From 3e069e7756ba3f295c89db12e1252eba3c8d0474 Mon Sep 17 00:00:00 2001 From: Cosmin Cojocar Date: Tue, 28 Jan 2020 14:11:00 +0100 Subject: [PATCH] Fix the errors rule whitelist to work on types methods Signed-off-by: Cosmin Cojocar --- call_list.go | 41 ++++++++--- call_list_test.go | 46 ++++++++++-- helpers.go | 29 ++++++++ helpers_test.go | 138 ++++++++++++++++++++++++++++++++++++ rules/archive.go | 2 +- rules/bind.go | 2 +- rules/decompression-bomb.go | 12 ++-- rules/errors.go | 4 +- rules/integer_overflow.go | 5 +- rules/readfile.go | 4 +- rules/rsa.go | 2 +- rules/sql.go | 4 +- rules/ssrf.go | 2 +- rules/subproc.go | 2 +- rules/tempfiles.go | 2 +- rules/templates.go | 2 +- testutils/source.go | 27 ++++++- 17 files changed, 286 insertions(+), 38 deletions(-) diff --git a/call_list.go b/call_list.go index 556a1e8..f22c737 100644 --- a/call_list.go +++ b/call_list.go @@ -56,9 +56,22 @@ func (c CallList) Contains(selector, ident string) bool { return false } -// ContainsCallExpr resolves the call expression name and type -/// or package and determines if it exists within the CallList -func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context, stripVendor bool) *ast.CallExpr { +// ContainsPointer returns true if a pointer to the selector type or the type +// itslef is a members of this call list. +func (c CallList) ContainsPointer(selector, indent string) bool { + if strings.HasPrefix(selector, "*") { + if c.Contains(selector, indent) { + return true + } + s := strings.TrimPrefix(selector, "*") + return c.Contains(s, indent) + } + return false +} + +// ContainsPkgCallExpr resolves the call expression name and type, and then further looks +// up the package path for that type. Finally, it determines if the call exists within the CallList +func (c CallList) ContainsPkgCallExpr(n ast.Node, ctx *Context, stripVendor bool) *ast.CallExpr { selector, ident, err := GetCallInfo(n, ctx) if err != nil { return nil @@ -79,12 +92,18 @@ func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context, stripVendor bool) * } return n.(*ast.CallExpr) - /* - // Try direct resolution - if c.Contains(selector, ident) { - log.Printf("c.Contains == true, %s, %s.", selector, ident) - return n.(*ast.CallExpr) - } - */ - +} + +// ContainsCallExpr resolves the call experssion name and type, and then determines +// if the call existis with the call list +func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) *ast.CallExpr { + selector, ident, err := GetCallInfo(n, ctx) + if err != nil { + return nil + } + if !c.Contains(selector, ident) && !c.ContainsPointer(selector, ident) { + return nil + } + + return n.(*ast.CallExpr) } diff --git a/call_list_test.go b/call_list_test.go index 4f1c72b..99f3b8d 100644 --- a/call_list_test.go +++ b/call_list_test.go @@ -46,6 +46,20 @@ var _ = Describe("Call List", func() { Expect(actual).Should(Equal(expected)) }) + It("should be possible to add pointer call", func() { + Expect(calls).Should(HaveLen(0)) + calls.Add("*bytes.Buffer", "WriteString") + actual := calls.ContainsPointer("*bytes.Buffer", "WriteString") + Expect(actual).Should(BeTrue()) + }) + + It("should be possible to check pointer call", func() { + Expect(calls).Should(HaveLen(0)) + calls.Add("bytes.Buffer", "WriteString") + actual := calls.ContainsPointer("*bytes.Buffer", "WriteString") + Expect(actual).Should(BeTrue()) + }) + It("should not return a match if none are present", func() { calls.Add("ioutil", "Copy") Expect(calls.Contains("fmt", "Println")).Should(BeFalse()) @@ -56,8 +70,7 @@ var _ = Describe("Call List", func() { Expect(calls.Contains("ioutil", "Copy")).Should(BeTrue()) }) - It("should match a call expression", func() { - + It("should match a package call expression", func() { // Create file to be scanned pkg := testutils.NewTestPackage() defer pkg.Close() @@ -73,14 +86,39 @@ var _ = Describe("Call List", func() { v := testutils.NewMockVisitor() v.Context = ctx v.Callback = func(n ast.Node, ctx *gosec.Context) bool { - if _, ok := n.(*ast.CallExpr); ok && calls.ContainsCallExpr(n, ctx, false) != nil { + if _, ok := n.(*ast.CallExpr); ok && calls.ContainsPkgCallExpr(n, ctx, false) != nil { matched++ } return true } ast.Walk(v, ctx.Root) Expect(matched).Should(Equal(1)) - }) + It("should match a call expression", func() { + // Create file to be scanned + pkg := testutils.NewTestPackage() + defer pkg.Close() + pkg.AddFile("main.go", testutils.SampleCodeG104[5].Code[0]) + + ctx := pkg.CreateContext("main.go") + + calls.Add("bytes.Buffer", "WriteString") + calls.Add("strings.Builder", "WriteString") + calls.Add("io.Pipe", "CloseWithError") + calls.Add("fmt", "Fprintln") + + // Stub out visitor and count number of matched call expr + matched := 0 + v := testutils.NewMockVisitor() + v.Context = ctx + v.Callback = func(n ast.Node, ctx *gosec.Context) bool { + if _, ok := n.(*ast.CallExpr); ok && calls.ContainsCallExpr(n, ctx) != nil { + matched++ + } + return true + } + ast.Walk(v, ctx.Root) + Expect(matched).Should(Equal(5)) + }) }) diff --git a/helpers.go b/helpers.go index d708626..473ddef 100644 --- a/helpers.go +++ b/helpers.go @@ -135,11 +135,40 @@ func GetCallInfo(n ast.Node, ctx *Context) (string, string, error) { return "undefined", fn.Sel.Name, fmt.Errorf("missing type info") } return expr.Name, fn.Sel.Name, nil + case *ast.CallExpr: + switch call := expr.Fun.(type) { + case *ast.Ident: + if call.Name == "new" { + t := ctx.Info.TypeOf(expr.Args[0]) + if t != nil { + return t.String(), fn.Sel.Name, nil + } + return "undefined", fn.Sel.Name, fmt.Errorf("missing type info") + } + if call.Obj != nil { + switch decl := call.Obj.Decl.(type) { + case *ast.FuncDecl: + ret := decl.Type.Results + if ret != nil && len(ret.List) > 0 { + ret1 := ret.List[0] + if ret1 != nil { + t := ctx.Info.TypeOf(ret1.Type) + if t != nil { + return t.String(), fn.Sel.Name, nil + } + return "undefined", fn.Sel.Name, fmt.Errorf("missing type info") + } + } + } + } + + } } case *ast.Ident: return ctx.Pkg.Name(), fn.Name, nil } } + return "", "", fmt.Errorf("unable to determine call info") } diff --git a/helpers_test.go b/helpers_test.go index f806758..b9c280f 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -1,6 +1,7 @@ package gosec_test import ( + "go/ast" "io/ioutil" "os" "path/filepath" @@ -9,6 +10,7 @@ import ( . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/securego/gosec" + "github.com/securego/gosec/testutils" ) var _ = Describe("Helpers", func() { @@ -91,4 +93,140 @@ var _ = Describe("Helpers", func() { Expect(len(r)).Should(Equal(0)) }) }) + + Context("when getting call info", func() { + It("should return the type and call name for selector expression", func() { + pkg := testutils.NewTestPackage() + defer pkg.Close() + pkg.AddFile("main.go", ` + package main + + import( + "bytes" + ) + + func main() { + b := new(bytes.Buffer) + _, err := b.WriteString("test") + if err != nil { + panic(err) + } + } + `) + ctx := pkg.CreateContext("main.go") + result := map[string]string{} + visitor := testutils.NewMockVisitor() + visitor.Context = ctx + visitor.Callback = func(n ast.Node, ctx *gosec.Context) bool { + typeName, call, err := gosec.GetCallInfo(n, ctx) + if err == nil { + result[typeName] = call + } + return true + } + ast.Walk(visitor, ctx.Root) + + Expect(result).Should(HaveKeyWithValue("*bytes.Buffer", "WriteString")) + }) + + It("should return the type and call name for new selector expression", func() { + pkg := testutils.NewTestPackage() + defer pkg.Close() + pkg.AddFile("main.go", ` + package main + + import( + "bytes" + ) + + func main() { + _, err := new(bytes.Buffer).WriteString("test") + if err != nil { + panic(err) + } + } + `) + ctx := pkg.CreateContext("main.go") + result := map[string]string{} + visitor := testutils.NewMockVisitor() + visitor.Context = ctx + visitor.Callback = func(n ast.Node, ctx *gosec.Context) bool { + typeName, call, err := gosec.GetCallInfo(n, ctx) + if err == nil { + result[typeName] = call + } + return true + } + ast.Walk(visitor, ctx.Root) + + Expect(result).Should(HaveKeyWithValue("bytes.Buffer", "WriteString")) + }) + + It("should return the type and call name for function selector expression", func() { + pkg := testutils.NewTestPackage() + defer pkg.Close() + pkg.AddFile("main.go", ` + package main + + import( + "bytes" + ) + + func createBuffer() *bytes.Buffer { + return new(bytes.Buffer) + } + + func main() { + _, err := createBuffer().WriteString("test") + if err != nil { + panic(err) + } + } + `) + ctx := pkg.CreateContext("main.go") + result := map[string]string{} + visitor := testutils.NewMockVisitor() + visitor.Context = ctx + visitor.Callback = func(n ast.Node, ctx *gosec.Context) bool { + typeName, call, err := gosec.GetCallInfo(n, ctx) + if err == nil { + result[typeName] = call + } + return true + } + ast.Walk(visitor, ctx.Root) + + Expect(result).Should(HaveKeyWithValue("*bytes.Buffer", "WriteString")) + }) + + It("should return the type and call name for package function", func() { + pkg := testutils.NewTestPackage() + defer pkg.Close() + pkg.AddFile("main.go", ` + package main + + import( + "fmt" + ) + + func main() { + fmt.Println("test") + } + `) + ctx := pkg.CreateContext("main.go") + result := map[string]string{} + visitor := testutils.NewMockVisitor() + visitor.Context = ctx + visitor.Callback = func(n ast.Node, ctx *gosec.Context) bool { + typeName, call, err := gosec.GetCallInfo(n, ctx) + if err == nil { + result[typeName] = call + } + return true + } + ast.Walk(visitor, ctx.Root) + + Expect(result).Should(HaveKeyWithValue("fmt", "Println")) + }) + }) }) diff --git a/rules/archive.go b/rules/archive.go index 55f390c..d9e7ea9 100644 --- a/rules/archive.go +++ b/rules/archive.go @@ -19,7 +19,7 @@ func (a *archive) ID() string { // Match inspects AST nodes to determine if the filepath.Joins uses any argument derived from type zip.File func (a *archive) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { - if node := a.calls.ContainsCallExpr(n, c, false); node != nil { + if node := a.calls.ContainsPkgCallExpr(n, c, false); node != nil { for _, arg := range node.Args { var argType types.Type if selector, ok := arg.(*ast.SelectorExpr); ok { diff --git a/rules/bind.go b/rules/bind.go index 7273588..eb2ed05 100644 --- a/rules/bind.go +++ b/rules/bind.go @@ -33,7 +33,7 @@ func (r *bindsToAllNetworkInterfaces) ID() string { } func (r *bindsToAllNetworkInterfaces) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { - callExpr := r.calls.ContainsCallExpr(n, c, false) + callExpr := r.calls.ContainsPkgCallExpr(n, c, false) if callExpr == nil { return nil, nil } diff --git a/rules/decompression-bomb.go b/rules/decompression-bomb.go index 2c71be9..9b9caf5 100644 --- a/rules/decompression-bomb.go +++ b/rules/decompression-bomb.go @@ -16,8 +16,9 @@ package rules import ( "fmt" - "github.com/securego/gosec" "go/ast" + + "github.com/securego/gosec" ) type decompressionBombCheck struct { @@ -31,15 +32,12 @@ func (d *decompressionBombCheck) ID() string { } func containsReaderCall(node ast.Node, ctx *gosec.Context, list gosec.CallList) bool { - if list.ContainsCallExpr(node, ctx, false) != nil { + if list.ContainsPkgCallExpr(node, ctx, false) != nil { return true } // Resolve type info of ident (for *archive/zip.File.Open) s, idt, _ := gosec.GetCallInfo(node, ctx) - if list.Contains(s, idt) { - return true - } - return false + return list.Contains(s, idt) } func (d *decompressionBombCheck) Match(node ast.Node, ctx *gosec.Context) (*gosec.Issue, error) { @@ -70,7 +68,7 @@ func (d *decompressionBombCheck) Match(node ast.Node, ctx *gosec.Context) (*gose } } case *ast.CallExpr: - if d.copyCalls.ContainsCallExpr(n, ctx, false) != nil { + if d.copyCalls.ContainsPkgCallExpr(n, ctx, false) != nil { if idt, ok := n.Args[1].(*ast.Ident); ok { if _, ok := readerVarObj[idt.Obj]; ok { // Detect io.Copy(x, r) diff --git a/rules/errors.go b/rules/errors.go index d2e98b5..de3163e 100644 --- a/rules/errors.go +++ b/rules/errors.go @@ -55,7 +55,7 @@ func (r *noErrorCheck) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, erro cfg := ctx.Config if enabled, err := cfg.IsGlobalEnabled(gosec.Audit); err == nil && enabled { for _, expr := range stmt.Rhs { - if callExpr, ok := expr.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(expr, ctx, false) == nil { + if callExpr, ok := expr.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(expr, ctx) == nil { pos := returnsError(callExpr, ctx) if pos < 0 || pos >= len(stmt.Lhs) { return nil, nil @@ -67,7 +67,7 @@ func (r *noErrorCheck) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, erro } } case *ast.ExprStmt: - if callExpr, ok := stmt.X.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(stmt.X, ctx, false) == nil { + if callExpr, ok := stmt.X.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(stmt.X, ctx) == nil { pos := returnsError(callExpr, ctx) if pos >= 0 { return gosec.NewIssue(ctx, n, r.ID(), r.What, r.Severity, r.Confidence), nil diff --git a/rules/integer_overflow.go b/rules/integer_overflow.go index 33a3139..311ff52 100644 --- a/rules/integer_overflow.go +++ b/rules/integer_overflow.go @@ -16,8 +16,9 @@ package rules import ( "fmt" - "github.com/securego/gosec" "go/ast" + + "github.com/securego/gosec" ) type integerOverflowCheck struct { @@ -47,7 +48,7 @@ func (i *integerOverflowCheck) Match(node ast.Node, ctx *gosec.Context) (*gosec. switch n := node.(type) { case *ast.AssignStmt: for _, expr := range n.Rhs { - if callExpr, ok := expr.(*ast.CallExpr); ok && i.calls.ContainsCallExpr(callExpr, ctx, false) != nil { + if callExpr, ok := expr.(*ast.CallExpr); ok && i.calls.ContainsPkgCallExpr(callExpr, ctx, false) != nil { if idt, ok := n.Lhs[0].(*ast.Ident); ok && idt.Name != "_" { // Example: // v, _ := strconv.Atoi("1111") diff --git a/rules/readfile.go b/rules/readfile.go index 87158f0..6464360 100644 --- a/rules/readfile.go +++ b/rules/readfile.go @@ -34,7 +34,7 @@ func (r *readfile) ID() string { // isJoinFunc checks if there is a filepath.Join or other join function func (r *readfile) isJoinFunc(n ast.Node, c *gosec.Context) bool { - if call := r.pathJoin.ContainsCallExpr(n, c, false); call != nil { + if call := r.pathJoin.ContainsPkgCallExpr(n, c, false); call != nil { for _, arg := range call.Args { // edge case: check if one of the args is a BinaryExpr if binExp, ok := arg.(*ast.BinaryExpr); ok { @@ -58,7 +58,7 @@ func (r *readfile) isJoinFunc(n ast.Node, c *gosec.Context) bool { // Match inspects AST nodes to determine if the match the methods `os.Open` or `ioutil.ReadFile` func (r *readfile) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { - if node := r.ContainsCallExpr(n, c, false); node != nil { + if node := r.ContainsPkgCallExpr(n, c, false); node != nil { for _, arg := range node.Args { // handles path joining functions in Arg // eg. os.Open(filepath.Join("/tmp/", file)) diff --git a/rules/rsa.go b/rules/rsa.go index 8f17afe..1ceee37 100644 --- a/rules/rsa.go +++ b/rules/rsa.go @@ -32,7 +32,7 @@ func (w *weakKeyStrength) ID() string { } func (w *weakKeyStrength) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { - if callExpr := w.calls.ContainsCallExpr(n, c, false); callExpr != nil { + if callExpr := w.calls.ContainsPkgCallExpr(n, c, false); callExpr != nil { if bits, err := gosec.GetInt(callExpr.Args[1]); err == nil && bits < (int64)(w.bits) { return gosec.NewIssue(c, n, w.ID(), w.What, w.Severity, w.Confidence), nil } diff --git a/rules/sql.go b/rules/sql.go index ccee0a6..885f105 100644 --- a/rules/sql.go +++ b/rules/sql.go @@ -137,7 +137,7 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) argIndex := 0 // TODO(gm) improve confidence if database/sql is being used - if node := s.calls.ContainsCallExpr(n, c, false); node != nil { + if node := s.calls.ContainsPkgCallExpr(n, c, false); node != nil { // if the function is fmt.Fprintf, search for SQL statement in Args[1] instead if sel, ok := node.Fun.(*ast.SelectorExpr); ok { if sel.Sel.Name == "Fprintf" { @@ -177,7 +177,7 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) if argIndex+1 < len(node.Args) { allSafe := true for _, arg := range node.Args[argIndex+1:] { - if n := s.noIssueQuoted.ContainsCallExpr(arg, c, true); n == nil && !s.constObject(arg, c) { + if n := s.noIssueQuoted.ContainsPkgCallExpr(arg, c, true); n == nil && !s.constObject(arg, c) { allSafe = false break } diff --git a/rules/ssrf.go b/rules/ssrf.go index b1409a5..41b5386 100644 --- a/rules/ssrf.go +++ b/rules/ssrf.go @@ -42,7 +42,7 @@ func (r *ssrf) ResolveVar(n *ast.CallExpr, c *gosec.Context) bool { // Match inspects AST nodes to determine if certain net/http methods are called with variable input func (r *ssrf) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { // Call expression is using http package directly - if node := r.ContainsCallExpr(n, c, false); node != nil { + if node := r.ContainsPkgCallExpr(n, c, false); node != nil { if r.ResolveVar(node, c) { return gosec.NewIssue(c, n, r.ID(), r.What, r.Severity, r.Confidence), nil } diff --git a/rules/subproc.go b/rules/subproc.go index c452898..2513ec1 100644 --- a/rules/subproc.go +++ b/rules/subproc.go @@ -40,7 +40,7 @@ func (r *subprocess) ID() string { // // syscall.Exec("echo", "foobar" + tainted) func (r *subprocess) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { - if node := r.ContainsCallExpr(n, c, false); node != nil { + if node := r.ContainsPkgCallExpr(n, c, false); node != nil { for _, arg := range node.Args { if ident, ok := arg.(*ast.Ident); ok { obj := c.Info.ObjectOf(ident) diff --git a/rules/tempfiles.go b/rules/tempfiles.go index 095544d..4adafd1 100644 --- a/rules/tempfiles.go +++ b/rules/tempfiles.go @@ -32,7 +32,7 @@ func (t *badTempFile) ID() string { } func (t *badTempFile) Match(n ast.Node, c *gosec.Context) (gi *gosec.Issue, err error) { - if node := t.calls.ContainsCallExpr(n, c, false); node != nil { + if node := t.calls.ContainsPkgCallExpr(n, c, false); node != nil { if arg, e := gosec.GetString(node.Args[0]); t.args.MatchString(arg) && e == nil { return gosec.NewIssue(c, n, t.ID(), t.What, t.Severity, t.Confidence), nil } diff --git a/rules/templates.go b/rules/templates.go index 0bff687..3c663f1 100644 --- a/rules/templates.go +++ b/rules/templates.go @@ -30,7 +30,7 @@ func (t *templateCheck) ID() string { } func (t *templateCheck) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { - if node := t.calls.ContainsCallExpr(n, c, false); node != nil { + if node := t.calls.ContainsPkgCallExpr(n, c, false); node != nil { for _, arg := range node.Args { if _, ok := arg.(*ast.BasicLit); !ok { // basic lits are safe return gosec.NewIssue(c, n, t.ID(), t.What, t.Severity, t.Confidence), nil diff --git a/testutils/source.go b/testutils/source.go index 00d066f..ae8761e 100644 --- a/testutils/source.go +++ b/testutils/source.go @@ -245,7 +245,32 @@ func a() { } func main() { a() -}`}, 0, gosec.Config{"G104": map[string]interface{}{"io/ioutil": []interface{}{"WriteFile"}}}}} +}`}, 0, gosec.Config{"G104": map[string]interface{}{"ioutil": []interface{}{"WriteFile"}}}}, {[]string{` +package main + +import ( + "bytes" + "fmt" + "io" + "os" + "strings" +) + +func createBuffer() *bytes.Buffer { + return new(bytes.Buffer) +} + +func main() { + new(bytes.Buffer).WriteString("*bytes.Buffer") + fmt.Fprintln(os.Stderr, "fmt") + new(strings.Builder).WriteString("*strings.Builder") + _, pw := io.Pipe() + pw.CloseWithError(io.EOF) + + createBuffer().WriteString("*bytes.Buffer") + b := createBuffer() + b.WriteString("*bytes.Buffer") +}`}, 0, gosec.NewConfig()}} // it shoudn't return any errors because all method calls are whitelisted by default // SampleCodeG104Audit finds errors that aren't being handled in audit mode SampleCodeG104Audit = []CodeSample{