From d29c64800e7d7f7f78e2882429c6bb16eda297cc Mon Sep 17 00:00:00 2001 From: Grant Murphy Date: Thu, 17 Nov 2016 20:18:31 -0800 Subject: [PATCH 1/3] Add match call by type --- core/helpers.go | 27 ++++++++++++++++++++ core/helpers_test.go | 59 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 core/helpers_test.go diff --git a/core/helpers.go b/core/helpers.go index dd25821..ee1bddf 100644 --- a/core/helpers.go +++ b/core/helpers.go @@ -88,6 +88,33 @@ func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*a return nil, false } +// MatchCallByType ensures that the node is a call expression to a +// specific object type. +// +// Usage: +// node, matched := MatchCallByType(n, ctx, "bytes.Buffer", "WriteTo", "Write") +// +func MatchCallByType(n ast.Node, ctx *Context, requiredType string, calls ...string) (*ast.CallExpr, bool) { + switch callExpr := n.(type) { + case *ast.CallExpr: + switch fn := callExpr.Fun.(type) { + case *ast.SelectorExpr: + switch expr := fn.X.(type) { + case *ast.Ident: + t := ctx.Info.TypeOf(expr) + if t != nil && t.String() == requiredType { + for _, call := range calls { + if fn.Sel.Name == call { + return callExpr, true + } + } + } + } + } + } + return nil, false +} + // MatchCompLit will match an ast.CompositeLit if its string value obays the given regex. func MatchCompLit(n ast.Node, r *regexp.Regexp) *ast.CompositeLit { t := reflect.TypeOf(&ast.CompositeLit{}) diff --git a/core/helpers_test.go b/core/helpers_test.go new file mode 100644 index 0000000..beb7afb --- /dev/null +++ b/core/helpers_test.go @@ -0,0 +1,59 @@ +package core + +import ( + "go/ast" + "testing" +) + +type dummyCallback func(ast.Node, *Context, string, ...string) (*ast.CallExpr, bool) + +type dummyRule struct { + MetaData + pkgOrType string + funcsOrMethods []string + callback dummyCallback + callExpr []ast.Node + matched int +} + +func (r *dummyRule) Match(n ast.Node, c *Context) (gi *Issue, err error) { + if callexpr, matched := r.callback(n, c, r.pkgOrType, r.funcsOrMethods...); matched { + r.matched += 1 + r.callExpr = append(r.callExpr, callexpr) + } + return nil, nil +} + +func TestMatchCallByType(t *testing.T) { + config := map[string]interface{}{"ignoreNosec": false} + analyzer := NewAnalyzer(config, nil) + rule := &dummyRule{ + MetaData: MetaData{ + Severity: Low, + Confidence: Low, + What: "A dummy rule", + }, + pkgOrType: "bytes.Buffer", + funcsOrMethods: []string{"Write"}, + callback: MatchCallByType, + callExpr: []ast.Node{}, + matched: 0, + } + analyzer.AddRule(rule, []ast.Node{(*ast.CallExpr)(nil)}) + source := ` + package main + import ( + "bytes" + "fmt" + ) + func main() { + var b bytes.Buffer + b.Write([]byte("Hello ")) + fmt.Fprintf(&b, "world!") + }` + + analyzer.ProcessSource("dummy.go", source) + if rule.matched != 1 || len(rule.callExpr) != 1 { + t.Errorf("Expected to match a bytes.Buffer.Write call") + } +} From 5242a2c1df10a0d7ec3e66c141fa2adef66904db Mon Sep 17 00:00:00 2001 From: Grant Murphy Date: Fri, 18 Nov 2016 09:57:34 -0800 Subject: [PATCH 2/3] Extend helpers and call list - Update call list to work directly with call expression - Add call list test cases - Extend helpers to add GetCallInfo to resolve call name and package or type if it's a var. - Add test cases to ensure correct behaviour --- core/call_list.go | 52 ++++++++++++++++---------------- core/call_list_test.go | 58 +++++++++++++++++++++++++++++++++++ core/helpers.go | 68 ++++++++++++++++++++++++++---------------- core/helpers_test.go | 12 ++++++++ 4 files changed, 139 insertions(+), 51 deletions(-) create mode 100644 core/call_list_test.go diff --git a/core/call_list.go b/core/call_list.go index 1e45513..9ace433 100644 --- a/core/call_list.go +++ b/core/call_list.go @@ -13,16 +13,13 @@ package core -type set map[string]bool +import "go/ast" -type calls struct { - matchAny bool - functions set -} +type set map[string]bool /// CallList is used to check for usage of specific packages /// and functions. -type CallList map[string]*calls +type CallList map[string]set /// NewCallList creates a new empty CallList func NewCallList() CallList { @@ -30,36 +27,39 @@ func NewCallList() CallList { } /// NewCallListFor createse a call list using the package path -func NewCallListFor(pkg string, funcs ...string) CallList { +func NewCallListFor(selector string, idents ...string) CallList { c := NewCallList() - if len(funcs) == 0 { - c[pkg] = &calls{true, make(set)} - } else { - for _, fn := range funcs { - c.Add(pkg, fn) - } + c[selector] = make(set) + for _, ident := range idents { + c.Add(selector, ident) } return c } -/// Add a new package and function to the call list -func (c CallList) Add(pkg, fn string) { - if cl, ok := c[pkg]; ok { - if cl.matchAny { - cl.matchAny = false - } - } else { - c[pkg] = &calls{false, make(set)} +/// Add a selector and call to the call list +func (c CallList) Add(selector, ident string) { + if _, ok := c[selector]; !ok { + c[selector] = make(set) } - c[pkg].functions[fn] = true + c[selector][ident] = true } /// Contains returns true if the package and function are /// members of this call list. -func (c CallList) Contains(pkg, fn string) bool { - if funcs, ok := c[pkg]; ok { - _, ok = funcs.functions[fn] - return ok || funcs.matchAny +func (c CallList) Contains(selector, ident string) bool { + if idents, ok := c[selector]; ok { + _, found := idents[ident] + return found } return false } + +/// ContainsCallExpr resolves the call expression name and type +/// or package and determines if it exists within the CallList +func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) bool { + selector, ident, err := GetCallInfo(n, ctx) + if err != nil { + return false + } + return c.Contains(selector, ident) +} diff --git a/core/call_list_test.go b/core/call_list_test.go new file mode 100644 index 0000000..aa4f67c --- /dev/null +++ b/core/call_list_test.go @@ -0,0 +1,58 @@ +package core + +import ( + "go/ast" + "testing" +) + +type callListRule struct { + MetaData + callList CallList + matched int +} + +func (r *callListRule) Match(n ast.Node, c *Context) (gi *Issue, err error) { + if r.callList.ContainsCallExpr(n, c) { + r.matched += 1 + } + return nil, nil +} + +func TestCallListContainsCallExpr(t *testing.T) { + config := map[string]interface{}{"ignoreNosec": false} + analyzer := NewAnalyzer(config, nil) + rule := &callListRule{ + MetaData: MetaData{ + Severity: Low, + Confidence: Low, + What: "A dummy rule", + }, + callList: NewCallListFor("bytes.Buffer", "Write", "WriteTo"), + matched: 0, + } + analyzer.AddRule(rule, []ast.Node{(*ast.CallExpr)(nil)}) + source := ` + package main + import ( + "bytes" + "fmt" + ) + func main() { + var b bytes.Buffer + b.Write([]byte("Hello ")) + fmt.Fprintf(&b, "world!") + }` + + analyzer.ProcessSource("dummy.go", source) + if rule.matched != 1 { + t.Errorf("Expected to match a bytes.Buffer.Write call") + } +} + +func TestCallListContains(t *testing.T) { + callList := NewCallList() + callList.Add("fmt", "Printf") + if !callList.Contains("fmt", "Printf") { + t.Errorf("Expected call list to contain fmt.Printf") + } +} diff --git a/core/helpers.go b/core/helpers.go index ee1bddf..79a1617 100644 --- a/core/helpers.go +++ b/core/helpers.go @@ -69,18 +69,15 @@ func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*a importName = alias } - switch node := n.(type) { - case *ast.CallExpr: - switch fn := node.Fun.(type) { - case *ast.SelectorExpr: - switch expr := fn.X.(type) { - case *ast.Ident: - if expr.Name == importName { - for _, name := range names { - if fn.Sel.Name == name { - return node, true - } - } + if callExpr, ok := n.(*ast.CallExpr); ok { + packageName, callName, err := GetCallInfo(callExpr, c) + if err != nil { + return nil, false + } + if packageName == importName { + for _, name := range names { + if callName == name { + return callExpr, true } } } @@ -95,19 +92,15 @@ func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*a // node, matched := MatchCallByType(n, ctx, "bytes.Buffer", "WriteTo", "Write") // func MatchCallByType(n ast.Node, ctx *Context, requiredType string, calls ...string) (*ast.CallExpr, bool) { - switch callExpr := n.(type) { - case *ast.CallExpr: - switch fn := callExpr.Fun.(type) { - case *ast.SelectorExpr: - switch expr := fn.X.(type) { - case *ast.Ident: - t := ctx.Info.TypeOf(expr) - if t != nil && t.String() == requiredType { - for _, call := range calls { - if fn.Sel.Name == call { - return callExpr, true - } - } + if callExpr, ok := n.(*ast.CallExpr); ok { + typeName, callName, err := GetCallInfo(callExpr, ctx) + if err != nil { + return nil, false + } + if typeName == requiredType { + for _, call := range calls { + if call == callName { + return callExpr, true } } } @@ -171,3 +164,28 @@ func GetCallObject(n ast.Node, ctx *Context) (*ast.CallExpr, types.Object) { } return nil, nil } + +// GetCallInfo returns the package or type and name associated with a +// call expression. +func GetCallInfo(n ast.Node, ctx *Context) (string, string, error) { + switch node := n.(type) { + case *ast.CallExpr: + switch fn := node.Fun.(type) { + case *ast.SelectorExpr: + switch expr := fn.X.(type) { + case *ast.Ident: + if expr.Obj != nil && expr.Obj.Kind == ast.Var { + t := ctx.Info.TypeOf(expr) + if t != nil { + return t.String(), fn.Sel.Name, nil + } else { + return "undefined", fn.Sel.Name, fmt.Errorf("missing type info") + } + } else { + return expr.Name, fn.Sel.Name, nil + } + } + } + } + return "", "", fmt.Errorf("unable to determine call info") +} diff --git a/core/helpers_test.go b/core/helpers_test.go index beb7afb..89648e7 100644 --- a/core/helpers_test.go +++ b/core/helpers_test.go @@ -56,4 +56,16 @@ func TestMatchCallByType(t *testing.T) { if rule.matched != 1 || len(rule.callExpr) != 1 { t.Errorf("Expected to match a bytes.Buffer.Write call") } + + typeName, callName, err := GetCallInfo(rule.callExpr[0], &analyzer.context) + if err != nil { + t.Errorf("Unable to resolve call info: %v\n", err) + } + if typeName != "bytes.Buffer" { + t.Errorf("Expected: %s, Got: %s\n", "bytes.Buffer", typeName) + } + if callName != "Write" { + t.Errorf("Expected: %s, Got: %s\n", "Write", callName) + } + } From 129be1561b9e097f001ff986aa1e935a0110733a Mon Sep 17 00:00:00 2001 From: Grant Murphy Date: Fri, 18 Nov 2016 14:09:10 -0800 Subject: [PATCH 3/3] Update error test case There were several issues with the error test case that have been addressed in this commit. - It is possible to specify a whitelist of calls that error handling should be ignored for. - Additional support for ast.ExprStmt for cases where the error is implicitly ignored. There were several other additions to the helpers and call list in order to support this type of functionality. Fixes #54 --- core/call_list.go | 22 ++++++--- core/call_list_test.go | 4 +- core/helpers.go | 45 ++++++++++++----- rules/errors.go | 77 +++++++++++++++++++++-------- rules/errors_test.go | 108 ++++++++++++++++++++++++++++++----------- 5 files changed, 187 insertions(+), 69 deletions(-) diff --git a/core/call_list.go b/core/call_list.go index 9ace433..2002024 100644 --- a/core/call_list.go +++ b/core/call_list.go @@ -13,7 +13,9 @@ package core -import "go/ast" +import ( + "go/ast" +) type set map[string]bool @@ -26,14 +28,11 @@ func NewCallList() CallList { return make(CallList) } -/// NewCallListFor createse a call list using the package path -func NewCallListFor(selector string, idents ...string) CallList { - c := NewCallList() - c[selector] = make(set) +/// AddAll will add several calls to the call list at once +func (c CallList) AddAll(selector string, idents ...string) { for _, ident := range idents { c.Add(selector, ident) } - return c } /// Add a selector and call to the call list @@ -61,5 +60,14 @@ func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) bool { if err != nil { return false } - return c.Contains(selector, ident) + // Try direct resolution + if c.Contains(selector, ident) { + return true + } + + // Also support explicit path + if path, ok := GetImportPath(selector, ctx); ok { + return c.Contains(path, ident) + } + return false } diff --git a/core/call_list_test.go b/core/call_list_test.go index aa4f67c..ef58293 100644 --- a/core/call_list_test.go +++ b/core/call_list_test.go @@ -21,13 +21,15 @@ func (r *callListRule) Match(n ast.Node, c *Context) (gi *Issue, err error) { func TestCallListContainsCallExpr(t *testing.T) { config := map[string]interface{}{"ignoreNosec": false} analyzer := NewAnalyzer(config, nil) + calls := NewCallList() + calls.AddAll("bytes.Buffer", "Write", "WriteTo") rule := &callListRule{ MetaData: MetaData{ Severity: Low, Confidence: Low, What: "A dummy rule", }, - callList: NewCallListFor("bytes.Buffer", "Write", "WriteTo"), + callList: calls, matched: 0, } analyzer.AddRule(rule, []ast.Node{(*ast.CallExpr)(nil)}) diff --git a/core/helpers.go b/core/helpers.go index 79a1617..d42ceca 100644 --- a/core/helpers.go +++ b/core/helpers.go @@ -56,25 +56,17 @@ func MatchCall(n ast.Node, r *regexp.Regexp) *ast.CallExpr { // func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*ast.CallExpr, bool) { - importName, imported := c.Imports.Imported[pkg] - if !imported { + importedName, found := GetImportedName(pkg, c) + if !found { return nil, false } - if _, initonly := c.Imports.InitOnly[pkg]; initonly { - return nil, false - } - - if alias, ok := c.Imports.Aliased[pkg]; ok { - importName = alias - } - if callExpr, ok := n.(*ast.CallExpr); ok { packageName, callName, err := GetCallInfo(callExpr, c) if err != nil { return nil, false } - if packageName == importName { + if packageName == importedName { for _, name := range names { if callName == name { return callExpr, true @@ -185,7 +177,38 @@ func GetCallInfo(n ast.Node, ctx *Context) (string, string, error) { return expr.Name, fn.Sel.Name, nil } } + case *ast.Ident: + return ctx.Pkg.Name(), fn.Name, nil } } return "", "", fmt.Errorf("unable to determine call info") } + +// GetImportedName returns the name used for the package within the +// code. It will resolve aliases and ignores initalization only imports. +func GetImportedName(path string, ctx *Context) (string, bool) { + importName, imported := ctx.Imports.Imported[path] + if !imported { + return "", false + } + + if _, initonly := ctx.Imports.InitOnly[path]; initonly { + return "", false + } + + if alias, ok := ctx.Imports.Aliased[path]; ok { + importName = alias + } + return importName, true +} + +// GetImportPath resolves the full import path of an identifer based on +// the imports in the current context. +func GetImportPath(name string, ctx *Context) (string, bool) { + for path, _ := range ctx.Imports.Imported { + if imported, ok := GetImportedName(path, ctx); ok && imported == name { + return path, true + } + } + return "", false +} diff --git a/rules/errors.go b/rules/errors.go index 4490312..2bf61c9 100644 --- a/rules/errors.go +++ b/rules/errors.go @@ -15,47 +15,82 @@ package rules import ( + gas "github.com/GoASTScanner/gas/core" "go/ast" "go/types" - "reflect" - - gas "github.com/GoASTScanner/gas/core" ) type NoErrorCheck struct { gas.MetaData + whitelist gas.CallList } -func (r *NoErrorCheck) Match(n ast.Node, c *gas.Context) (gi *gas.Issue, err error) { - if node, ok := n.(*ast.AssignStmt); ok { - sel := reflect.TypeOf(&ast.CallExpr{}) - if call, ok := gas.SimpleSelect(node.Rhs[0], sel).(*ast.CallExpr); ok { - if t := c.Info.Types[call].Type; t != nil { - if typeVal, typeErr := t.(*types.Tuple); typeErr { - for i := 0; i < typeVal.Len(); i++ { - if typeVal.At(i).Type().String() == "error" { // TODO(tkelsey): is there a better way? - if id, ok := node.Lhs[i].(*ast.Ident); ok && id.Name == "_" { - return gas.NewIssue(c, n, r.What, r.Severity, r.Confidence), nil - } - } - } - } else if t.String() == "error" { // TODO(tkelsey): is there a better way? - if id, ok := node.Lhs[0].(*ast.Ident); ok && id.Name == "_" { - return gas.NewIssue(c, n, r.What, r.Severity, r.Confidence), nil - } +func returnsError(callExpr *ast.CallExpr, ctx *gas.Context) int { + if tv := ctx.Info.TypeOf(callExpr); tv != nil { + switch t := tv.(type) { + case *types.Tuple: + for pos := 0; pos < t.Len(); pos += 1 { + variable := t.At(pos) + if variable != nil && variable.Type().String() == "error" { + return pos } } + case *types.Named: + if t.String() == "error" { + return 0 + } + } + } + return -1 +} + +func (r *NoErrorCheck) Match(n ast.Node, ctx *gas.Context) (*gas.Issue, error) { + switch stmt := n.(type) { + case *ast.AssignStmt: + for _, expr := range stmt.Rhs { + if callExpr, ok := expr.(*ast.CallExpr); ok && !r.whitelist.ContainsCallExpr(callExpr, ctx) { + pos := returnsError(callExpr, ctx) + if pos < 0 || pos >= len(stmt.Lhs) { + return nil, nil + } + if id, ok := stmt.Lhs[pos].(*ast.Ident); ok && id.Name == "_" { + return gas.NewIssue(ctx, n, r.What, r.Severity, r.Confidence), nil + } + } + } + case *ast.ExprStmt: + if callExpr, ok := stmt.X.(*ast.CallExpr); ok && !r.whitelist.ContainsCallExpr(callExpr, ctx) { + pos := returnsError(callExpr, ctx) + if pos >= 0 { + return gas.NewIssue(ctx, n, r.What, r.Severity, r.Confidence), nil + } } } return nil, nil } func NewNoErrorCheck(conf map[string]interface{}) (gas.Rule, []ast.Node) { + + // TODO(gm) Come up with sensible defaults here. Or flip it to use a + // black list instead. + whitelist := gas.NewCallList() + whitelist.AddAll("bytes.Buffer", "Write", "WriteByte", "WriteRune", "WriteString") + whitelist.AddAll("fmt", "Print", "Printf", "Println") + whitelist.Add("io.PipeWriter", "CloseWithError") + + if configured, ok := conf["G104"]; ok { + if whitelisted, ok := configured.(map[string][]string); ok { + for key, val := range whitelisted { + whitelist.AddAll(key, val...) + } + } + } return &NoErrorCheck{ MetaData: gas.MetaData{ Severity: gas.Low, Confidence: gas.High, What: "Errors unhandled.", }, - }, []ast.Node{(*ast.AssignStmt)(nil)} + whitelist: whitelist, + }, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)} } diff --git a/rules/errors_test.go b/rules/errors_test.go index 4ae502b..d4a07a0 100644 --- a/rules/errors_test.go +++ b/rules/errors_test.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -28,17 +28,17 @@ func TestErrorsMulti(t *testing.T) { issues := gasTestRunner( `package main - import ( - "fmt" - ) + import ( + "fmt" + ) - func test() (val int, err error) { - return 0, nil - } + func test() (val int, err error) { + return 0, nil + } - func main() { - v, _ := test() - }`, analyzer) + func main() { + v, _ := test() + }`, analyzer) checkTestResults(t, issues, 1, "Errors unhandled") } @@ -51,19 +51,30 @@ func TestErrorsSingle(t *testing.T) { issues := gasTestRunner( `package main - import ( - "fmt" - ) + import ( + "fmt" + ) - func test() (err error) { - return nil - } + func a() error { + return fmt.Errorf("This is an error") + } - func main() { - _ := test() - }`, analyzer) + func b() { + fmt.Println("b") + } - checkTestResults(t, issues, 1, "Errors unhandled") + func c() string { + return fmt.Sprintf("This isn't anything") + } + + func main() { + _ = a() + a() + b() + _ = c() + c() + }`, analyzer) + checkTestResults(t, issues, 2, "Errors unhandled") } func TestErrorsGood(t *testing.T) { @@ -74,17 +85,56 @@ func TestErrorsGood(t *testing.T) { issues := gasTestRunner( `package main - import ( - "fmt" - ) + import ( + "fmt" + ) - func test() err error { - return 0, nil - } + func test() err error { + return 0, nil + } - func main() { - e := test() - }`, analyzer) + func main() { + e := test() + }`, analyzer) checkTestResults(t, issues, 0, "") } + +func TestErrorsWhitelisted(t *testing.T) { + config := map[string]interface{}{ + "ignoreNosec": false, + "G104": map[string][]string{ + "compress/zlib": []string{"NewReader"}, + "io": []string{"Copy"}, + }, + } + analyzer := gas.NewAnalyzer(config, nil) + analyzer.AddRule(NewNoErrorCheck(config)) + source := `package main + import ( + "io" + "os" + "fmt" + "bytes" + "compress/zlib" + ) + + func a() error { + return fmt.Errorf("This is an error ok") + } + + func main() { + // Expect at least one failure + _ = a() + + var b bytes.Buffer + // Default whitelist + nbytes, _ := b.Write([]byte("Hello ")) + + // Whitelisted via configuration + r, _ := zlib.NewReader(&b) + io.Copy(os.Stdout, r) + }` + issues := gasTestRunner(source, analyzer) + checkTestResults(t, issues, 1, "Errors unhandled") +}