Extend helpers and call list

- Update call list to work directly with call expression
- Add call list test cases
- Extend helpers to add GetCallInfo to resolve call name and package or
  type if it's a var.
- Add test cases to ensure correct behaviour
This commit is contained in:
Grant Murphy 2016-11-18 09:57:34 -08:00
parent d29c64800e
commit 5242a2c1df
4 changed files with 139 additions and 51 deletions

View file

@ -13,16 +13,13 @@
package core package core
type set map[string]bool import "go/ast"
type calls struct { type set map[string]bool
matchAny bool
functions set
}
/// CallList is used to check for usage of specific packages /// CallList is used to check for usage of specific packages
/// and functions. /// and functions.
type CallList map[string]*calls type CallList map[string]set
/// NewCallList creates a new empty CallList /// NewCallList creates a new empty CallList
func NewCallList() CallList { func NewCallList() CallList {
@ -30,36 +27,39 @@ func NewCallList() CallList {
} }
/// NewCallListFor createse a call list using the package path /// NewCallListFor createse a call list using the package path
func NewCallListFor(pkg string, funcs ...string) CallList { func NewCallListFor(selector string, idents ...string) CallList {
c := NewCallList() c := NewCallList()
if len(funcs) == 0 { c[selector] = make(set)
c[pkg] = &calls{true, make(set)} for _, ident := range idents {
} else { c.Add(selector, ident)
for _, fn := range funcs {
c.Add(pkg, fn)
}
} }
return c return c
} }
/// Add a new package and function to the call list /// Add a selector and call to the call list
func (c CallList) Add(pkg, fn string) { func (c CallList) Add(selector, ident string) {
if cl, ok := c[pkg]; ok { if _, ok := c[selector]; !ok {
if cl.matchAny { c[selector] = make(set)
cl.matchAny = false
}
} else {
c[pkg] = &calls{false, make(set)}
} }
c[pkg].functions[fn] = true c[selector][ident] = true
} }
/// Contains returns true if the package and function are /// Contains returns true if the package and function are
/// members of this call list. /// members of this call list.
func (c CallList) Contains(pkg, fn string) bool { func (c CallList) Contains(selector, ident string) bool {
if funcs, ok := c[pkg]; ok { if idents, ok := c[selector]; ok {
_, ok = funcs.functions[fn] _, found := idents[ident]
return ok || funcs.matchAny return found
} }
return false return false
} }
/// ContainsCallExpr resolves the call expression name and type
/// or package and determines if it exists within the CallList
func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) bool {
selector, ident, err := GetCallInfo(n, ctx)
if err != nil {
return false
}
return c.Contains(selector, ident)
}

58
core/call_list_test.go Normal file
View file

@ -0,0 +1,58 @@
package core
import (
"go/ast"
"testing"
)
type callListRule struct {
MetaData
callList CallList
matched int
}
func (r *callListRule) Match(n ast.Node, c *Context) (gi *Issue, err error) {
if r.callList.ContainsCallExpr(n, c) {
r.matched += 1
}
return nil, nil
}
func TestCallListContainsCallExpr(t *testing.T) {
config := map[string]interface{}{"ignoreNosec": false}
analyzer := NewAnalyzer(config, nil)
rule := &callListRule{
MetaData: MetaData{
Severity: Low,
Confidence: Low,
What: "A dummy rule",
},
callList: NewCallListFor("bytes.Buffer", "Write", "WriteTo"),
matched: 0,
}
analyzer.AddRule(rule, []ast.Node{(*ast.CallExpr)(nil)})
source := `
package main
import (
"bytes"
"fmt"
)
func main() {
var b bytes.Buffer
b.Write([]byte("Hello "))
fmt.Fprintf(&b, "world!")
}`
analyzer.ProcessSource("dummy.go", source)
if rule.matched != 1 {
t.Errorf("Expected to match a bytes.Buffer.Write call")
}
}
func TestCallListContains(t *testing.T) {
callList := NewCallList()
callList.Add("fmt", "Printf")
if !callList.Contains("fmt", "Printf") {
t.Errorf("Expected call list to contain fmt.Printf")
}
}

View file

@ -69,18 +69,15 @@ func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*a
importName = alias importName = alias
} }
switch node := n.(type) { if callExpr, ok := n.(*ast.CallExpr); ok {
case *ast.CallExpr: packageName, callName, err := GetCallInfo(callExpr, c)
switch fn := node.Fun.(type) { if err != nil {
case *ast.SelectorExpr: return nil, false
switch expr := fn.X.(type) { }
case *ast.Ident: if packageName == importName {
if expr.Name == importName { for _, name := range names {
for _, name := range names { if callName == name {
if fn.Sel.Name == name { return callExpr, true
return node, true
}
}
} }
} }
} }
@ -95,19 +92,15 @@ func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*a
// node, matched := MatchCallByType(n, ctx, "bytes.Buffer", "WriteTo", "Write") // node, matched := MatchCallByType(n, ctx, "bytes.Buffer", "WriteTo", "Write")
// //
func MatchCallByType(n ast.Node, ctx *Context, requiredType string, calls ...string) (*ast.CallExpr, bool) { func MatchCallByType(n ast.Node, ctx *Context, requiredType string, calls ...string) (*ast.CallExpr, bool) {
switch callExpr := n.(type) { if callExpr, ok := n.(*ast.CallExpr); ok {
case *ast.CallExpr: typeName, callName, err := GetCallInfo(callExpr, ctx)
switch fn := callExpr.Fun.(type) { if err != nil {
case *ast.SelectorExpr: return nil, false
switch expr := fn.X.(type) { }
case *ast.Ident: if typeName == requiredType {
t := ctx.Info.TypeOf(expr) for _, call := range calls {
if t != nil && t.String() == requiredType { if call == callName {
for _, call := range calls { return callExpr, true
if fn.Sel.Name == call {
return callExpr, true
}
}
} }
} }
} }
@ -171,3 +164,28 @@ func GetCallObject(n ast.Node, ctx *Context) (*ast.CallExpr, types.Object) {
} }
return nil, nil return nil, nil
} }
// GetCallInfo returns the package or type and name associated with a
// call expression.
func GetCallInfo(n ast.Node, ctx *Context) (string, string, error) {
switch node := n.(type) {
case *ast.CallExpr:
switch fn := node.Fun.(type) {
case *ast.SelectorExpr:
switch expr := fn.X.(type) {
case *ast.Ident:
if expr.Obj != nil && expr.Obj.Kind == ast.Var {
t := ctx.Info.TypeOf(expr)
if t != nil {
return t.String(), fn.Sel.Name, nil
} else {
return "undefined", fn.Sel.Name, fmt.Errorf("missing type info")
}
} else {
return expr.Name, fn.Sel.Name, nil
}
}
}
}
return "", "", fmt.Errorf("unable to determine call info")
}

View file

@ -56,4 +56,16 @@ func TestMatchCallByType(t *testing.T) {
if rule.matched != 1 || len(rule.callExpr) != 1 { if rule.matched != 1 || len(rule.callExpr) != 1 {
t.Errorf("Expected to match a bytes.Buffer.Write call") t.Errorf("Expected to match a bytes.Buffer.Write call")
} }
typeName, callName, err := GetCallInfo(rule.callExpr[0], &analyzer.context)
if err != nil {
t.Errorf("Unable to resolve call info: %v\n", err)
}
if typeName != "bytes.Buffer" {
t.Errorf("Expected: %s, Got: %s\n", "bytes.Buffer", typeName)
}
if callName != "Write" {
t.Errorf("Expected: %s, Got: %s\n", "Write", callName)
}
} }