Refactor SQL rules for better extensibility (#841)

Remove hardwired assumption and heuristics on index of arg taking a SQL
string, be explicit about it instead.
This commit is contained in:
Ville Skyttä 2022-08-02 16:25:30 +03:00 committed by GitHub
parent 1b0873a235
commit 6a26c231fc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -15,9 +15,9 @@
package rules package rules
import ( import (
"fmt"
"go/ast" "go/ast"
"regexp" "regexp"
"strings"
"github.com/securego/gosec/v2" "github.com/securego/gosec/v2"
) )
@ -30,6 +30,51 @@ type sqlStatement struct {
patterns []*regexp.Regexp patterns []*regexp.Regexp
} }
var sqlCallIdents = map[string]map[string]int{
"*database/sql.DB": {
"Exec": 0,
"ExecContext": 1,
"Query": 0,
"QueryContext": 1,
"QueryRow": 0,
"QueryRowContext": 1,
"Prepare": 0,
"PrepareContext": 1,
},
"*database/sql.Tx": {
"Exec": 0,
"ExecContext": 1,
"Query": 0,
"QueryContext": 1,
"QueryRow": 0,
"QueryRowContext": 1,
"Prepare": 0,
"PrepareContext": 1,
},
}
// findQueryArg locates the argument taking raw SQL
func findQueryArg(call *ast.CallExpr, ctx *gosec.Context) (ast.Expr, error) {
typeName, fnName, err := gosec.GetCallInfo(call, ctx)
if err != nil {
return nil, err
}
i := -1
if ni, ok := sqlCallIdents[typeName]; ok {
if i, ok = ni[fnName]; !ok {
i = -1
}
}
if i == -1 {
return nil, fmt.Errorf("SQL argument index not found for %s.%s", typeName, fnName)
}
if i >= len(call.Args) {
return nil, nil
}
query := call.Args[i]
return query, nil
}
func (s *sqlStatement) ID() string { func (s *sqlStatement) ID() string {
return s.MetaData.ID return s.MetaData.ID
} }
@ -69,16 +114,10 @@ func (s *sqlStrConcat) checkObject(n *ast.Ident, c *gosec.Context) bool {
// checkQuery verifies if the query parameters is a string concatenation // checkQuery verifies if the query parameters is a string concatenation
func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) { func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) {
_, fnName, err := gosec.GetCallInfo(call, ctx) query, err := findQueryArg(call, ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var query ast.Node
if strings.HasSuffix(fnName, "Context") {
query = call.Args[1]
} else {
query = call.Args[0]
}
if be, ok := query.(*ast.BinaryExpr); ok { if be, ok := query.(*ast.BinaryExpr); ok {
operands := gosec.GetBinaryExprOperands(be) operands := gosec.GetBinaryExprOperands(be)
@ -137,8 +176,11 @@ func NewSQLStrConcat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
}, },
} }
rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext") for s, si := range sqlCallIdents {
rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext") for i := range si {
rule.Add(s, i)
}
}
return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)} return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)}
} }
@ -171,16 +213,10 @@ func (s *sqlStrFormat) constObject(e ast.Expr, c *gosec.Context) bool {
} }
func (s *sqlStrFormat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) { func (s *sqlStrFormat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) {
_, fnName, err := gosec.GetCallInfo(call, ctx) query, err := findQueryArg(call, ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var query ast.Node
if strings.HasSuffix(fnName, "Context") {
query = call.Args[1]
} else {
query = call.Args[0]
}
if ident, ok := query.(*ast.Ident); ok && ident.Obj != nil { if ident, ok := query.(*ast.Ident); ok && ident.Obj != nil {
decl := ident.Obj.Decl decl := ident.Obj.Decl
@ -306,8 +342,11 @@ func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
}, },
}, },
} }
rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext") for s, si := range sqlCallIdents {
rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext") for i := range si {
rule.Add(s, i)
}
}
rule.fmtCalls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf") rule.fmtCalls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf")
rule.noIssue.AddAll("os", "Stdout", "Stderr") rule.noIssue.AddAll("os", "Stdout", "Stderr")
rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier") rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier")