// (c) Copyright 2016 Hewlett Packard Enterprise Development LP // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package gas import ( "fmt" "go/ast" "reflect" ) // SelectFunc is like an AST visitor, but has a richer interface. It // is called with the current ast.Node being visitied and that nodes depth in // the tree. The function can return true to continue traversing the tree, or // false to end traversal here. type SelectFunc func(ast.Node, int) bool func walkIdentList(list []*ast.Ident, depth int, fun SelectFunc) { for _, x := range list { depthWalk(x, depth, fun) } } func walkExprList(list []ast.Expr, depth int, fun SelectFunc) { for _, x := range list { depthWalk(x, depth, fun) } } func walkStmtList(list []ast.Stmt, depth int, fun SelectFunc) { for _, x := range list { depthWalk(x, depth, fun) } } func walkDeclList(list []ast.Decl, depth int, fun SelectFunc) { for _, x := range list { depthWalk(x, depth, fun) } } func depthWalk(node ast.Node, depth int, fun SelectFunc) { if !fun(node, depth) { return } switch n := node.(type) { // Comments and fields case *ast.Comment: case *ast.CommentGroup: for _, c := range n.List { depthWalk(c, depth+1, fun) } case *ast.Field: if n.Doc != nil { depthWalk(n.Doc, depth+1, fun) } walkIdentList(n.Names, depth+1, fun) depthWalk(n.Type, depth+1, fun) if n.Tag != nil { depthWalk(n.Tag, depth+1, fun) } if n.Comment != nil { depthWalk(n.Comment, depth+1, fun) } case *ast.FieldList: for _, f := range n.List { depthWalk(f, depth+1, fun) } // Expressions case *ast.BadExpr, *ast.Ident, *ast.BasicLit: case *ast.Ellipsis: if n.Elt != nil { depthWalk(n.Elt, depth+1, fun) } case *ast.FuncLit: depthWalk(n.Type, depth+1, fun) depthWalk(n.Body, depth+1, fun) case *ast.CompositeLit: if n.Type != nil { depthWalk(n.Type, depth+1, fun) } walkExprList(n.Elts, depth+1, fun) case *ast.ParenExpr: depthWalk(n.X, depth+1, fun) case *ast.SelectorExpr: depthWalk(n.X, depth+1, fun) depthWalk(n.Sel, depth+1, fun) case *ast.IndexExpr: depthWalk(n.X, depth+1, fun) depthWalk(n.Index, depth+1, fun) case *ast.SliceExpr: depthWalk(n.X, depth+1, fun) if n.Low != nil { depthWalk(n.Low, depth+1, fun) } if n.High != nil { depthWalk(n.High, depth+1, fun) } if n.Max != nil { depthWalk(n.Max, depth+1, fun) } case *ast.TypeAssertExpr: depthWalk(n.X, depth+1, fun) if n.Type != nil { depthWalk(n.Type, depth+1, fun) } case *ast.CallExpr: depthWalk(n.Fun, depth+1, fun) walkExprList(n.Args, depth+1, fun) case *ast.StarExpr: depthWalk(n.X, depth+1, fun) case *ast.UnaryExpr: depthWalk(n.X, depth+1, fun) case *ast.BinaryExpr: depthWalk(n.X, depth+1, fun) depthWalk(n.Y, depth+1, fun) case *ast.KeyValueExpr: depthWalk(n.Key, depth+1, fun) depthWalk(n.Value, depth+1, fun) // Types case *ast.ArrayType: if n.Len != nil { depthWalk(n.Len, depth+1, fun) } depthWalk(n.Elt, depth+1, fun) case *ast.StructType: depthWalk(n.Fields, depth+1, fun) case *ast.FuncType: if n.Params != nil { depthWalk(n.Params, depth+1, fun) } if n.Results != nil { depthWalk(n.Results, depth+1, fun) } case *ast.InterfaceType: depthWalk(n.Methods, depth+1, fun) case *ast.MapType: depthWalk(n.Key, depth+1, fun) depthWalk(n.Value, depth+1, fun) case *ast.ChanType: depthWalk(n.Value, depth+1, fun) // Statements case *ast.BadStmt: case *ast.DeclStmt: depthWalk(n.Decl, depth+1, fun) case *ast.EmptyStmt: case *ast.LabeledStmt: depthWalk(n.Label, depth+1, fun) depthWalk(n.Stmt, depth+1, fun) case *ast.ExprStmt: depthWalk(n.X, depth+1, fun) case *ast.SendStmt: depthWalk(n.Chan, depth+1, fun) depthWalk(n.Value, depth+1, fun) case *ast.IncDecStmt: depthWalk(n.X, depth+1, fun) case *ast.AssignStmt: walkExprList(n.Lhs, depth+1, fun) walkExprList(n.Rhs, depth+1, fun) case *ast.GoStmt: depthWalk(n.Call, depth+1, fun) case *ast.DeferStmt: depthWalk(n.Call, depth+1, fun) case *ast.ReturnStmt: walkExprList(n.Results, depth+1, fun) case *ast.BranchStmt: if n.Label != nil { depthWalk(n.Label, depth+1, fun) } case *ast.BlockStmt: walkStmtList(n.List, depth+1, fun) case *ast.IfStmt: if n.Init != nil { depthWalk(n.Init, depth+1, fun) } depthWalk(n.Cond, depth+1, fun) depthWalk(n.Body, depth+1, fun) if n.Else != nil { depthWalk(n.Else, depth+1, fun) } case *ast.CaseClause: walkExprList(n.List, depth+1, fun) walkStmtList(n.Body, depth+1, fun) case *ast.SwitchStmt: if n.Init != nil { depthWalk(n.Init, depth+1, fun) } if n.Tag != nil { depthWalk(n.Tag, depth+1, fun) } depthWalk(n.Body, depth+1, fun) case *ast.TypeSwitchStmt: if n.Init != nil { depthWalk(n.Init, depth+1, fun) } depthWalk(n.Assign, depth+1, fun) depthWalk(n.Body, depth+1, fun) case *ast.CommClause: if n.Comm != nil { depthWalk(n.Comm, depth+1, fun) } walkStmtList(n.Body, depth+1, fun) case *ast.SelectStmt: depthWalk(n.Body, depth+1, fun) case *ast.ForStmt: if n.Init != nil { depthWalk(n.Init, depth+1, fun) } if n.Cond != nil { depthWalk(n.Cond, depth+1, fun) } if n.Post != nil { depthWalk(n.Post, depth+1, fun) } depthWalk(n.Body, depth+1, fun) case *ast.RangeStmt: if n.Key != nil { depthWalk(n.Key, depth+1, fun) } if n.Value != nil { depthWalk(n.Value, depth+1, fun) } depthWalk(n.X, depth+1, fun) depthWalk(n.Body, depth+1, fun) // Declarations case *ast.ImportSpec: if n.Doc != nil { depthWalk(n.Doc, depth+1, fun) } if n.Name != nil { depthWalk(n.Name, depth+1, fun) } depthWalk(n.Path, depth+1, fun) if n.Comment != nil { depthWalk(n.Comment, depth+1, fun) } case *ast.ValueSpec: if n.Doc != nil { depthWalk(n.Doc, depth+1, fun) } walkIdentList(n.Names, depth+1, fun) if n.Type != nil { depthWalk(n.Type, depth+1, fun) } walkExprList(n.Values, depth+1, fun) if n.Comment != nil { depthWalk(n.Comment, depth+1, fun) } case *ast.TypeSpec: if n.Doc != nil { depthWalk(n.Doc, depth+1, fun) } depthWalk(n.Name, depth+1, fun) depthWalk(n.Type, depth+1, fun) if n.Comment != nil { depthWalk(n.Comment, depth+1, fun) } case *ast.BadDecl: case *ast.GenDecl: if n.Doc != nil { depthWalk(n.Doc, depth+1, fun) } for _, s := range n.Specs { depthWalk(s, depth+1, fun) } case *ast.FuncDecl: if n.Doc != nil { depthWalk(n.Doc, depth+1, fun) } if n.Recv != nil { depthWalk(n.Recv, depth+1, fun) } depthWalk(n.Name, depth+1, fun) depthWalk(n.Type, depth+1, fun) if n.Body != nil { depthWalk(n.Body, depth+1, fun) } // Files and packages case *ast.File: if n.Doc != nil { depthWalk(n.Doc, depth+1, fun) } depthWalk(n.Name, depth+1, fun) walkDeclList(n.Decls, depth+1, fun) // don't walk n.Comments - they have been // visited already through the individual // nodes case *ast.Package: for _, f := range n.Files { depthWalk(f, depth+1, fun) } default: panic(fmt.Sprintf("gas.depthWalk: unexpected node type %T", n)) } } type Selector interface { Final(ast.Node) Partial(ast.Node) bool } func Select(s Selector, n ast.Node, bits ...reflect.Type) { fun := func(n ast.Node, d int) bool { if d < len(bits) && reflect.TypeOf(n) == bits[d] { if d == len(bits)-1 { s.Final(n) return false } else if s.Partial(n) { return true } } return false } depthWalk(n, 0, fun) } // SimpleSelect will try to match a path through a sub-tree starting at a given AST node. // The type of each node in the path at a given depth must match its entry in list of // node types given. func SimpleSelect(n ast.Node, bits ...reflect.Type) ast.Node { var found ast.Node fun := func(n ast.Node, d int) bool { if found != nil { return false // short cut logic if we have found a match } if d < len(bits) && reflect.TypeOf(n) == bits[d] { if d == len(bits)-1 { found = n return false } return true } return false } depthWalk(n, 0, fun) return found }