diff --git a/rules/sql.go b/rules/sql.go index 7aed1b6..ccee0a6 100644 --- a/rules/sql.go +++ b/rules/sql.go @@ -110,6 +110,26 @@ type sqlStrFormat struct { noIssueQuoted gosec.CallList } +// see if we can figure out what it is +func (s *sqlStrFormat) constObject(e ast.Expr, c *gosec.Context) bool { + n, ok := e.(*ast.Ident) + if !ok { + return false + } + + if n.Obj != nil { + return n.Obj.Kind == ast.Con + } + + // Try to resolve unresolved identifiers using other files in same package + for _, file := range c.PkgFiles { + if node, ok := file.Scope.Objects[n.String()]; ok { + return node.Kind == ast.Con + } + } + return false +} + // Looks for "fmt.Sprintf("SELECT * FROM foo where '%s', userInput)" func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { @@ -153,16 +173,16 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) return nil, nil } - // If all formatter args are quoted, then the SQL construction is safe + // If all formatter args are quoted or constant, then the SQL construction is safe if argIndex+1 < len(node.Args) { - allQuoted := true + allSafe := true for _, arg := range node.Args[argIndex+1:] { - if n := s.noIssueQuoted.ContainsCallExpr(arg, c, true); n == nil { - allQuoted = false + if n := s.noIssueQuoted.ContainsCallExpr(arg, c, true); n == nil && !s.constObject(arg, c) { + allSafe = false break } } - if allQuoted { + if allSafe { return nil, nil } } diff --git a/testutils/source.go b/testutils/source.go index 911450b..e9160c9 100644 --- a/testutils/source.go +++ b/testutils/source.go @@ -437,6 +437,26 @@ func main(){ } defer rows.Close() }`}, 0}, {[]string{` +// false positive +package main +import ( + "database/sql" + "fmt" +) + +const Table = "foo" +func main(){ + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + q := fmt.Sprintf("SELECT * FROM %s where id = 1", Table) + rows, err := db.Query(q) + if err != nil { + panic(err) + } + defer rows.Close() +}`}, 0}, {[]string{` package main import ( "fmt"