123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- package ast
- import (
- "bytes"
- "go/ast"
- "go/format"
- "go/parser"
- "go/token"
- "golang.org/x/text/cases"
- "golang.org/x/text/language"
- "log"
- "os"
- "strconv"
- "strings"
- )
- type Visitor struct {
- ImportCode string
- StructName string
- PackageName string
- GroupName string
- }
- func (vi *Visitor) Visit(node ast.Node) ast.Visitor {
- switch n := node.(type) {
- case *ast.GenDecl:
- // 查找有没有import context包
- // Notice:没有考虑没有import任何包的情况
- if n.Tok == token.IMPORT && vi.ImportCode != "" {
- vi.addImport(n)
- // 不需要再遍历子树
- return nil
- }
- if n.Tok == token.TYPE && vi.StructName != "" && vi.PackageName != "" && vi.GroupName != "" {
- vi.addStruct(n)
- return nil
- }
- case *ast.FuncDecl:
- if n.Name.Name == "Routers" {
- vi.addFuncBodyVar(n)
- return nil
- }
- }
- return vi
- }
- func (vi *Visitor) addStruct(genDecl *ast.GenDecl) ast.Visitor {
- for i := range genDecl.Specs {
- switch n := genDecl.Specs[i].(type) {
- case *ast.TypeSpec:
- if strings.Index(n.Name.Name, "Group") > -1 {
- switch t := n.Type.(type) {
- case *ast.StructType:
- f := &ast.Field{
- Names: []*ast.Ident{
- {
- Name: vi.StructName,
- Obj: &ast.Object{
- Kind: ast.Var,
- Name: vi.StructName,
- },
- },
- },
- Type: &ast.SelectorExpr{
- X: &ast.Ident{
- Name: vi.PackageName,
- },
- Sel: &ast.Ident{
- Name: vi.GroupName,
- },
- },
- }
- t.Fields.List = append(t.Fields.List, f)
- }
- }
- }
- }
- return vi
- }
- func (vi *Visitor) addImport(genDecl *ast.GenDecl) ast.Visitor {
- // 是否已经import
- hasImported := false
- for _, v := range genDecl.Specs {
- importSpec := v.(*ast.ImportSpec)
- // 如果已经包含
- if importSpec.Path.Value == strconv.Quote(vi.ImportCode) {
- hasImported = true
- }
- }
- if !hasImported {
- genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{
- Path: &ast.BasicLit{
- Kind: token.STRING,
- Value: strconv.Quote(vi.ImportCode),
- },
- })
- }
- return vi
- }
- func (vi *Visitor) addFuncBodyVar(funDecl *ast.FuncDecl) ast.Visitor {
- hasVar := false
- for _, v := range funDecl.Body.List {
- switch varSpec := v.(type) {
- case *ast.AssignStmt:
- for i := range varSpec.Lhs {
- switch nn := varSpec.Lhs[i].(type) {
- case *ast.Ident:
- if nn.Name == vi.PackageName+"Router" {
- hasVar = true
- }
- }
- }
- }
- }
- if !hasVar {
- assignStmt := &ast.AssignStmt{
- Lhs: []ast.Expr{
- &ast.Ident{
- Name: vi.PackageName + "Router",
- Obj: &ast.Object{
- Kind: ast.Var,
- Name: vi.PackageName + "Router",
- },
- },
- },
- Tok: token.DEFINE,
- Rhs: []ast.Expr{
- &ast.SelectorExpr{
- X: &ast.SelectorExpr{
- X: &ast.Ident{
- Name: "router",
- },
- Sel: &ast.Ident{
- Name: "RouterGroupApp",
- },
- },
- Sel: &ast.Ident{
- Name: cases.Title(language.English).String(vi.PackageName),
- },
- },
- },
- }
- funDecl.Body.List = append(funDecl.Body.List, funDecl.Body.List[1])
- index := 1
- copy(funDecl.Body.List[index+1:], funDecl.Body.List[index:])
- funDecl.Body.List[index] = assignStmt
- }
- return vi
- }
- func ImportReference(filepath, importCode, structName, packageName, groupName string) error {
- fSet := token.NewFileSet()
- fParser, err := parser.ParseFile(fSet, filepath, nil, parser.ParseComments)
- if err != nil {
- return err
- }
- importCode = strings.TrimSpace(importCode)
- v := &Visitor{
- ImportCode: importCode,
- StructName: structName,
- PackageName: packageName,
- GroupName: groupName,
- }
- if importCode == "" {
- ast.Print(fSet, fParser)
- }
- ast.Walk(v, fParser)
- var output []byte
- buffer := bytes.NewBuffer(output)
- err = format.Node(buffer, fSet, fParser)
- if err != nil {
- log.Fatal(err)
- }
- // 写回数据
- return os.WriteFile(filepath, buffer.Bytes(), 0o600)
- }
|