ast_enter.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. package ast
  2. import (
  3. "bytes"
  4. "go/ast"
  5. "go/format"
  6. "go/parser"
  7. "go/token"
  8. "golang.org/x/text/cases"
  9. "golang.org/x/text/language"
  10. "log"
  11. "os"
  12. "strconv"
  13. "strings"
  14. )
  15. type Visitor struct {
  16. ImportCode string
  17. StructName string
  18. PackageName string
  19. GroupName string
  20. }
  21. func (vi *Visitor) Visit(node ast.Node) ast.Visitor {
  22. switch n := node.(type) {
  23. case *ast.GenDecl:
  24. // 查找有没有import context包
  25. // Notice:没有考虑没有import任何包的情况
  26. if n.Tok == token.IMPORT && vi.ImportCode != "" {
  27. vi.addImport(n)
  28. // 不需要再遍历子树
  29. return nil
  30. }
  31. if n.Tok == token.TYPE && vi.StructName != "" && vi.PackageName != "" && vi.GroupName != "" {
  32. vi.addStruct(n)
  33. return nil
  34. }
  35. case *ast.FuncDecl:
  36. if n.Name.Name == "Routers" {
  37. vi.addFuncBodyVar(n)
  38. return nil
  39. }
  40. }
  41. return vi
  42. }
  43. func (vi *Visitor) addStruct(genDecl *ast.GenDecl) ast.Visitor {
  44. for i := range genDecl.Specs {
  45. switch n := genDecl.Specs[i].(type) {
  46. case *ast.TypeSpec:
  47. if strings.Index(n.Name.Name, "Group") > -1 {
  48. switch t := n.Type.(type) {
  49. case *ast.StructType:
  50. f := &ast.Field{
  51. Names: []*ast.Ident{
  52. {
  53. Name: vi.StructName,
  54. Obj: &ast.Object{
  55. Kind: ast.Var,
  56. Name: vi.StructName,
  57. },
  58. },
  59. },
  60. Type: &ast.SelectorExpr{
  61. X: &ast.Ident{
  62. Name: vi.PackageName,
  63. },
  64. Sel: &ast.Ident{
  65. Name: vi.GroupName,
  66. },
  67. },
  68. }
  69. t.Fields.List = append(t.Fields.List, f)
  70. }
  71. }
  72. }
  73. }
  74. return vi
  75. }
  76. func (vi *Visitor) addImport(genDecl *ast.GenDecl) ast.Visitor {
  77. // 是否已经import
  78. hasImported := false
  79. for _, v := range genDecl.Specs {
  80. importSpec := v.(*ast.ImportSpec)
  81. // 如果已经包含
  82. if importSpec.Path.Value == strconv.Quote(vi.ImportCode) {
  83. hasImported = true
  84. }
  85. }
  86. if !hasImported {
  87. genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{
  88. Path: &ast.BasicLit{
  89. Kind: token.STRING,
  90. Value: strconv.Quote(vi.ImportCode),
  91. },
  92. })
  93. }
  94. return vi
  95. }
  96. func (vi *Visitor) addFuncBodyVar(funDecl *ast.FuncDecl) ast.Visitor {
  97. hasVar := false
  98. for _, v := range funDecl.Body.List {
  99. switch varSpec := v.(type) {
  100. case *ast.AssignStmt:
  101. for i := range varSpec.Lhs {
  102. switch nn := varSpec.Lhs[i].(type) {
  103. case *ast.Ident:
  104. if nn.Name == vi.PackageName+"Router" {
  105. hasVar = true
  106. }
  107. }
  108. }
  109. }
  110. }
  111. if !hasVar {
  112. assignStmt := &ast.AssignStmt{
  113. Lhs: []ast.Expr{
  114. &ast.Ident{
  115. Name: vi.PackageName + "Router",
  116. Obj: &ast.Object{
  117. Kind: ast.Var,
  118. Name: vi.PackageName + "Router",
  119. },
  120. },
  121. },
  122. Tok: token.DEFINE,
  123. Rhs: []ast.Expr{
  124. &ast.SelectorExpr{
  125. X: &ast.SelectorExpr{
  126. X: &ast.Ident{
  127. Name: "router",
  128. },
  129. Sel: &ast.Ident{
  130. Name: "RouterGroupApp",
  131. },
  132. },
  133. Sel: &ast.Ident{
  134. Name: cases.Title(language.English).String(vi.PackageName),
  135. },
  136. },
  137. },
  138. }
  139. funDecl.Body.List = append(funDecl.Body.List, funDecl.Body.List[1])
  140. index := 1
  141. copy(funDecl.Body.List[index+1:], funDecl.Body.List[index:])
  142. funDecl.Body.List[index] = assignStmt
  143. }
  144. return vi
  145. }
  146. func ImportReference(filepath, importCode, structName, packageName, groupName string) error {
  147. fSet := token.NewFileSet()
  148. fParser, err := parser.ParseFile(fSet, filepath, nil, parser.ParseComments)
  149. if err != nil {
  150. return err
  151. }
  152. importCode = strings.TrimSpace(importCode)
  153. v := &Visitor{
  154. ImportCode: importCode,
  155. StructName: structName,
  156. PackageName: packageName,
  157. GroupName: groupName,
  158. }
  159. if importCode == "" {
  160. ast.Print(fSet, fParser)
  161. }
  162. ast.Walk(v, fParser)
  163. var output []byte
  164. buffer := bytes.NewBuffer(output)
  165. err = format.Node(buffer, fSet, fParser)
  166. if err != nil {
  167. log.Fatal(err)
  168. }
  169. // 写回数据
  170. return os.WriteFile(filepath, buffer.Bytes(), 0o600)
  171. }