From 762ff3a70904bf2703f7119ae42f68041a63c1f1 Mon Sep 17 00:00:00 2001 From: Dale Hui Date: Tue, 25 Sep 2018 00:40:05 -0700 Subject: [PATCH] Allow quoted strings to be used to format SQL queries (#240) * Support stripping vendor paths when matching calls * Factor out matching of formatter string * Quoted strings are safe to use with SQL str formatted strings * Add test for allowing quoted strings with string formatters * Install the pq package for tests to pass --- .travis.yml | 1 + call_list.go | 22 +++++++++++++++++----- call_list_test.go | 2 +- rules/archive.go | 2 +- rules/bind.go | 2 +- rules/errors.go | 4 ++-- rules/readfile.go | 16 ++++++++-------- rules/rsa.go | 2 +- rules/sql.go | 41 +++++++++++++++++++++++++++++++---------- rules/ssrf.go | 2 +- rules/subproc.go | 2 +- rules/tempfiles.go | 2 +- rules/templates.go | 2 +- testutils/source.go | 21 +++++++++++++++++++++ 14 files changed, 88 insertions(+), 33 deletions(-) diff --git a/.travis.yml b/.travis.yml index ad4ee96..28f1a3e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,6 +12,7 @@ install: - go get -u github.com/onsi/ginkgo/ginkgo - go get -u github.com/onsi/gomega - go get -u golang.org/x/crypto/ssh + - go get -u github.com/lib/pq - go get -u github.com/securego/gosec/cmd/gosec/... - go get -v -t ./... - export PATH=$PATH:$HOME/gopath/bin diff --git a/call_list.go b/call_list.go index 8370f8f..556a1e8 100644 --- a/call_list.go +++ b/call_list.go @@ -15,8 +15,11 @@ package gosec import ( "go/ast" + "strings" ) +const vendorPath = "vendor/" + type set map[string]bool // CallList is used to check for usage of specific packages @@ -55,17 +58,27 @@ func (c CallList) Contains(selector, ident string) bool { // ContainsCallExpr resolves the call expression name and type /// or package and determines if it exists within the CallList -func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) *ast.CallExpr { +func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context, stripVendor bool) *ast.CallExpr { selector, ident, err := GetCallInfo(n, ctx) if err != nil { return nil } - // Use only explicit path to reduce conflicts - if path, ok := GetImportPath(selector, ctx); ok && c.Contains(path, ident) { - return n.(*ast.CallExpr) + // Use only explicit path (optionally strip vendor path prefix) to reduce conflicts + path, ok := GetImportPath(selector, ctx) + if !ok { + return nil + } + if stripVendor { + if vendorIdx := strings.Index(path, vendorPath); vendorIdx >= 0 { + path = path[vendorIdx+len(vendorPath):] + } + } + if !c.Contains(path, ident) { + return nil } + return n.(*ast.CallExpr) /* // Try direct resolution if c.Contains(selector, ident) { @@ -74,5 +87,4 @@ func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) *ast.CallExpr { } */ - return nil } diff --git a/call_list_test.go b/call_list_test.go index 41aa51d..284e2e8 100644 --- a/call_list_test.go +++ b/call_list_test.go @@ -73,7 +73,7 @@ var _ = Describe("call list", func() { v := testutils.NewMockVisitor() v.Context = ctx v.Callback = func(n ast.Node, ctx *gosec.Context) bool { - if _, ok := n.(*ast.CallExpr); ok && calls.ContainsCallExpr(n, ctx) != nil { + if _, ok := n.(*ast.CallExpr); ok && calls.ContainsCallExpr(n, ctx, false) != nil { matched++ } return true diff --git a/rules/archive.go b/rules/archive.go index 1fa2b40..55f390c 100644 --- a/rules/archive.go +++ b/rules/archive.go @@ -19,7 +19,7 @@ func (a *archive) ID() string { // Match inspects AST nodes to determine if the filepath.Joins uses any argument derived from type zip.File func (a *archive) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { - if node := a.calls.ContainsCallExpr(n, c); node != nil { + if node := a.calls.ContainsCallExpr(n, c, false); node != nil { for _, arg := range node.Args { var argType types.Type if selector, ok := arg.(*ast.SelectorExpr); ok { diff --git a/rules/bind.go b/rules/bind.go index a7d599b..1448d03 100644 --- a/rules/bind.go +++ b/rules/bind.go @@ -33,7 +33,7 @@ func (r *bindsToAllNetworkInterfaces) ID() string { } func (r *bindsToAllNetworkInterfaces) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { - callExpr := r.calls.ContainsCallExpr(n, c) + callExpr := r.calls.ContainsCallExpr(n, c, false) if callExpr == nil { return nil, nil } diff --git a/rules/errors.go b/rules/errors.go index 5aea57d..1c8785e 100644 --- a/rules/errors.go +++ b/rules/errors.go @@ -53,7 +53,7 @@ func (r *noErrorCheck) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, erro switch stmt := n.(type) { case *ast.AssignStmt: for _, expr := range stmt.Rhs { - if callExpr, ok := expr.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(expr, ctx) == nil { + if callExpr, ok := expr.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(expr, ctx, false) == nil { pos := returnsError(callExpr, ctx) if pos < 0 || pos >= len(stmt.Lhs) { return nil, nil @@ -64,7 +64,7 @@ func (r *noErrorCheck) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, erro } } case *ast.ExprStmt: - if callExpr, ok := stmt.X.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(stmt.X, ctx) == nil { + if callExpr, ok := stmt.X.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(stmt.X, ctx, false) == nil { pos := returnsError(callExpr, ctx) if pos >= 0 { return gosec.NewIssue(ctx, n, r.ID(), r.What, r.Severity, r.Confidence), nil diff --git a/rules/readfile.go b/rules/readfile.go index 2b38852..dbd18b0 100644 --- a/rules/readfile.go +++ b/rules/readfile.go @@ -34,7 +34,7 @@ func (r *readfile) ID() string { // isJoinFunc checks if there is a filepath.Join or other join function func (r *readfile) isJoinFunc(n ast.Node, c *gosec.Context) bool { - if call := r.pathJoin.ContainsCallExpr(n, c); call != nil { + if call := r.pathJoin.ContainsCallExpr(n, c, false); call != nil { for _, arg := range call.Args { // edge case: check if one of the args is a BinaryExpr if binExp, ok := arg.(*ast.BinaryExpr); ok { @@ -44,21 +44,21 @@ func (r *readfile) isJoinFunc(n ast.Node, c *gosec.Context) bool { } } - // try and resolve identity - if ident, ok := arg.(*ast.Ident); ok { - obj := c.Info.ObjectOf(ident) - if _, ok := obj.(*types.Var); ok && !gosec.TryResolve(ident, c) { - return true + // try and resolve identity + if ident, ok := arg.(*ast.Ident); ok { + obj := c.Info.ObjectOf(ident) + if _, ok := obj.(*types.Var); ok && !gosec.TryResolve(ident, c) { + return true + } } } } -} return false } // Match inspects AST nodes to determine if the match the methods `os.Open` or `ioutil.ReadFile` func (r *readfile) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { - if node := r.ContainsCallExpr(n, c); node != nil { + if node := r.ContainsCallExpr(n, c, false); node != nil { for _, arg := range node.Args { // handles path joining functions in Arg // eg. os.Open(filepath.Join("/tmp/", file)) diff --git a/rules/rsa.go b/rules/rsa.go index 4a42905..8f17afe 100644 --- a/rules/rsa.go +++ b/rules/rsa.go @@ -32,7 +32,7 @@ func (w *weakKeyStrength) ID() string { } func (w *weakKeyStrength) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { - if callExpr := w.calls.ContainsCallExpr(n, c); callExpr != nil { + if callExpr := w.calls.ContainsCallExpr(n, c, false); callExpr != nil { if bits, err := gosec.GetInt(callExpr.Args[1]); err == nil && bits < (int64)(w.bits) { return gosec.NewIssue(c, n, w.ID(), w.What, w.Severity, w.Confidence), nil } diff --git a/rules/sql.go b/rules/sql.go index 38f8598..18d8937 100644 --- a/rules/sql.go +++ b/rules/sql.go @@ -98,8 +98,9 @@ func NewSQLStrConcat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) { type sqlStrFormat struct { sqlStatement - calls gosec.CallList - noIssue gosec.CallList + calls gosec.CallList + noIssue gosec.CallList + noIssueQuoted gosec.CallList } // Looks for "fmt.Sprintf("SELECT * FROM foo where '%s', userInput)" @@ -109,7 +110,7 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) argIndex := 0 // TODO(gm) improve confidence if database/sql is being used - if node := s.calls.ContainsCallExpr(n, c); node != nil { + if node := s.calls.ContainsCallExpr(n, c, false); node != nil { // if the function is fmt.Fprintf, search for SQL statement in Args[1] instead if sel, ok := node.Fun.(*ast.SelectorExpr); ok { if sel.Sel.Name == "Fprintf" { @@ -125,17 +126,35 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) argIndex = 1 } } + + var formatter string + // concats callexpr arg strings together if needed before regex evaluation if argExpr, ok := node.Args[argIndex].(*ast.BinaryExpr); ok { if fullStr, ok := gosec.ConcatString(argExpr); ok { - if s.MatchPatterns(fullStr) { - return gosec.NewIssue(c, n, s.ID(), s.What, s.Severity, s.Confidence), - nil - } + formatter = fullStr } + } else if arg, e := gosec.GetString(node.Args[argIndex]); e == nil { + formatter = arg + } + if len(formatter) <= 0 { + return nil, nil } - if arg, e := gosec.GetString(node.Args[argIndex]); s.MatchPatterns(arg) && e == nil { + // If all formatter args are quoted, then the SQL construction is safe + if argIndex+1 < len(node.Args) { + allQuoted := true + for _, arg := range node.Args[argIndex+1:] { + if n := s.noIssueQuoted.ContainsCallExpr(arg, c, true); n == nil { + allQuoted = false + break + } + } + if allQuoted { + return nil, nil + } + } + if s.MatchPatterns(formatter) { return gosec.NewIssue(c, n, s.ID(), s.What, s.Severity, s.Confidence), nil } } @@ -145,8 +164,9 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) // NewSQLStrFormat looks for cases where we're building SQL query strings using format strings func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) { rule := &sqlStrFormat{ - calls: gosec.NewCallList(), - noIssue: gosec.NewCallList(), + calls: gosec.NewCallList(), + noIssue: gosec.NewCallList(), + noIssueQuoted: gosec.NewCallList(), sqlStatement: sqlStatement{ patterns: []*regexp.Regexp{ regexp.MustCompile("(?)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) "), @@ -162,5 +182,6 @@ func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) { } rule.calls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf") rule.noIssue.AddAll("os", "Stdout", "Stderr") + rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier") return rule, []ast.Node{(*ast.CallExpr)(nil)} } diff --git a/rules/ssrf.go b/rules/ssrf.go index 9be9b40..34aa5d4 100644 --- a/rules/ssrf.go +++ b/rules/ssrf.go @@ -35,7 +35,7 @@ func (r *ssrf) ResolveVar(n *ast.CallExpr, c *gosec.Context) bool { // Match inspects AST nodes to determine if certain net/http methods are called with variable input func (r *ssrf) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { // Call expression is using http package directly - if node := r.ContainsCallExpr(n, c); node != nil { + if node := r.ContainsCallExpr(n, c, false); node != nil { if r.ResolveVar(node, c) { return gosec.NewIssue(c, n, r.ID(), r.What, r.Severity, r.Confidence), nil } diff --git a/rules/subproc.go b/rules/subproc.go index 80a3464..809d808 100644 --- a/rules/subproc.go +++ b/rules/subproc.go @@ -40,7 +40,7 @@ func (r *subprocess) ID() string { // // syscall.Exec("echo", "foobar" + tainted) func (r *subprocess) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { - if node := r.ContainsCallExpr(n, c); node != nil { + if node := r.ContainsCallExpr(n, c, false); node != nil { for _, arg := range node.Args { if ident, ok := arg.(*ast.Ident); ok { obj := c.Info.ObjectOf(ident) diff --git a/rules/tempfiles.go b/rules/tempfiles.go index 6963404..095544d 100644 --- a/rules/tempfiles.go +++ b/rules/tempfiles.go @@ -32,7 +32,7 @@ func (t *badTempFile) ID() string { } func (t *badTempFile) Match(n ast.Node, c *gosec.Context) (gi *gosec.Issue, err error) { - if node := t.calls.ContainsCallExpr(n, c); node != nil { + if node := t.calls.ContainsCallExpr(n, c, false); node != nil { if arg, e := gosec.GetString(node.Args[0]); t.args.MatchString(arg) && e == nil { return gosec.NewIssue(c, n, t.ID(), t.What, t.Severity, t.Confidence), nil } diff --git a/rules/templates.go b/rules/templates.go index 30a964b..c452a28 100644 --- a/rules/templates.go +++ b/rules/templates.go @@ -30,7 +30,7 @@ func (t *templateCheck) ID() string { } func (t *templateCheck) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { - if node := t.calls.ContainsCallExpr(n, c); node != nil { + if node := t.calls.ContainsCallExpr(n, c, false); node != nil { for _, arg := range node.Args { if _, ok := arg.(*ast.BasicLit); !ok { // basic lits are safe return gosec.NewIssue(c, n, t.ID(), t.What, t.Severity, t.Confidence), nil diff --git a/testutils/source.go b/testutils/source.go index e0a6834..4e502b6 100644 --- a/testutils/source.go +++ b/testutils/source.go @@ -292,6 +292,27 @@ func main(){ panic(err) } defer rows.Close() +}`, 0}, {` +// Format string false positive, quoted formatter argument. +package main +import ( + "database/sql" + "fmt" + "os" + "github.com/lib/pq" +) + +func main(){ + db, err := sql.Open("postgres", "localhost") + if err != nil { + panic(err) + } + q := fmt.Sprintf("SELECT * FROM %s where id = 1", pq.QuoteIdentifier(os.Args[1])) + rows, err := db.Query(q) + if err != nil { + panic(err) + } + defer rows.Close() }`, 0}} // SampleCodeG202 - SQL query string building via string concatenation