fix: correctly identify infixed concats as potential SQL injections (#987)

This commit is contained in:
Audun 2023-07-25 17:13:07 +02:00 committed by GitHub
parent 2292ed5e91
commit bf7feda2b9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 142 additions and 15 deletions

View file

@ -96,11 +96,46 @@ func GetChar(n ast.Node) (byte, error) {
return 0, fmt.Errorf("Unexpected AST node type: %T", n)
}
// GetStringRecursive will recursively walk down a tree of *ast.BinaryExpr. It will then concat the results, and return.
// Unlike the other getters, it does _not_ raise an error for unknown ast.Node types. At the base, the recursion will hit a non-BinaryExpr type,
// either BasicLit or other, so it's not an error case. It will only error if `strconv.Unquote` errors. This matters, because there's
// currently functionality that relies on error values being returned by GetString if and when it hits a non-basiclit string node type,
// hence for cases where recursion is needed, we use this separate function, so that we can still be backwards compatbile.
//
// This was added to handle a SQL injection concatenation case where the injected value is infixed between two strings, not at the start or end. See example below
//
// Do note that this will omit non-string values. So for example, if you were to use this node:
// ```go
// q := "SELECT * FROM foo WHERE name = '" + os.Args[0] + "' AND 1=1" // will result in "SELECT * FROM foo WHERE ” AND 1=1"
func GetStringRecursive(n ast.Node) (string, error) {
if node, ok := n.(*ast.BasicLit); ok && node.Kind == token.STRING {
return strconv.Unquote(node.Value)
}
if expr, ok := n.(*ast.BinaryExpr); ok {
x, err := GetStringRecursive(expr.X)
if err != nil {
return "", err
}
y, err := GetStringRecursive(expr.Y)
if err != nil {
return "", err
}
return x + y, nil
}
return "", nil
}
// GetString will read and return a string value from an ast.BasicLit
func GetString(n ast.Node) (string, error) {
if node, ok := n.(*ast.BasicLit); ok && node.Kind == token.STRING {
return strconv.Unquote(node.Value)
}
return "", fmt.Errorf("Unexpected AST node type: %T", n)
}
@ -201,22 +236,21 @@ func GetCallStringArgsValues(n ast.Node, _ *Context) []string {
return values
}
// GetIdentStringValues return the string values of an Ident if they can be resolved
func GetIdentStringValues(ident *ast.Ident) []string {
func getIdentStringValues(ident *ast.Ident, stringFinder func(ast.Node) (string, error)) []string {
values := []string{}
obj := ident.Obj
if obj != nil {
switch decl := obj.Decl.(type) {
case *ast.ValueSpec:
for _, v := range decl.Values {
value, err := GetString(v)
value, err := stringFinder(v)
if err == nil {
values = append(values, value)
}
}
case *ast.AssignStmt:
for _, v := range decl.Rhs {
value, err := GetString(v)
value, err := stringFinder(v)
if err == nil {
values = append(values, value)
}
@ -226,6 +260,18 @@ func GetIdentStringValues(ident *ast.Ident) []string {
return values
}
// getIdentStringRecursive returns the string of values of an Ident if they can be resolved
// The difference between this and GetIdentStringValues is that it will attempt to resolve the strings recursively,
// if it is passed a *ast.BinaryExpr. See GetStringRecursive for details
func GetIdentStringValuesRecursive(ident *ast.Ident) []string {
return getIdentStringValues(ident, GetStringRecursive)
}
// GetIdentStringValues return the string values of an Ident if they can be resolved
func GetIdentStringValues(ident *ast.Ident) []string {
return getIdentStringValues(ident, GetString)
}
// GetBinaryExprOperands returns all operands of a binary expression by traversing
// the expression tree
func GetBinaryExprOperands(be *ast.BinaryExpr) []ast.Node {

View file

@ -98,6 +98,32 @@ func (s *sqlStrConcat) ID() string {
return s.MetaData.ID
}
// findInjectionInBranch walks diwb a set if expressions, and will create new issues if it finds SQL injections
// This method assumes you've already verified that the branch contains SQL syntax
func (s *sqlStrConcat) findInjectionInBranch(ctx *gosec.Context, branch []ast.Expr) *ast.BinaryExpr {
for _, node := range branch {
be, ok := node.(*ast.BinaryExpr)
if !ok {
continue
}
operands := gosec.GetBinaryExprOperands(be)
for _, op := range operands {
if _, ok := op.(*ast.BasicLit); ok {
continue
}
if ident, ok := op.(*ast.Ident); ok && s.checkObject(ident, ctx) {
continue
}
return be
}
}
return nil
}
// see if we can figure out what it is
func (s *sqlStrConcat) checkObject(n *ast.Ident, c *gosec.Context) bool {
if n.Obj != nil {
@ -140,6 +166,28 @@ func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*issu
}
}
// Handle the case where an injection occurs as an infixed string concatenation, ie "SELECT * FROM foo WHERE name = '" + os.Args[0] + "' AND 1=1"
if id, ok := query.(*ast.Ident); ok {
var match bool
for _, str := range gosec.GetIdentStringValuesRecursive(id) {
if s.MatchPatterns(str) {
match = true
break
}
}
if !match {
return nil, nil
}
switch decl := id.Obj.Decl.(type) {
case *ast.AssignStmt:
if injection := s.findInjectionInBranch(ctx, decl.Rhs); injection != nil {
return ctx.NewIssue(injection, s.ID(), s.What, s.Severity, s.Confidence), nil
}
}
}
return nil, nil
}
@ -157,6 +205,7 @@ func (s *sqlStrConcat) Match(n ast.Node, ctx *gosec.Context) (*issue.Issue, erro
return s.checkQuery(sqlQueryCall, ctx)
}
}
return nil, nil
}
@ -165,7 +214,7 @@ func NewSQLStrConcat(id string, _ gosec.Config) (gosec.Rule, []ast.Node) {
rule := &sqlStrConcat{
sqlStatement: sqlStatement{
patterns: []*regexp.Regexp{
regexp.MustCompile(`(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) `),
regexp.MustCompile("(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE)( |\n|\r|\t)"),
},
MetaData: issue.MetaData{
ID: id,

View file

@ -1712,6 +1712,28 @@ func main() {
// SampleCodeG202 - SQL query string building via string concatenation
SampleCodeG202 = []CodeSample{
{[]string{`
// infixed concatenation
package main
import (
"database/sql"
"os"
)
func main(){
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
q := "INSERT INTO foo (name) VALUES ('" + os.Args[0] + "')"
rows, err := db.Query(q)
if err != nil {
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()},
{[]string{`
package main
import (
@ -1729,7 +1751,8 @@ func main(){
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// case insensitive match
package main
@ -1748,7 +1771,8 @@ func main(){
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// context match
package main
@ -1768,7 +1792,8 @@ func main(){
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// DB transaction check
package main
@ -1796,7 +1821,8 @@ func main(){
if err := tx.Commit(); err != nil {
panic(err)
}
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// multiple string concatenation
package main
@ -1815,7 +1841,8 @@ func main(){
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// false positive
package main
@ -1834,7 +1861,8 @@ func main(){
panic(err)
}
defer rows.Close()
}`}, 0, gosec.NewConfig()}, {[]string{`
}`}, 0, gosec.NewConfig()},
{[]string{`
package main
import (
@ -1856,7 +1884,8 @@ func main(){
}
defer rows.Close()
}
`}, 0, gosec.NewConfig()}, {[]string{`
`}, 0, gosec.NewConfig()},
{[]string{`
package main
const gender = "M"
@ -1882,7 +1911,8 @@ func main(){
}
defer rows.Close()
}
`}, 0, gosec.NewConfig()}, {[]string{`
`}, 0, gosec.NewConfig()},
{[]string{`
// ExecContext match
package main
@ -1903,7 +1933,8 @@ func main() {
panic(err)
}
fmt.Println(result)
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// Exec match
package main
@ -1923,7 +1954,8 @@ func main() {
panic(err)
}
fmt.Println(result)
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
package main
import (