ast_rollback.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. package ast
  2. import (
  3. "bytes"
  4. "fmt"
  5. "go/ast"
  6. "go/parser"
  7. "go/printer"
  8. "go/token"
  9. "os"
  10. "path/filepath"
  11. "server/global"
  12. )
  13. func RollBackAst(pk, model string) {
  14. RollGormBack(pk, model)
  15. RollRouterBack(pk, model)
  16. }
  17. func RollGormBack(pk, model string) {
  18. // 首先分析存在多少个ttt作为调用方的node块
  19. // 如果多个 仅仅删除对应块即可
  20. // 如果单个 那么还需要剔除import
  21. path := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm.go")
  22. src, err := os.ReadFile(path)
  23. if err != nil {
  24. fmt.Println(err)
  25. }
  26. fileSet := token.NewFileSet()
  27. astFile, err := parser.ParseFile(fileSet, "", src, 0)
  28. if err != nil {
  29. fmt.Println(err)
  30. }
  31. var n *ast.CallExpr
  32. var k int = -1
  33. var pkNum = 0
  34. ast.Inspect(astFile, func(node ast.Node) bool {
  35. if node, ok := node.(*ast.CallExpr); ok {
  36. for i := range node.Args {
  37. pkOK := false
  38. modelOK := false
  39. ast.Inspect(node.Args[i], func(item ast.Node) bool {
  40. if ii, ok := item.(*ast.Ident); ok {
  41. if ii.Name == pk {
  42. pkOK = true
  43. pkNum++
  44. }
  45. if ii.Name == model {
  46. modelOK = true
  47. }
  48. }
  49. if pkOK && modelOK {
  50. n = node
  51. k = i
  52. }
  53. return true
  54. })
  55. }
  56. }
  57. return true
  58. })
  59. if k > 0 {
  60. n.Args = append(append([]ast.Expr{}, n.Args[:k]...), n.Args[k+1:]...)
  61. }
  62. if pkNum == 1 {
  63. var imI int = -1
  64. var gp *ast.GenDecl
  65. ast.Inspect(astFile, func(node ast.Node) bool {
  66. if gen, ok := node.(*ast.GenDecl); ok {
  67. for i := range gen.Specs {
  68. if imspec, ok := gen.Specs[i].(*ast.ImportSpec); ok {
  69. if imspec.Path.Value == "\"server/model/"+pk+"\"" {
  70. gp = gen
  71. imI = i
  72. return false
  73. }
  74. }
  75. }
  76. }
  77. return true
  78. })
  79. if imI > -1 {
  80. gp.Specs = append(append([]ast.Spec{}, gp.Specs[:imI]...), gp.Specs[imI+1:]...)
  81. }
  82. }
  83. var out []byte
  84. bf := bytes.NewBuffer(out)
  85. printer.Fprint(bf, fileSet, astFile)
  86. os.Remove(path)
  87. os.WriteFile(path, bf.Bytes(), 0666)
  88. }
  89. func RollRouterBack(pk, model string) {
  90. // 首先抓到所有的代码块结构 {}
  91. // 分析结构中是否存在一个变量叫做 pk+Router
  92. // 然后获取到代码块指针 对内部需要回滚的代码进行剔除
  93. path := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "router.go")
  94. src, err := os.ReadFile(path)
  95. if err != nil {
  96. fmt.Println(err)
  97. }
  98. fileSet := token.NewFileSet()
  99. astFile, err := parser.ParseFile(fileSet, "", src, 0)
  100. if err != nil {
  101. fmt.Println(err)
  102. }
  103. var block *ast.BlockStmt
  104. var routerStmt *ast.FuncDecl
  105. ast.Inspect(astFile, func(node ast.Node) bool {
  106. if n, ok := node.(*ast.FuncDecl); ok {
  107. if n.Name.Name == "Routers" {
  108. routerStmt = n
  109. }
  110. }
  111. if n, ok := node.(*ast.BlockStmt); ok {
  112. ast.Inspect(n, func(bNode ast.Node) bool {
  113. if in, ok := bNode.(*ast.Ident); ok {
  114. if in.Name == pk+"Router" {
  115. block = n
  116. return false
  117. }
  118. }
  119. return true
  120. })
  121. return true
  122. }
  123. return true
  124. })
  125. var k int
  126. for i := range block.List {
  127. if stmtNode, ok := block.List[i].(*ast.ExprStmt); ok {
  128. ast.Inspect(stmtNode, func(node ast.Node) bool {
  129. if n, ok := node.(*ast.Ident); ok {
  130. if n.Name == "Init"+model+"Router" {
  131. k = i
  132. return false
  133. }
  134. }
  135. return true
  136. })
  137. }
  138. }
  139. block.List = append(append([]ast.Stmt{}, block.List[:k]...), block.List[k+1:]...)
  140. if len(block.List) == 1 {
  141. // 说明这个块就没任何意义了
  142. block.List = nil
  143. }
  144. for i, n := range routerStmt.Body.List {
  145. if n, ok := n.(*ast.BlockStmt); ok {
  146. if n.List == nil {
  147. routerStmt.Body.List = append(append([]ast.Stmt{}, routerStmt.Body.List[:i]...), routerStmt.Body.List[i+1:]...)
  148. i--
  149. }
  150. }
  151. }
  152. var out []byte
  153. bf := bytes.NewBuffer(out)
  154. printer.Fprint(bf, fileSet, astFile)
  155. os.Remove(path)
  156. os.WriteFile(path, bf.Bytes(), 0666)
  157. }