Fix the errors rule whitelist to work on types methods

Signed-off-by: Cosmin Cojocar <cosmin.cojocar@gmx.ch>
This commit is contained in:
Cosmin Cojocar 2020-01-28 14:11:00 +01:00 committed by Cosmin Cojocar
parent 459e2d3e91
commit 3e069e7756
17 changed files with 286 additions and 38 deletions

View file

@ -56,9 +56,22 @@ func (c CallList) Contains(selector, ident string) bool {
return false
}
// ContainsCallExpr resolves the call expression name and type
/// or package and determines if it exists within the CallList
func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context, stripVendor bool) *ast.CallExpr {
// ContainsPointer returns true if a pointer to the selector type or the type
// itslef is a members of this call list.
func (c CallList) ContainsPointer(selector, indent string) bool {
if strings.HasPrefix(selector, "*") {
if c.Contains(selector, indent) {
return true
}
s := strings.TrimPrefix(selector, "*")
return c.Contains(s, indent)
}
return false
}
// ContainsPkgCallExpr resolves the call expression name and type, and then further looks
// up the package path for that type. Finally, it determines if the call exists within the CallList
func (c CallList) ContainsPkgCallExpr(n ast.Node, ctx *Context, stripVendor bool) *ast.CallExpr {
selector, ident, err := GetCallInfo(n, ctx)
if err != nil {
return nil
@ -79,12 +92,18 @@ func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context, stripVendor bool) *
}
return n.(*ast.CallExpr)
/*
// Try direct resolution
if c.Contains(selector, ident) {
log.Printf("c.Contains == true, %s, %s.", selector, ident)
return n.(*ast.CallExpr)
}
*/
}
// ContainsCallExpr resolves the call experssion name and type, and then determines
// if the call existis with the call list
func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) *ast.CallExpr {
selector, ident, err := GetCallInfo(n, ctx)
if err != nil {
return nil
}
if !c.Contains(selector, ident) && !c.ContainsPointer(selector, ident) {
return nil
}
return n.(*ast.CallExpr)
}

View file

@ -46,6 +46,20 @@ var _ = Describe("Call List", func() {
Expect(actual).Should(Equal(expected))
})
It("should be possible to add pointer call", func() {
Expect(calls).Should(HaveLen(0))
calls.Add("*bytes.Buffer", "WriteString")
actual := calls.ContainsPointer("*bytes.Buffer", "WriteString")
Expect(actual).Should(BeTrue())
})
It("should be possible to check pointer call", func() {
Expect(calls).Should(HaveLen(0))
calls.Add("bytes.Buffer", "WriteString")
actual := calls.ContainsPointer("*bytes.Buffer", "WriteString")
Expect(actual).Should(BeTrue())
})
It("should not return a match if none are present", func() {
calls.Add("ioutil", "Copy")
Expect(calls.Contains("fmt", "Println")).Should(BeFalse())
@ -56,8 +70,7 @@ var _ = Describe("Call List", func() {
Expect(calls.Contains("ioutil", "Copy")).Should(BeTrue())
})
It("should match a call expression", func() {
It("should match a package call expression", func() {
// Create file to be scanned
pkg := testutils.NewTestPackage()
defer pkg.Close()
@ -73,14 +86,39 @@ var _ = Describe("Call List", func() {
v := testutils.NewMockVisitor()
v.Context = ctx
v.Callback = func(n ast.Node, ctx *gosec.Context) bool {
if _, ok := n.(*ast.CallExpr); ok && calls.ContainsCallExpr(n, ctx, false) != nil {
if _, ok := n.(*ast.CallExpr); ok && calls.ContainsPkgCallExpr(n, ctx, false) != nil {
matched++
}
return true
}
ast.Walk(v, ctx.Root)
Expect(matched).Should(Equal(1))
})
It("should match a call expression", func() {
// Create file to be scanned
pkg := testutils.NewTestPackage()
defer pkg.Close()
pkg.AddFile("main.go", testutils.SampleCodeG104[5].Code[0])
ctx := pkg.CreateContext("main.go")
calls.Add("bytes.Buffer", "WriteString")
calls.Add("strings.Builder", "WriteString")
calls.Add("io.Pipe", "CloseWithError")
calls.Add("fmt", "Fprintln")
// Stub out visitor and count number of matched call expr
matched := 0
v := testutils.NewMockVisitor()
v.Context = ctx
v.Callback = func(n ast.Node, ctx *gosec.Context) bool {
if _, ok := n.(*ast.CallExpr); ok && calls.ContainsCallExpr(n, ctx) != nil {
matched++
}
return true
}
ast.Walk(v, ctx.Root)
Expect(matched).Should(Equal(5))
})
})

View file

@ -135,11 +135,40 @@ func GetCallInfo(n ast.Node, ctx *Context) (string, string, error) {
return "undefined", fn.Sel.Name, fmt.Errorf("missing type info")
}
return expr.Name, fn.Sel.Name, nil
case *ast.CallExpr:
switch call := expr.Fun.(type) {
case *ast.Ident:
if call.Name == "new" {
t := ctx.Info.TypeOf(expr.Args[0])
if t != nil {
return t.String(), fn.Sel.Name, nil
}
return "undefined", fn.Sel.Name, fmt.Errorf("missing type info")
}
if call.Obj != nil {
switch decl := call.Obj.Decl.(type) {
case *ast.FuncDecl:
ret := decl.Type.Results
if ret != nil && len(ret.List) > 0 {
ret1 := ret.List[0]
if ret1 != nil {
t := ctx.Info.TypeOf(ret1.Type)
if t != nil {
return t.String(), fn.Sel.Name, nil
}
return "undefined", fn.Sel.Name, fmt.Errorf("missing type info")
}
}
}
}
}
}
case *ast.Ident:
return ctx.Pkg.Name(), fn.Name, nil
}
}
return "", "", fmt.Errorf("unable to determine call info")
}

View file

@ -1,6 +1,7 @@
package gosec_test
import (
"go/ast"
"io/ioutil"
"os"
"path/filepath"
@ -9,6 +10,7 @@ import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/securego/gosec"
"github.com/securego/gosec/testutils"
)
var _ = Describe("Helpers", func() {
@ -91,4 +93,140 @@ var _ = Describe("Helpers", func() {
Expect(len(r)).Should(Equal(0))
})
})
Context("when getting call info", func() {
It("should return the type and call name for selector expression", func() {
pkg := testutils.NewTestPackage()
defer pkg.Close()
pkg.AddFile("main.go", `
package main
import(
"bytes"
)
func main() {
b := new(bytes.Buffer)
_, err := b.WriteString("test")
if err != nil {
panic(err)
}
}
`)
ctx := pkg.CreateContext("main.go")
result := map[string]string{}
visitor := testutils.NewMockVisitor()
visitor.Context = ctx
visitor.Callback = func(n ast.Node, ctx *gosec.Context) bool {
typeName, call, err := gosec.GetCallInfo(n, ctx)
if err == nil {
result[typeName] = call
}
return true
}
ast.Walk(visitor, ctx.Root)
Expect(result).Should(HaveKeyWithValue("*bytes.Buffer", "WriteString"))
})
It("should return the type and call name for new selector expression", func() {
pkg := testutils.NewTestPackage()
defer pkg.Close()
pkg.AddFile("main.go", `
package main
import(
"bytes"
)
func main() {
_, err := new(bytes.Buffer).WriteString("test")
if err != nil {
panic(err)
}
}
`)
ctx := pkg.CreateContext("main.go")
result := map[string]string{}
visitor := testutils.NewMockVisitor()
visitor.Context = ctx
visitor.Callback = func(n ast.Node, ctx *gosec.Context) bool {
typeName, call, err := gosec.GetCallInfo(n, ctx)
if err == nil {
result[typeName] = call
}
return true
}
ast.Walk(visitor, ctx.Root)
Expect(result).Should(HaveKeyWithValue("bytes.Buffer", "WriteString"))
})
It("should return the type and call name for function selector expression", func() {
pkg := testutils.NewTestPackage()
defer pkg.Close()
pkg.AddFile("main.go", `
package main
import(
"bytes"
)
func createBuffer() *bytes.Buffer {
return new(bytes.Buffer)
}
func main() {
_, err := createBuffer().WriteString("test")
if err != nil {
panic(err)
}
}
`)
ctx := pkg.CreateContext("main.go")
result := map[string]string{}
visitor := testutils.NewMockVisitor()
visitor.Context = ctx
visitor.Callback = func(n ast.Node, ctx *gosec.Context) bool {
typeName, call, err := gosec.GetCallInfo(n, ctx)
if err == nil {
result[typeName] = call
}
return true
}
ast.Walk(visitor, ctx.Root)
Expect(result).Should(HaveKeyWithValue("*bytes.Buffer", "WriteString"))
})
It("should return the type and call name for package function", func() {
pkg := testutils.NewTestPackage()
defer pkg.Close()
pkg.AddFile("main.go", `
package main
import(
"fmt"
)
func main() {
fmt.Println("test")
}
`)
ctx := pkg.CreateContext("main.go")
result := map[string]string{}
visitor := testutils.NewMockVisitor()
visitor.Context = ctx
visitor.Callback = func(n ast.Node, ctx *gosec.Context) bool {
typeName, call, err := gosec.GetCallInfo(n, ctx)
if err == nil {
result[typeName] = call
}
return true
}
ast.Walk(visitor, ctx.Root)
Expect(result).Should(HaveKeyWithValue("fmt", "Println"))
})
})
})

View file

@ -19,7 +19,7 @@ func (a *archive) ID() string {
// Match inspects AST nodes to determine if the filepath.Joins uses any argument derived from type zip.File
func (a *archive) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
if node := a.calls.ContainsCallExpr(n, c, false); node != nil {
if node := a.calls.ContainsPkgCallExpr(n, c, false); node != nil {
for _, arg := range node.Args {
var argType types.Type
if selector, ok := arg.(*ast.SelectorExpr); ok {

View file

@ -33,7 +33,7 @@ func (r *bindsToAllNetworkInterfaces) ID() string {
}
func (r *bindsToAllNetworkInterfaces) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
callExpr := r.calls.ContainsCallExpr(n, c, false)
callExpr := r.calls.ContainsPkgCallExpr(n, c, false)
if callExpr == nil {
return nil, nil
}

View file

@ -16,8 +16,9 @@ package rules
import (
"fmt"
"github.com/securego/gosec"
"go/ast"
"github.com/securego/gosec"
)
type decompressionBombCheck struct {
@ -31,15 +32,12 @@ func (d *decompressionBombCheck) ID() string {
}
func containsReaderCall(node ast.Node, ctx *gosec.Context, list gosec.CallList) bool {
if list.ContainsCallExpr(node, ctx, false) != nil {
if list.ContainsPkgCallExpr(node, ctx, false) != nil {
return true
}
// Resolve type info of ident (for *archive/zip.File.Open)
s, idt, _ := gosec.GetCallInfo(node, ctx)
if list.Contains(s, idt) {
return true
}
return false
return list.Contains(s, idt)
}
func (d *decompressionBombCheck) Match(node ast.Node, ctx *gosec.Context) (*gosec.Issue, error) {
@ -70,7 +68,7 @@ func (d *decompressionBombCheck) Match(node ast.Node, ctx *gosec.Context) (*gose
}
}
case *ast.CallExpr:
if d.copyCalls.ContainsCallExpr(n, ctx, false) != nil {
if d.copyCalls.ContainsPkgCallExpr(n, ctx, false) != nil {
if idt, ok := n.Args[1].(*ast.Ident); ok {
if _, ok := readerVarObj[idt.Obj]; ok {
// Detect io.Copy(x, r)

View file

@ -55,7 +55,7 @@ func (r *noErrorCheck) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, erro
cfg := ctx.Config
if enabled, err := cfg.IsGlobalEnabled(gosec.Audit); err == nil && enabled {
for _, expr := range stmt.Rhs {
if callExpr, ok := expr.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(expr, ctx, false) == nil {
if callExpr, ok := expr.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(expr, ctx) == nil {
pos := returnsError(callExpr, ctx)
if pos < 0 || pos >= len(stmt.Lhs) {
return nil, nil
@ -67,7 +67,7 @@ func (r *noErrorCheck) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, erro
}
}
case *ast.ExprStmt:
if callExpr, ok := stmt.X.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(stmt.X, ctx, false) == nil {
if callExpr, ok := stmt.X.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(stmt.X, ctx) == nil {
pos := returnsError(callExpr, ctx)
if pos >= 0 {
return gosec.NewIssue(ctx, n, r.ID(), r.What, r.Severity, r.Confidence), nil

View file

@ -16,8 +16,9 @@ package rules
import (
"fmt"
"github.com/securego/gosec"
"go/ast"
"github.com/securego/gosec"
)
type integerOverflowCheck struct {
@ -47,7 +48,7 @@ func (i *integerOverflowCheck) Match(node ast.Node, ctx *gosec.Context) (*gosec.
switch n := node.(type) {
case *ast.AssignStmt:
for _, expr := range n.Rhs {
if callExpr, ok := expr.(*ast.CallExpr); ok && i.calls.ContainsCallExpr(callExpr, ctx, false) != nil {
if callExpr, ok := expr.(*ast.CallExpr); ok && i.calls.ContainsPkgCallExpr(callExpr, ctx, false) != nil {
if idt, ok := n.Lhs[0].(*ast.Ident); ok && idt.Name != "_" {
// Example:
// v, _ := strconv.Atoi("1111")

View file

@ -34,7 +34,7 @@ func (r *readfile) ID() string {
// isJoinFunc checks if there is a filepath.Join or other join function
func (r *readfile) isJoinFunc(n ast.Node, c *gosec.Context) bool {
if call := r.pathJoin.ContainsCallExpr(n, c, false); call != nil {
if call := r.pathJoin.ContainsPkgCallExpr(n, c, false); call != nil {
for _, arg := range call.Args {
// edge case: check if one of the args is a BinaryExpr
if binExp, ok := arg.(*ast.BinaryExpr); ok {
@ -58,7 +58,7 @@ func (r *readfile) isJoinFunc(n ast.Node, c *gosec.Context) bool {
// Match inspects AST nodes to determine if the match the methods `os.Open` or `ioutil.ReadFile`
func (r *readfile) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
if node := r.ContainsCallExpr(n, c, false); node != nil {
if node := r.ContainsPkgCallExpr(n, c, false); node != nil {
for _, arg := range node.Args {
// handles path joining functions in Arg
// eg. os.Open(filepath.Join("/tmp/", file))

View file

@ -32,7 +32,7 @@ func (w *weakKeyStrength) ID() string {
}
func (w *weakKeyStrength) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
if callExpr := w.calls.ContainsCallExpr(n, c, false); callExpr != nil {
if callExpr := w.calls.ContainsPkgCallExpr(n, c, false); callExpr != nil {
if bits, err := gosec.GetInt(callExpr.Args[1]); err == nil && bits < (int64)(w.bits) {
return gosec.NewIssue(c, n, w.ID(), w.What, w.Severity, w.Confidence), nil
}

View file

@ -137,7 +137,7 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
argIndex := 0
// TODO(gm) improve confidence if database/sql is being used
if node := s.calls.ContainsCallExpr(n, c, false); node != nil {
if node := s.calls.ContainsPkgCallExpr(n, c, false); node != nil {
// if the function is fmt.Fprintf, search for SQL statement in Args[1] instead
if sel, ok := node.Fun.(*ast.SelectorExpr); ok {
if sel.Sel.Name == "Fprintf" {
@ -177,7 +177,7 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
if argIndex+1 < len(node.Args) {
allSafe := true
for _, arg := range node.Args[argIndex+1:] {
if n := s.noIssueQuoted.ContainsCallExpr(arg, c, true); n == nil && !s.constObject(arg, c) {
if n := s.noIssueQuoted.ContainsPkgCallExpr(arg, c, true); n == nil && !s.constObject(arg, c) {
allSafe = false
break
}

View file

@ -42,7 +42,7 @@ func (r *ssrf) ResolveVar(n *ast.CallExpr, c *gosec.Context) bool {
// Match inspects AST nodes to determine if certain net/http methods are called with variable input
func (r *ssrf) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
// Call expression is using http package directly
if node := r.ContainsCallExpr(n, c, false); node != nil {
if node := r.ContainsPkgCallExpr(n, c, false); node != nil {
if r.ResolveVar(node, c) {
return gosec.NewIssue(c, n, r.ID(), r.What, r.Severity, r.Confidence), nil
}

View file

@ -40,7 +40,7 @@ func (r *subprocess) ID() string {
//
// syscall.Exec("echo", "foobar" + tainted)
func (r *subprocess) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
if node := r.ContainsCallExpr(n, c, false); node != nil {
if node := r.ContainsPkgCallExpr(n, c, false); node != nil {
for _, arg := range node.Args {
if ident, ok := arg.(*ast.Ident); ok {
obj := c.Info.ObjectOf(ident)

View file

@ -32,7 +32,7 @@ func (t *badTempFile) ID() string {
}
func (t *badTempFile) Match(n ast.Node, c *gosec.Context) (gi *gosec.Issue, err error) {
if node := t.calls.ContainsCallExpr(n, c, false); node != nil {
if node := t.calls.ContainsPkgCallExpr(n, c, false); node != nil {
if arg, e := gosec.GetString(node.Args[0]); t.args.MatchString(arg) && e == nil {
return gosec.NewIssue(c, n, t.ID(), t.What, t.Severity, t.Confidence), nil
}

View file

@ -30,7 +30,7 @@ func (t *templateCheck) ID() string {
}
func (t *templateCheck) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
if node := t.calls.ContainsCallExpr(n, c, false); node != nil {
if node := t.calls.ContainsPkgCallExpr(n, c, false); node != nil {
for _, arg := range node.Args {
if _, ok := arg.(*ast.BasicLit); !ok { // basic lits are safe
return gosec.NewIssue(c, n, t.ID(), t.What, t.Severity, t.Confidence), nil

View file

@ -245,7 +245,32 @@ func a() {
}
func main() {
a()
}`}, 0, gosec.Config{"G104": map[string]interface{}{"io/ioutil": []interface{}{"WriteFile"}}}}}
}`}, 0, gosec.Config{"G104": map[string]interface{}{"ioutil": []interface{}{"WriteFile"}}}}, {[]string{`
package main
import (
"bytes"
"fmt"
"io"
"os"
"strings"
)
func createBuffer() *bytes.Buffer {
return new(bytes.Buffer)
}
func main() {
new(bytes.Buffer).WriteString("*bytes.Buffer")
fmt.Fprintln(os.Stderr, "fmt")
new(strings.Builder).WriteString("*strings.Builder")
_, pw := io.Pipe()
pw.CloseWithError(io.EOF)
createBuffer().WriteString("*bytes.Buffer")
b := createBuffer()
b.WriteString("*bytes.Buffer")
}`}, 0, gosec.NewConfig()}} // it shoudn't return any errors because all method calls are whitelisted by default
// SampleCodeG104Audit finds errors that aren't being handled in audit mode
SampleCodeG104Audit = []CodeSample{