One approach for fixing the false positive identified in #325.

This commit is contained in:
Ben Bytheway 2019-06-12 13:53:17 -06:00 committed by Grant Murphy
parent 196edd34b6
commit 04dc713f22
2 changed files with 45 additions and 5 deletions

View file

@ -110,6 +110,26 @@ type sqlStrFormat struct {
noIssueQuoted gosec.CallList noIssueQuoted gosec.CallList
} }
// see if we can figure out what it is
func (s *sqlStrFormat) constObject(e ast.Expr, c *gosec.Context) bool {
n, ok := e.(*ast.Ident)
if !ok {
return false
}
if n.Obj != nil {
return n.Obj.Kind == ast.Con
}
// Try to resolve unresolved identifiers using other files in same package
for _, file := range c.PkgFiles {
if node, ok := file.Scope.Objects[n.String()]; ok {
return node.Kind == ast.Con
}
}
return false
}
// Looks for "fmt.Sprintf("SELECT * FROM foo where '%s', userInput)" // Looks for "fmt.Sprintf("SELECT * FROM foo where '%s', userInput)"
func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
@ -153,16 +173,16 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
return nil, nil return nil, nil
} }
// If all formatter args are quoted, then the SQL construction is safe // If all formatter args are quoted or constant, then the SQL construction is safe
if argIndex+1 < len(node.Args) { if argIndex+1 < len(node.Args) {
allQuoted := true allSafe := true
for _, arg := range node.Args[argIndex+1:] { for _, arg := range node.Args[argIndex+1:] {
if n := s.noIssueQuoted.ContainsCallExpr(arg, c, true); n == nil { if n := s.noIssueQuoted.ContainsCallExpr(arg, c, true); n == nil && !s.constObject(arg, c) {
allQuoted = false allSafe = false
break break
} }
} }
if allQuoted { if allSafe {
return nil, nil return nil, nil
} }
} }

View file

@ -437,6 +437,26 @@ func main(){
} }
defer rows.Close() defer rows.Close()
}`}, 0}, {[]string{` }`}, 0}, {[]string{`
// false positive
package main
import (
"database/sql"
"fmt"
)
const Table = "foo"
func main(){
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
q := fmt.Sprintf("SELECT * FROM %s where id = 1", Table)
rows, err := db.Query(q)
if err != nil {
panic(err)
}
defer rows.Close()
}`}, 0}, {[]string{`
package main package main
import ( import (
"fmt" "fmt"