diff --git a/rules/sql.go b/rules/sql.go index 9d98398..2220668 100644 --- a/rules/sql.go +++ b/rules/sql.go @@ -15,10 +15,10 @@ package rules import ( - gas "github.com/HewlettPackard/gas/core" "go/ast" - "reflect" "regexp" + + gas "github.com/HewlettPackard/gas/core" ) type SqlStatement struct { @@ -30,13 +30,27 @@ type SqlStrConcat struct { SqlStatement } +// see if we can figgure out what it is +func (s *SqlStrConcat) checkObject(n *ast.Ident) bool { + if n.Obj != nil { + return (n.Obj.Kind != ast.Var || n.Obj.Kind != ast.Fun) + } + return false +} + // Look for "SELECT * FROM table WHERE " + " ' OR 1=1" func (s *SqlStrConcat) Match(n ast.Node, c *gas.Context) (*gas.Issue, error) { - a := reflect.TypeOf(&ast.BinaryExpr{}) - b := reflect.TypeOf(&ast.BasicLit{}) - if node := gas.SimpleSelect(n, a, b); node != nil { - if str, _ := gas.GetString(node); s.pattern.MatchString(str) { - return gas.NewIssue(c, n, s.What, s.Severity, s.Confidence), nil + if node, ok := n.(*ast.BinaryExpr); ok { + if start, ok := node.X.(*ast.BasicLit); ok { + if str, _ := gas.GetString(start); s.pattern.MatchString(str) { + if _, ok := node.Y.(*ast.BasicLit); ok { + return nil, nil // string cat OK + } + if second, ok := node.Y.(*ast.Ident); ok && s.checkObject(second) { + return nil, nil + } + return gas.NewIssue(c, n, s.What, s.Severity, s.Confidence), nil + } } } return nil, nil diff --git a/rules/sql_test.go b/rules/sql_test.go index 76cf920..3dc6171 100644 --- a/rules/sql_test.go +++ b/rules/sql_test.go @@ -15,8 +15,9 @@ package rules import ( - gas "github.com/HewlettPackard/gas/core" "testing" + + gas "github.com/HewlettPackard/gas/core" ) func TestSQLInjectionViaConcatenation(t *testing.T) { @@ -144,3 +145,74 @@ func TestSQLInjectionFalsePositiveB(t *testing.T) { checkTestResults(t, issues, 0, "Not expected to match") } + +func TestSQLInjectionFalsePositiveC(t *testing.T) { + analyzer := gas.NewAnalyzer(false, nil) + analyzer.AddRule(NewSqlStrConcat()) + analyzer.AddRule(NewSqlStrFormat()) + + source := ` + + package main + import ( + "database/sql" + "fmt" + "os" + _ "github.com/mattn/go-sqlite3" + ) + + var staticQuery = "SELECT * FROM foo WHERE age < " + + func main(){ + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + rows, err := db.Query(staticQuery + "32") + if err != nil { + panic(err) + } + defer rows.Close() + } + + ` + issues := gasTestRunner(source, analyzer) + + checkTestResults(t, issues, 0, "Not expected to match") +} + +func TestSQLInjectionFalsePositiveD(t *testing.T) { + analyzer := gas.NewAnalyzer(false, nil) + analyzer.AddRule(NewSqlStrConcat()) + analyzer.AddRule(NewSqlStrFormat()) + + source := ` + + package main + import ( + "database/sql" + "fmt" + "os" + _ "github.com/mattn/go-sqlite3" + ) + + const age = "32" + var staticQuery = "SELECT * FROM foo WHERE age < " + + func main(){ + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + rows, err := db.Query(staticQuery + age) + if err != nil { + panic(err) + } + defer rows.Close() + } + + ` + issues := gasTestRunner(source, analyzer) + + checkTestResults(t, issues, 0, "Not expected to match") +}