diff --git a/helpers.go b/helpers.go index 1d84c45..b4c23e5 100644 --- a/helpers.go +++ b/helpers.go @@ -96,11 +96,46 @@ func GetChar(n ast.Node) (byte, error) { return 0, fmt.Errorf("Unexpected AST node type: %T", n) } +// GetStringRecursive will recursively walk down a tree of *ast.BinaryExpr. It will then concat the results, and return. +// Unlike the other getters, it does _not_ raise an error for unknown ast.Node types. At the base, the recursion will hit a non-BinaryExpr type, +// either BasicLit or other, so it's not an error case. It will only error if `strconv.Unquote` errors. This matters, because there's +// currently functionality that relies on error values being returned by GetString if and when it hits a non-basiclit string node type, +// hence for cases where recursion is needed, we use this separate function, so that we can still be backwards compatbile. +// +// This was added to handle a SQL injection concatenation case where the injected value is infixed between two strings, not at the start or end. See example below +// +// Do note that this will omit non-string values. So for example, if you were to use this node: +// ```go +// q := "SELECT * FROM foo WHERE name = '" + os.Args[0] + "' AND 1=1" // will result in "SELECT * FROM foo WHERE ” AND 1=1" + +func GetStringRecursive(n ast.Node) (string, error) { + if node, ok := n.(*ast.BasicLit); ok && node.Kind == token.STRING { + return strconv.Unquote(node.Value) + } + + if expr, ok := n.(*ast.BinaryExpr); ok { + x, err := GetStringRecursive(expr.X) + if err != nil { + return "", err + } + + y, err := GetStringRecursive(expr.Y) + if err != nil { + return "", err + } + + return x + y, nil + } + + return "", nil +} + // GetString will read and return a string value from an ast.BasicLit func GetString(n ast.Node) (string, error) { if node, ok := n.(*ast.BasicLit); ok && node.Kind == token.STRING { return strconv.Unquote(node.Value) } + return "", fmt.Errorf("Unexpected AST node type: %T", n) } @@ -201,22 +236,21 @@ func GetCallStringArgsValues(n ast.Node, _ *Context) []string { return values } -// GetIdentStringValues return the string values of an Ident if they can be resolved -func GetIdentStringValues(ident *ast.Ident) []string { +func getIdentStringValues(ident *ast.Ident, stringFinder func(ast.Node) (string, error)) []string { values := []string{} obj := ident.Obj if obj != nil { switch decl := obj.Decl.(type) { case *ast.ValueSpec: for _, v := range decl.Values { - value, err := GetString(v) + value, err := stringFinder(v) if err == nil { values = append(values, value) } } case *ast.AssignStmt: for _, v := range decl.Rhs { - value, err := GetString(v) + value, err := stringFinder(v) if err == nil { values = append(values, value) } @@ -226,6 +260,18 @@ func GetIdentStringValues(ident *ast.Ident) []string { return values } +// getIdentStringRecursive returns the string of values of an Ident if they can be resolved +// The difference between this and GetIdentStringValues is that it will attempt to resolve the strings recursively, +// if it is passed a *ast.BinaryExpr. See GetStringRecursive for details +func GetIdentStringValuesRecursive(ident *ast.Ident) []string { + return getIdentStringValues(ident, GetStringRecursive) +} + +// GetIdentStringValues return the string values of an Ident if they can be resolved +func GetIdentStringValues(ident *ast.Ident) []string { + return getIdentStringValues(ident, GetString) +} + // GetBinaryExprOperands returns all operands of a binary expression by traversing // the expression tree func GetBinaryExprOperands(be *ast.BinaryExpr) []ast.Node { diff --git a/rules/sql.go b/rules/sql.go index 4085b5d..61222bf 100644 --- a/rules/sql.go +++ b/rules/sql.go @@ -98,6 +98,32 @@ func (s *sqlStrConcat) ID() string { return s.MetaData.ID } +// findInjectionInBranch walks diwb a set if expressions, and will create new issues if it finds SQL injections +// This method assumes you've already verified that the branch contains SQL syntax +func (s *sqlStrConcat) findInjectionInBranch(ctx *gosec.Context, branch []ast.Expr) *ast.BinaryExpr { + for _, node := range branch { + be, ok := node.(*ast.BinaryExpr) + if !ok { + continue + } + + operands := gosec.GetBinaryExprOperands(be) + + for _, op := range operands { + if _, ok := op.(*ast.BasicLit); ok { + continue + } + + if ident, ok := op.(*ast.Ident); ok && s.checkObject(ident, ctx) { + continue + } + + return be + } + } + return nil +} + // see if we can figure out what it is func (s *sqlStrConcat) checkObject(n *ast.Ident, c *gosec.Context) bool { if n.Obj != nil { @@ -140,6 +166,28 @@ func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*issu } } + // Handle the case where an injection occurs as an infixed string concatenation, ie "SELECT * FROM foo WHERE name = '" + os.Args[0] + "' AND 1=1" + if id, ok := query.(*ast.Ident); ok { + var match bool + for _, str := range gosec.GetIdentStringValuesRecursive(id) { + if s.MatchPatterns(str) { + match = true + break + } + } + + if !match { + return nil, nil + } + + switch decl := id.Obj.Decl.(type) { + case *ast.AssignStmt: + if injection := s.findInjectionInBranch(ctx, decl.Rhs); injection != nil { + return ctx.NewIssue(injection, s.ID(), s.What, s.Severity, s.Confidence), nil + } + } + } + return nil, nil } @@ -157,6 +205,7 @@ func (s *sqlStrConcat) Match(n ast.Node, ctx *gosec.Context) (*issue.Issue, erro return s.checkQuery(sqlQueryCall, ctx) } } + return nil, nil } @@ -165,7 +214,7 @@ func NewSQLStrConcat(id string, _ gosec.Config) (gosec.Rule, []ast.Node) { rule := &sqlStrConcat{ sqlStatement: sqlStatement{ patterns: []*regexp.Regexp{ - regexp.MustCompile(`(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) `), + regexp.MustCompile("(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE)( |\n|\r|\t)"), }, MetaData: issue.MetaData{ ID: id, diff --git a/testutils/source.go b/testutils/source.go index 63e6e4c..4a49bc7 100644 --- a/testutils/source.go +++ b/testutils/source.go @@ -1712,6 +1712,28 @@ func main() { // SampleCodeG202 - SQL query string building via string concatenation SampleCodeG202 = []CodeSample{ {[]string{` + // infixed concatenation +package main + +import ( + "database/sql" + "os" +) + +func main(){ + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + + q := "INSERT INTO foo (name) VALUES ('" + os.Args[0] + "')" + rows, err := db.Query(q) + if err != nil { + panic(err) + } + defer rows.Close() +}`}, 1, gosec.NewConfig()}, + {[]string{` package main import ( @@ -1729,7 +1751,8 @@ func main(){ panic(err) } defer rows.Close() -}`}, 1, gosec.NewConfig()}, {[]string{` +}`}, 1, gosec.NewConfig()}, + {[]string{` // case insensitive match package main @@ -1748,7 +1771,8 @@ func main(){ panic(err) } defer rows.Close() -}`}, 1, gosec.NewConfig()}, {[]string{` +}`}, 1, gosec.NewConfig()}, + {[]string{` // context match package main @@ -1768,7 +1792,8 @@ func main(){ panic(err) } defer rows.Close() -}`}, 1, gosec.NewConfig()}, {[]string{` +}`}, 1, gosec.NewConfig()}, + {[]string{` // DB transaction check package main @@ -1796,7 +1821,8 @@ func main(){ if err := tx.Commit(); err != nil { panic(err) } -}`}, 1, gosec.NewConfig()}, {[]string{` +}`}, 1, gosec.NewConfig()}, + {[]string{` // multiple string concatenation package main @@ -1815,7 +1841,8 @@ func main(){ panic(err) } defer rows.Close() -}`}, 1, gosec.NewConfig()}, {[]string{` +}`}, 1, gosec.NewConfig()}, + {[]string{` // false positive package main @@ -1834,7 +1861,8 @@ func main(){ panic(err) } defer rows.Close() -}`}, 0, gosec.NewConfig()}, {[]string{` +}`}, 0, gosec.NewConfig()}, + {[]string{` package main import ( @@ -1856,7 +1884,8 @@ func main(){ } defer rows.Close() } -`}, 0, gosec.NewConfig()}, {[]string{` +`}, 0, gosec.NewConfig()}, + {[]string{` package main const gender = "M" @@ -1882,7 +1911,8 @@ func main(){ } defer rows.Close() } -`}, 0, gosec.NewConfig()}, {[]string{` +`}, 0, gosec.NewConfig()}, + {[]string{` // ExecContext match package main @@ -1903,7 +1933,8 @@ func main() { panic(err) } fmt.Println(result) -}`}, 1, gosec.NewConfig()}, {[]string{` +}`}, 1, gosec.NewConfig()}, + {[]string{` // Exec match package main @@ -1923,7 +1954,8 @@ func main() { panic(err) } fmt.Println(result) -}`}, 1, gosec.NewConfig()}, {[]string{` +}`}, 1, gosec.NewConfig()}, + {[]string{` package main import (