Extend helpers and call list

- Update call list to work directly with call expression
- Add call list test cases
- Extend helpers to add GetCallInfo to resolve call name and package or
  type if it's a var.
- Add test cases to ensure correct behaviour
This commit is contained in:
Grant Murphy 2016-11-18 09:57:34 -08:00
parent d29c64800e
commit 5242a2c1df
4 changed files with 139 additions and 51 deletions

View file

@ -13,16 +13,13 @@
package core
type set map[string]bool
import "go/ast"
type calls struct {
matchAny bool
functions set
}
type set map[string]bool
/// CallList is used to check for usage of specific packages
/// and functions.
type CallList map[string]*calls
type CallList map[string]set
/// NewCallList creates a new empty CallList
func NewCallList() CallList {
@ -30,36 +27,39 @@ func NewCallList() CallList {
}
/// NewCallListFor createse a call list using the package path
func NewCallListFor(pkg string, funcs ...string) CallList {
func NewCallListFor(selector string, idents ...string) CallList {
c := NewCallList()
if len(funcs) == 0 {
c[pkg] = &calls{true, make(set)}
} else {
for _, fn := range funcs {
c.Add(pkg, fn)
}
c[selector] = make(set)
for _, ident := range idents {
c.Add(selector, ident)
}
return c
}
/// Add a new package and function to the call list
func (c CallList) Add(pkg, fn string) {
if cl, ok := c[pkg]; ok {
if cl.matchAny {
cl.matchAny = false
}
} else {
c[pkg] = &calls{false, make(set)}
/// Add a selector and call to the call list
func (c CallList) Add(selector, ident string) {
if _, ok := c[selector]; !ok {
c[selector] = make(set)
}
c[pkg].functions[fn] = true
c[selector][ident] = true
}
/// Contains returns true if the package and function are
/// members of this call list.
func (c CallList) Contains(pkg, fn string) bool {
if funcs, ok := c[pkg]; ok {
_, ok = funcs.functions[fn]
return ok || funcs.matchAny
func (c CallList) Contains(selector, ident string) bool {
if idents, ok := c[selector]; ok {
_, found := idents[ident]
return found
}
return false
}
/// ContainsCallExpr resolves the call expression name and type
/// or package and determines if it exists within the CallList
func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) bool {
selector, ident, err := GetCallInfo(n, ctx)
if err != nil {
return false
}
return c.Contains(selector, ident)
}

58
core/call_list_test.go Normal file
View file

@ -0,0 +1,58 @@
package core
import (
"go/ast"
"testing"
)
type callListRule struct {
MetaData
callList CallList
matched int
}
func (r *callListRule) Match(n ast.Node, c *Context) (gi *Issue, err error) {
if r.callList.ContainsCallExpr(n, c) {
r.matched += 1
}
return nil, nil
}
func TestCallListContainsCallExpr(t *testing.T) {
config := map[string]interface{}{"ignoreNosec": false}
analyzer := NewAnalyzer(config, nil)
rule := &callListRule{
MetaData: MetaData{
Severity: Low,
Confidence: Low,
What: "A dummy rule",
},
callList: NewCallListFor("bytes.Buffer", "Write", "WriteTo"),
matched: 0,
}
analyzer.AddRule(rule, []ast.Node{(*ast.CallExpr)(nil)})
source := `
package main
import (
"bytes"
"fmt"
)
func main() {
var b bytes.Buffer
b.Write([]byte("Hello "))
fmt.Fprintf(&b, "world!")
}`
analyzer.ProcessSource("dummy.go", source)
if rule.matched != 1 {
t.Errorf("Expected to match a bytes.Buffer.Write call")
}
}
func TestCallListContains(t *testing.T) {
callList := NewCallList()
callList.Add("fmt", "Printf")
if !callList.Contains("fmt", "Printf") {
t.Errorf("Expected call list to contain fmt.Printf")
}
}

View file

@ -69,18 +69,15 @@ func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*a
importName = alias
}
switch node := n.(type) {
case *ast.CallExpr:
switch fn := node.Fun.(type) {
case *ast.SelectorExpr:
switch expr := fn.X.(type) {
case *ast.Ident:
if expr.Name == importName {
for _, name := range names {
if fn.Sel.Name == name {
return node, true
}
}
if callExpr, ok := n.(*ast.CallExpr); ok {
packageName, callName, err := GetCallInfo(callExpr, c)
if err != nil {
return nil, false
}
if packageName == importName {
for _, name := range names {
if callName == name {
return callExpr, true
}
}
}
@ -95,19 +92,15 @@ func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*a
// node, matched := MatchCallByType(n, ctx, "bytes.Buffer", "WriteTo", "Write")
//
func MatchCallByType(n ast.Node, ctx *Context, requiredType string, calls ...string) (*ast.CallExpr, bool) {
switch callExpr := n.(type) {
case *ast.CallExpr:
switch fn := callExpr.Fun.(type) {
case *ast.SelectorExpr:
switch expr := fn.X.(type) {
case *ast.Ident:
t := ctx.Info.TypeOf(expr)
if t != nil && t.String() == requiredType {
for _, call := range calls {
if fn.Sel.Name == call {
return callExpr, true
}
}
if callExpr, ok := n.(*ast.CallExpr); ok {
typeName, callName, err := GetCallInfo(callExpr, ctx)
if err != nil {
return nil, false
}
if typeName == requiredType {
for _, call := range calls {
if call == callName {
return callExpr, true
}
}
}
@ -171,3 +164,28 @@ func GetCallObject(n ast.Node, ctx *Context) (*ast.CallExpr, types.Object) {
}
return nil, nil
}
// GetCallInfo returns the package or type and name associated with a
// call expression.
func GetCallInfo(n ast.Node, ctx *Context) (string, string, error) {
switch node := n.(type) {
case *ast.CallExpr:
switch fn := node.Fun.(type) {
case *ast.SelectorExpr:
switch expr := fn.X.(type) {
case *ast.Ident:
if expr.Obj != nil && expr.Obj.Kind == ast.Var {
t := ctx.Info.TypeOf(expr)
if t != nil {
return t.String(), fn.Sel.Name, nil
} else {
return "undefined", fn.Sel.Name, fmt.Errorf("missing type info")
}
} else {
return expr.Name, fn.Sel.Name, nil
}
}
}
}
return "", "", fmt.Errorf("unable to determine call info")
}

View file

@ -56,4 +56,16 @@ func TestMatchCallByType(t *testing.T) {
if rule.matched != 1 || len(rule.callExpr) != 1 {
t.Errorf("Expected to match a bytes.Buffer.Write call")
}
typeName, callName, err := GetCallInfo(rule.callExpr[0], &analyzer.context)
if err != nil {
t.Errorf("Unable to resolve call info: %v\n", err)
}
if typeName != "bytes.Buffer" {
t.Errorf("Expected: %s, Got: %s\n", "bytes.Buffer", typeName)
}
if callName != "Write" {
t.Errorf("Expected: %s, Got: %s\n", "Write", callName)
}
}