Merge pull request #92 from GoASTScanner/experimental

Resolve issues with error rules
This commit is contained in:
Grant Murphy 2016-12-02 09:01:30 -08:00 committed by GitHub
commit 8f78248b61
6 changed files with 391 additions and 99 deletions

View file

@ -13,53 +13,61 @@
package core
type set map[string]bool
import (
"go/ast"
)
type calls struct {
matchAny bool
functions set
}
type set map[string]bool
/// CallList is used to check for usage of specific packages
/// and functions.
type CallList map[string]*calls
type CallList map[string]set
/// NewCallList creates a new empty CallList
func NewCallList() CallList {
return make(CallList)
}
/// NewCallListFor createse a call list using the package path
func NewCallListFor(pkg string, funcs ...string) CallList {
c := NewCallList()
if len(funcs) == 0 {
c[pkg] = &calls{true, make(set)}
} else {
for _, fn := range funcs {
c.Add(pkg, fn)
}
/// AddAll will add several calls to the call list at once
func (c CallList) AddAll(selector string, idents ...string) {
for _, ident := range idents {
c.Add(selector, ident)
}
return c
}
/// Add a new package and function to the call list
func (c CallList) Add(pkg, fn string) {
if cl, ok := c[pkg]; ok {
if cl.matchAny {
cl.matchAny = false
}
} else {
c[pkg] = &calls{false, make(set)}
/// Add a selector and call to the call list
func (c CallList) Add(selector, ident string) {
if _, ok := c[selector]; !ok {
c[selector] = make(set)
}
c[pkg].functions[fn] = true
c[selector][ident] = true
}
/// Contains returns true if the package and function are
/// members of this call list.
func (c CallList) Contains(pkg, fn string) bool {
if funcs, ok := c[pkg]; ok {
_, ok = funcs.functions[fn]
return ok || funcs.matchAny
func (c CallList) Contains(selector, ident string) bool {
if idents, ok := c[selector]; ok {
_, found := idents[ident]
return found
}
return false
}
/// ContainsCallExpr resolves the call expression name and type
/// or package and determines if it exists within the CallList
func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) bool {
selector, ident, err := GetCallInfo(n, ctx)
if err != nil {
return false
}
// Try direct resolution
if c.Contains(selector, ident) {
return true
}
// Also support explicit path
if path, ok := GetImportPath(selector, ctx); ok {
return c.Contains(path, ident)
}
return false
}

60
core/call_list_test.go Normal file
View file

@ -0,0 +1,60 @@
package core
import (
"go/ast"
"testing"
)
type callListRule struct {
MetaData
callList CallList
matched int
}
func (r *callListRule) Match(n ast.Node, c *Context) (gi *Issue, err error) {
if r.callList.ContainsCallExpr(n, c) {
r.matched += 1
}
return nil, nil
}
func TestCallListContainsCallExpr(t *testing.T) {
config := map[string]interface{}{"ignoreNosec": false}
analyzer := NewAnalyzer(config, nil)
calls := NewCallList()
calls.AddAll("bytes.Buffer", "Write", "WriteTo")
rule := &callListRule{
MetaData: MetaData{
Severity: Low,
Confidence: Low,
What: "A dummy rule",
},
callList: calls,
matched: 0,
}
analyzer.AddRule(rule, []ast.Node{(*ast.CallExpr)(nil)})
source := `
package main
import (
"bytes"
"fmt"
)
func main() {
var b bytes.Buffer
b.Write([]byte("Hello "))
fmt.Fprintf(&b, "world!")
}`
analyzer.ProcessSource("dummy.go", source)
if rule.matched != 1 {
t.Errorf("Expected to match a bytes.Buffer.Write call")
}
}
func TestCallListContains(t *testing.T) {
callList := NewCallList()
callList.Add("fmt", "Printf")
if !callList.Contains("fmt", "Printf") {
t.Errorf("Expected call list to contain fmt.Printf")
}
}

View file

@ -56,31 +56,43 @@ func MatchCall(n ast.Node, r *regexp.Regexp) *ast.CallExpr {
//
func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*ast.CallExpr, bool) {
importName, imported := c.Imports.Imported[pkg]
if !imported {
importedName, found := GetImportedName(pkg, c)
if !found {
return nil, false
}
if _, initonly := c.Imports.InitOnly[pkg]; initonly {
return nil, false
if callExpr, ok := n.(*ast.CallExpr); ok {
packageName, callName, err := GetCallInfo(callExpr, c)
if err != nil {
return nil, false
}
if packageName == importedName {
for _, name := range names {
if callName == name {
return callExpr, true
}
}
}
}
return nil, false
}
if alias, ok := c.Imports.Aliased[pkg]; ok {
importName = alias
}
switch node := n.(type) {
case *ast.CallExpr:
switch fn := node.Fun.(type) {
case *ast.SelectorExpr:
switch expr := fn.X.(type) {
case *ast.Ident:
if expr.Name == importName {
for _, name := range names {
if fn.Sel.Name == name {
return node, true
}
}
// MatchCallByType ensures that the node is a call expression to a
// specific object type.
//
// Usage:
// node, matched := MatchCallByType(n, ctx, "bytes.Buffer", "WriteTo", "Write")
//
func MatchCallByType(n ast.Node, ctx *Context, requiredType string, calls ...string) (*ast.CallExpr, bool) {
if callExpr, ok := n.(*ast.CallExpr); ok {
typeName, callName, err := GetCallInfo(callExpr, ctx)
if err != nil {
return nil, false
}
if typeName == requiredType {
for _, call := range calls {
if call == callName {
return callExpr, true
}
}
}
@ -144,3 +156,59 @@ func GetCallObject(n ast.Node, ctx *Context) (*ast.CallExpr, types.Object) {
}
return nil, nil
}
// GetCallInfo returns the package or type and name associated with a
// call expression.
func GetCallInfo(n ast.Node, ctx *Context) (string, string, error) {
switch node := n.(type) {
case *ast.CallExpr:
switch fn := node.Fun.(type) {
case *ast.SelectorExpr:
switch expr := fn.X.(type) {
case *ast.Ident:
if expr.Obj != nil && expr.Obj.Kind == ast.Var {
t := ctx.Info.TypeOf(expr)
if t != nil {
return t.String(), fn.Sel.Name, nil
} else {
return "undefined", fn.Sel.Name, fmt.Errorf("missing type info")
}
} else {
return expr.Name, fn.Sel.Name, nil
}
}
case *ast.Ident:
return ctx.Pkg.Name(), fn.Name, nil
}
}
return "", "", fmt.Errorf("unable to determine call info")
}
// GetImportedName returns the name used for the package within the
// code. It will resolve aliases and ignores initalization only imports.
func GetImportedName(path string, ctx *Context) (string, bool) {
importName, imported := ctx.Imports.Imported[path]
if !imported {
return "", false
}
if _, initonly := ctx.Imports.InitOnly[path]; initonly {
return "", false
}
if alias, ok := ctx.Imports.Aliased[path]; ok {
importName = alias
}
return importName, true
}
// GetImportPath resolves the full import path of an identifer based on
// the imports in the current context.
func GetImportPath(name string, ctx *Context) (string, bool) {
for path, _ := range ctx.Imports.Imported {
if imported, ok := GetImportedName(path, ctx); ok && imported == name {
return path, true
}
}
return "", false
}

71
core/helpers_test.go Normal file
View file

@ -0,0 +1,71 @@
package core
import (
"go/ast"
"testing"
)
type dummyCallback func(ast.Node, *Context, string, ...string) (*ast.CallExpr, bool)
type dummyRule struct {
MetaData
pkgOrType string
funcsOrMethods []string
callback dummyCallback
callExpr []ast.Node
matched int
}
func (r *dummyRule) Match(n ast.Node, c *Context) (gi *Issue, err error) {
if callexpr, matched := r.callback(n, c, r.pkgOrType, r.funcsOrMethods...); matched {
r.matched += 1
r.callExpr = append(r.callExpr, callexpr)
}
return nil, nil
}
func TestMatchCallByType(t *testing.T) {
config := map[string]interface{}{"ignoreNosec": false}
analyzer := NewAnalyzer(config, nil)
rule := &dummyRule{
MetaData: MetaData{
Severity: Low,
Confidence: Low,
What: "A dummy rule",
},
pkgOrType: "bytes.Buffer",
funcsOrMethods: []string{"Write"},
callback: MatchCallByType,
callExpr: []ast.Node{},
matched: 0,
}
analyzer.AddRule(rule, []ast.Node{(*ast.CallExpr)(nil)})
source := `
package main
import (
"bytes"
"fmt"
)
func main() {
var b bytes.Buffer
b.Write([]byte("Hello "))
fmt.Fprintf(&b, "world!")
}`
analyzer.ProcessSource("dummy.go", source)
if rule.matched != 1 || len(rule.callExpr) != 1 {
t.Errorf("Expected to match a bytes.Buffer.Write call")
}
typeName, callName, err := GetCallInfo(rule.callExpr[0], &analyzer.context)
if err != nil {
t.Errorf("Unable to resolve call info: %v\n", err)
}
if typeName != "bytes.Buffer" {
t.Errorf("Expected: %s, Got: %s\n", "bytes.Buffer", typeName)
}
if callName != "Write" {
t.Errorf("Expected: %s, Got: %s\n", "Write", callName)
}
}

View file

@ -15,47 +15,82 @@
package rules
import (
gas "github.com/GoASTScanner/gas/core"
"go/ast"
"go/types"
"reflect"
gas "github.com/GoASTScanner/gas/core"
)
type NoErrorCheck struct {
gas.MetaData
whitelist gas.CallList
}
func (r *NoErrorCheck) Match(n ast.Node, c *gas.Context) (gi *gas.Issue, err error) {
if node, ok := n.(*ast.AssignStmt); ok {
sel := reflect.TypeOf(&ast.CallExpr{})
if call, ok := gas.SimpleSelect(node.Rhs[0], sel).(*ast.CallExpr); ok {
if t := c.Info.Types[call].Type; t != nil {
if typeVal, typeErr := t.(*types.Tuple); typeErr {
for i := 0; i < typeVal.Len(); i++ {
if typeVal.At(i).Type().String() == "error" { // TODO(tkelsey): is there a better way?
if id, ok := node.Lhs[i].(*ast.Ident); ok && id.Name == "_" {
return gas.NewIssue(c, n, r.What, r.Severity, r.Confidence), nil
}
}
}
} else if t.String() == "error" { // TODO(tkelsey): is there a better way?
if id, ok := node.Lhs[0].(*ast.Ident); ok && id.Name == "_" {
return gas.NewIssue(c, n, r.What, r.Severity, r.Confidence), nil
}
func returnsError(callExpr *ast.CallExpr, ctx *gas.Context) int {
if tv := ctx.Info.TypeOf(callExpr); tv != nil {
switch t := tv.(type) {
case *types.Tuple:
for pos := 0; pos < t.Len(); pos += 1 {
variable := t.At(pos)
if variable != nil && variable.Type().String() == "error" {
return pos
}
}
case *types.Named:
if t.String() == "error" {
return 0
}
}
}
return -1
}
func (r *NoErrorCheck) Match(n ast.Node, ctx *gas.Context) (*gas.Issue, error) {
switch stmt := n.(type) {
case *ast.AssignStmt:
for _, expr := range stmt.Rhs {
if callExpr, ok := expr.(*ast.CallExpr); ok && !r.whitelist.ContainsCallExpr(callExpr, ctx) {
pos := returnsError(callExpr, ctx)
if pos < 0 || pos >= len(stmt.Lhs) {
return nil, nil
}
if id, ok := stmt.Lhs[pos].(*ast.Ident); ok && id.Name == "_" {
return gas.NewIssue(ctx, n, r.What, r.Severity, r.Confidence), nil
}
}
}
case *ast.ExprStmt:
if callExpr, ok := stmt.X.(*ast.CallExpr); ok && !r.whitelist.ContainsCallExpr(callExpr, ctx) {
pos := returnsError(callExpr, ctx)
if pos >= 0 {
return gas.NewIssue(ctx, n, r.What, r.Severity, r.Confidence), nil
}
}
}
return nil, nil
}
func NewNoErrorCheck(conf map[string]interface{}) (gas.Rule, []ast.Node) {
// TODO(gm) Come up with sensible defaults here. Or flip it to use a
// black list instead.
whitelist := gas.NewCallList()
whitelist.AddAll("bytes.Buffer", "Write", "WriteByte", "WriteRune", "WriteString")
whitelist.AddAll("fmt", "Print", "Printf", "Println")
whitelist.Add("io.PipeWriter", "CloseWithError")
if configured, ok := conf["G104"]; ok {
if whitelisted, ok := configured.(map[string][]string); ok {
for key, val := range whitelisted {
whitelist.AddAll(key, val...)
}
}
}
return &NoErrorCheck{
MetaData: gas.MetaData{
Severity: gas.Low,
Confidence: gas.High,
What: "Errors unhandled.",
},
}, []ast.Node{(*ast.AssignStmt)(nil)}
whitelist: whitelist,
}, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)}
}

View file

@ -4,7 +4,7 @@
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
@ -28,17 +28,17 @@ func TestErrorsMulti(t *testing.T) {
issues := gasTestRunner(
`package main
import (
"fmt"
)
import (
"fmt"
)
func test() (val int, err error) {
return 0, nil
}
func test() (val int, err error) {
return 0, nil
}
func main() {
v, _ := test()
}`, analyzer)
func main() {
v, _ := test()
}`, analyzer)
checkTestResults(t, issues, 1, "Errors unhandled")
}
@ -51,19 +51,30 @@ func TestErrorsSingle(t *testing.T) {
issues := gasTestRunner(
`package main
import (
"fmt"
)
import (
"fmt"
)
func test() (err error) {
return nil
}
func a() error {
return fmt.Errorf("This is an error")
}
func main() {
_ := test()
}`, analyzer)
func b() {
fmt.Println("b")
}
checkTestResults(t, issues, 1, "Errors unhandled")
func c() string {
return fmt.Sprintf("This isn't anything")
}
func main() {
_ = a()
a()
b()
_ = c()
c()
}`, analyzer)
checkTestResults(t, issues, 2, "Errors unhandled")
}
func TestErrorsGood(t *testing.T) {
@ -74,17 +85,56 @@ func TestErrorsGood(t *testing.T) {
issues := gasTestRunner(
`package main
import (
"fmt"
)
import (
"fmt"
)
func test() err error {
return 0, nil
}
func test() err error {
return 0, nil
}
func main() {
e := test()
}`, analyzer)
func main() {
e := test()
}`, analyzer)
checkTestResults(t, issues, 0, "")
}
func TestErrorsWhitelisted(t *testing.T) {
config := map[string]interface{}{
"ignoreNosec": false,
"G104": map[string][]string{
"compress/zlib": []string{"NewReader"},
"io": []string{"Copy"},
},
}
analyzer := gas.NewAnalyzer(config, nil)
analyzer.AddRule(NewNoErrorCheck(config))
source := `package main
import (
"io"
"os"
"fmt"
"bytes"
"compress/zlib"
)
func a() error {
return fmt.Errorf("This is an error ok")
}
func main() {
// Expect at least one failure
_ = a()
var b bytes.Buffer
// Default whitelist
nbytes, _ := b.Write([]byte("Hello "))
// Whitelisted via configuration
r, _ := zlib.NewReader(&b)
io.Copy(os.Stdout, r)
}`
issues := gasTestRunner(source, analyzer)
checkTestResults(t, issues, 1, "Errors unhandled")
}