mirror of
https://github.com/securego/gosec.git
synced 2024-11-05 19:45:51 +00:00
fix: correctly identify infixed concats as potential SQL injections (#987)
This commit is contained in:
parent
2292ed5e91
commit
bf7feda2b9
3 changed files with 142 additions and 15 deletions
54
helpers.go
54
helpers.go
|
@ -96,11 +96,46 @@ func GetChar(n ast.Node) (byte, error) {
|
||||||
return 0, fmt.Errorf("Unexpected AST node type: %T", n)
|
return 0, fmt.Errorf("Unexpected AST node type: %T", n)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStringRecursive will recursively walk down a tree of *ast.BinaryExpr. It will then concat the results, and return.
|
||||||
|
// Unlike the other getters, it does _not_ raise an error for unknown ast.Node types. At the base, the recursion will hit a non-BinaryExpr type,
|
||||||
|
// either BasicLit or other, so it's not an error case. It will only error if `strconv.Unquote` errors. This matters, because there's
|
||||||
|
// currently functionality that relies on error values being returned by GetString if and when it hits a non-basiclit string node type,
|
||||||
|
// hence for cases where recursion is needed, we use this separate function, so that we can still be backwards compatbile.
|
||||||
|
//
|
||||||
|
// This was added to handle a SQL injection concatenation case where the injected value is infixed between two strings, not at the start or end. See example below
|
||||||
|
//
|
||||||
|
// Do note that this will omit non-string values. So for example, if you were to use this node:
|
||||||
|
// ```go
|
||||||
|
// q := "SELECT * FROM foo WHERE name = '" + os.Args[0] + "' AND 1=1" // will result in "SELECT * FROM foo WHERE ” AND 1=1"
|
||||||
|
|
||||||
|
func GetStringRecursive(n ast.Node) (string, error) {
|
||||||
|
if node, ok := n.(*ast.BasicLit); ok && node.Kind == token.STRING {
|
||||||
|
return strconv.Unquote(node.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expr, ok := n.(*ast.BinaryExpr); ok {
|
||||||
|
x, err := GetStringRecursive(expr.X)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
y, err := GetStringRecursive(expr.Y)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return x + y, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetString will read and return a string value from an ast.BasicLit
|
// GetString will read and return a string value from an ast.BasicLit
|
||||||
func GetString(n ast.Node) (string, error) {
|
func GetString(n ast.Node) (string, error) {
|
||||||
if node, ok := n.(*ast.BasicLit); ok && node.Kind == token.STRING {
|
if node, ok := n.(*ast.BasicLit); ok && node.Kind == token.STRING {
|
||||||
return strconv.Unquote(node.Value)
|
return strconv.Unquote(node.Value)
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", fmt.Errorf("Unexpected AST node type: %T", n)
|
return "", fmt.Errorf("Unexpected AST node type: %T", n)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -201,22 +236,21 @@ func GetCallStringArgsValues(n ast.Node, _ *Context) []string {
|
||||||
return values
|
return values
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetIdentStringValues return the string values of an Ident if they can be resolved
|
func getIdentStringValues(ident *ast.Ident, stringFinder func(ast.Node) (string, error)) []string {
|
||||||
func GetIdentStringValues(ident *ast.Ident) []string {
|
|
||||||
values := []string{}
|
values := []string{}
|
||||||
obj := ident.Obj
|
obj := ident.Obj
|
||||||
if obj != nil {
|
if obj != nil {
|
||||||
switch decl := obj.Decl.(type) {
|
switch decl := obj.Decl.(type) {
|
||||||
case *ast.ValueSpec:
|
case *ast.ValueSpec:
|
||||||
for _, v := range decl.Values {
|
for _, v := range decl.Values {
|
||||||
value, err := GetString(v)
|
value, err := stringFinder(v)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
values = append(values, value)
|
values = append(values, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case *ast.AssignStmt:
|
case *ast.AssignStmt:
|
||||||
for _, v := range decl.Rhs {
|
for _, v := range decl.Rhs {
|
||||||
value, err := GetString(v)
|
value, err := stringFinder(v)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
values = append(values, value)
|
values = append(values, value)
|
||||||
}
|
}
|
||||||
|
@ -226,6 +260,18 @@ func GetIdentStringValues(ident *ast.Ident) []string {
|
||||||
return values
|
return values
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getIdentStringRecursive returns the string of values of an Ident if they can be resolved
|
||||||
|
// The difference between this and GetIdentStringValues is that it will attempt to resolve the strings recursively,
|
||||||
|
// if it is passed a *ast.BinaryExpr. See GetStringRecursive for details
|
||||||
|
func GetIdentStringValuesRecursive(ident *ast.Ident) []string {
|
||||||
|
return getIdentStringValues(ident, GetStringRecursive)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIdentStringValues return the string values of an Ident if they can be resolved
|
||||||
|
func GetIdentStringValues(ident *ast.Ident) []string {
|
||||||
|
return getIdentStringValues(ident, GetString)
|
||||||
|
}
|
||||||
|
|
||||||
// GetBinaryExprOperands returns all operands of a binary expression by traversing
|
// GetBinaryExprOperands returns all operands of a binary expression by traversing
|
||||||
// the expression tree
|
// the expression tree
|
||||||
func GetBinaryExprOperands(be *ast.BinaryExpr) []ast.Node {
|
func GetBinaryExprOperands(be *ast.BinaryExpr) []ast.Node {
|
||||||
|
|
51
rules/sql.go
51
rules/sql.go
|
@ -98,6 +98,32 @@ func (s *sqlStrConcat) ID() string {
|
||||||
return s.MetaData.ID
|
return s.MetaData.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// findInjectionInBranch walks diwb a set if expressions, and will create new issues if it finds SQL injections
|
||||||
|
// This method assumes you've already verified that the branch contains SQL syntax
|
||||||
|
func (s *sqlStrConcat) findInjectionInBranch(ctx *gosec.Context, branch []ast.Expr) *ast.BinaryExpr {
|
||||||
|
for _, node := range branch {
|
||||||
|
be, ok := node.(*ast.BinaryExpr)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
operands := gosec.GetBinaryExprOperands(be)
|
||||||
|
|
||||||
|
for _, op := range operands {
|
||||||
|
if _, ok := op.(*ast.BasicLit); ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if ident, ok := op.(*ast.Ident); ok && s.checkObject(ident, ctx) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return be
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// see if we can figure out what it is
|
// see if we can figure out what it is
|
||||||
func (s *sqlStrConcat) checkObject(n *ast.Ident, c *gosec.Context) bool {
|
func (s *sqlStrConcat) checkObject(n *ast.Ident, c *gosec.Context) bool {
|
||||||
if n.Obj != nil {
|
if n.Obj != nil {
|
||||||
|
@ -140,6 +166,28 @@ func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*issu
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle the case where an injection occurs as an infixed string concatenation, ie "SELECT * FROM foo WHERE name = '" + os.Args[0] + "' AND 1=1"
|
||||||
|
if id, ok := query.(*ast.Ident); ok {
|
||||||
|
var match bool
|
||||||
|
for _, str := range gosec.GetIdentStringValuesRecursive(id) {
|
||||||
|
if s.MatchPatterns(str) {
|
||||||
|
match = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !match {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch decl := id.Obj.Decl.(type) {
|
||||||
|
case *ast.AssignStmt:
|
||||||
|
if injection := s.findInjectionInBranch(ctx, decl.Rhs); injection != nil {
|
||||||
|
return ctx.NewIssue(injection, s.ID(), s.What, s.Severity, s.Confidence), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,6 +205,7 @@ func (s *sqlStrConcat) Match(n ast.Node, ctx *gosec.Context) (*issue.Issue, erro
|
||||||
return s.checkQuery(sqlQueryCall, ctx)
|
return s.checkQuery(sqlQueryCall, ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,7 +214,7 @@ func NewSQLStrConcat(id string, _ gosec.Config) (gosec.Rule, []ast.Node) {
|
||||||
rule := &sqlStrConcat{
|
rule := &sqlStrConcat{
|
||||||
sqlStatement: sqlStatement{
|
sqlStatement: sqlStatement{
|
||||||
patterns: []*regexp.Regexp{
|
patterns: []*regexp.Regexp{
|
||||||
regexp.MustCompile(`(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) `),
|
regexp.MustCompile("(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE)( |\n|\r|\t)"),
|
||||||
},
|
},
|
||||||
MetaData: issue.MetaData{
|
MetaData: issue.MetaData{
|
||||||
ID: id,
|
ID: id,
|
||||||
|
|
|
@ -1712,6 +1712,28 @@ func main() {
|
||||||
// SampleCodeG202 - SQL query string building via string concatenation
|
// SampleCodeG202 - SQL query string building via string concatenation
|
||||||
SampleCodeG202 = []CodeSample{
|
SampleCodeG202 = []CodeSample{
|
||||||
{[]string{`
|
{[]string{`
|
||||||
|
// infixed concatenation
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main(){
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
q := "INSERT INTO foo (name) VALUES ('" + os.Args[0] + "')"
|
||||||
|
rows, err := db.Query(q)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
}`}, 1, gosec.NewConfig()},
|
||||||
|
{[]string{`
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -1729,7 +1751,8 @@ func main(){
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
}`}, 1, gosec.NewConfig()}, {[]string{`
|
}`}, 1, gosec.NewConfig()},
|
||||||
|
{[]string{`
|
||||||
// case insensitive match
|
// case insensitive match
|
||||||
package main
|
package main
|
||||||
|
|
||||||
|
@ -1748,7 +1771,8 @@ func main(){
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
}`}, 1, gosec.NewConfig()}, {[]string{`
|
}`}, 1, gosec.NewConfig()},
|
||||||
|
{[]string{`
|
||||||
// context match
|
// context match
|
||||||
package main
|
package main
|
||||||
|
|
||||||
|
@ -1768,7 +1792,8 @@ func main(){
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
}`}, 1, gosec.NewConfig()}, {[]string{`
|
}`}, 1, gosec.NewConfig()},
|
||||||
|
{[]string{`
|
||||||
// DB transaction check
|
// DB transaction check
|
||||||
package main
|
package main
|
||||||
|
|
||||||
|
@ -1796,7 +1821,8 @@ func main(){
|
||||||
if err := tx.Commit(); err != nil {
|
if err := tx.Commit(); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}`}, 1, gosec.NewConfig()}, {[]string{`
|
}`}, 1, gosec.NewConfig()},
|
||||||
|
{[]string{`
|
||||||
// multiple string concatenation
|
// multiple string concatenation
|
||||||
package main
|
package main
|
||||||
|
|
||||||
|
@ -1815,7 +1841,8 @@ func main(){
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
}`}, 1, gosec.NewConfig()}, {[]string{`
|
}`}, 1, gosec.NewConfig()},
|
||||||
|
{[]string{`
|
||||||
// false positive
|
// false positive
|
||||||
package main
|
package main
|
||||||
|
|
||||||
|
@ -1834,7 +1861,8 @@ func main(){
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
}`}, 0, gosec.NewConfig()}, {[]string{`
|
}`}, 0, gosec.NewConfig()},
|
||||||
|
{[]string{`
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -1856,7 +1884,8 @@ func main(){
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
}
|
}
|
||||||
`}, 0, gosec.NewConfig()}, {[]string{`
|
`}, 0, gosec.NewConfig()},
|
||||||
|
{[]string{`
|
||||||
package main
|
package main
|
||||||
|
|
||||||
const gender = "M"
|
const gender = "M"
|
||||||
|
@ -1882,7 +1911,8 @@ func main(){
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
}
|
}
|
||||||
`}, 0, gosec.NewConfig()}, {[]string{`
|
`}, 0, gosec.NewConfig()},
|
||||||
|
{[]string{`
|
||||||
// ExecContext match
|
// ExecContext match
|
||||||
package main
|
package main
|
||||||
|
|
||||||
|
@ -1903,7 +1933,8 @@ func main() {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
fmt.Println(result)
|
fmt.Println(result)
|
||||||
}`}, 1, gosec.NewConfig()}, {[]string{`
|
}`}, 1, gosec.NewConfig()},
|
||||||
|
{[]string{`
|
||||||
// Exec match
|
// Exec match
|
||||||
package main
|
package main
|
||||||
|
|
||||||
|
@ -1923,7 +1954,8 @@ func main() {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
fmt.Println(result)
|
fmt.Println(result)
|
||||||
}`}, 1, gosec.NewConfig()}, {[]string{`
|
}`}, 1, gosec.NewConfig()},
|
||||||
|
{[]string{`
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
Loading…
Reference in a new issue