Improve the SQL concatenation and string formatting rules to be applied only in the database/sql context

In addition makes pattern matching used by the rules cases insensitive.
Signed-off-by: Cosmin Cojocar <cosmin.cojocar@gmx.ch>
This commit is contained in:
Cosmin Cojocar 2020-05-25 14:19:00 +02:00 committed by Cosmin Cojocar
parent 32be4a5cc6
commit 68bce94323
4 changed files with 257 additions and 37 deletions

View file

@ -1,6 +1,7 @@
package gosec_test
import (
"fmt"
"go/ast"
. "github.com/onsi/ginkgo"
@ -91,21 +92,23 @@ var _ = Describe("Issue", func() {
})
It("should provide accurate line and file information for multi-line statements", func() {
var target *ast.BinaryExpr
source := `package main
import "os"
func main(){`
source += "q := `SELECT * FROM table WHERE` + \n os.Args[1] + `= ?` // nolint: gosec\n"
source += `println(q)}`
var target *ast.CallExpr
source := `
package main
import (
"net"
)
func main() {
_, _ := net.Listen("tcp", "0.0.0.0:2000")
}
`
pkg := testutils.NewTestPackage()
defer pkg.Close()
pkg.AddFile("foo.go", source)
ctx := pkg.CreateContext("foo.go")
v := testutils.NewMockVisitor()
v.Callback = func(n ast.Node, ctx *gosec.Context) bool {
if node, ok := n.(*ast.BinaryExpr); ok {
if node, ok := n.(*ast.CallExpr); ok {
target = node
}
return true
@ -114,15 +117,16 @@ var _ = Describe("Issue", func() {
ast.Walk(v, ctx.Root)
Expect(target).ShouldNot(BeNil())
// Use SQL rule to check binary expr
fmt.Printf("target: %v\n", target)
cfg := gosec.NewConfig()
rule, _ := rules.NewSQLStrConcat("TEST", cfg)
rule, _ := rules.NewBindsToAllNetworkInterfaces("TEST", cfg)
issue, err := rule.Match(target, ctx)
Expect(err).ShouldNot(HaveOccurred())
Expect(issue).ShouldNot(BeNil())
Expect(issue.File).Should(MatchRegexp("foo.go"))
Expect(issue.Line).Should(MatchRegexp("3-4"))
Expect(issue.Col).Should(Equal("21"))
Expect(issue.Line).Should(MatchRegexp("7"))
Expect(issue.Col).Should(Equal("10"))
})
It("should maintain the provided severity score", func() {

View file

@ -1,7 +1,6 @@
package rules
import (
"fmt"
"go/ast"
"go/token"
@ -99,7 +98,7 @@ func NewImplicitAliasing(id string, conf gosec.Config) (gosec.Rule, []ast.Node)
ID: id,
Severity: gosec.Medium,
Confidence: gosec.Medium,
What: fmt.Sprintf("Implicit memory aliasing in for loop."),
What: "Implicit memory aliasing in for loop.",
},
}, []ast.Node{(*ast.RangeStmt)(nil), (*ast.UnaryExpr)(nil), (*ast.ReturnStmt)(nil)}
}

View file

@ -16,13 +16,16 @@ package rules
import (
"go/ast"
"go/token"
"regexp"
"strings"
"github.com/securego/gosec/v2"
)
type sqlStatement struct {
gosec.MetaData
gosec.CallList
// Contains a list of patterns which must all match for the rule to match.
patterns []*regexp.Regexp
@ -65,33 +68,66 @@ func (s *sqlStrConcat) checkObject(n *ast.Ident, c *gosec.Context) bool {
return false
}
// Look for "SELECT * FROM table WHERE " + " ' OR 1=1"
func (s *sqlStrConcat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
if node, ok := n.(*ast.BinaryExpr); ok {
if start, ok := node.X.(*ast.BasicLit); ok {
// checkQuery verifies if the query parameters is a string concatenation
func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) {
_, fnName, err := gosec.GetCallInfo(call, ctx)
if err != nil {
return nil, err
}
var query ast.Node
if strings.HasSuffix(fnName, "Context") {
query = call.Args[1]
} else {
query = call.Args[0]
}
if be, ok := query.(*ast.BinaryExpr); ok {
// Skip all operations which aren't concatenation
if be.Op != token.ADD {
return nil, nil
}
if start, ok := be.X.(*ast.BasicLit); ok {
if str, e := gosec.GetString(start); e == nil {
if !s.MatchPatterns(str) {
return nil, nil
}
if _, ok := node.Y.(*ast.BasicLit); ok {
if _, ok := be.Y.(*ast.BasicLit); ok {
return nil, nil // string cat OK
}
if second, ok := node.Y.(*ast.Ident); ok && s.checkObject(second, c) {
if second, ok := be.Y.(*ast.Ident); ok && s.checkObject(second, ctx) {
return nil, nil
}
return gosec.NewIssue(c, n, s.ID(), s.What, s.Severity, s.Confidence), nil
return gosec.NewIssue(ctx, be, s.ID(), s.What, s.Severity, s.Confidence), nil
}
}
}
return nil, nil
}
// Checks SQL query concatenation issues such as "SELECT * FROM table WHERE " + " ' OR 1=1"
func (s *sqlStrConcat) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, error) {
switch stmt := n.(type) {
case *ast.AssignStmt:
for _, expr := range stmt.Rhs {
if sqlQueryCall, ok := expr.(*ast.CallExpr); ok && s.ContainsCallExpr(expr, ctx) != nil {
return s.checkQuery(sqlQueryCall, ctx)
}
}
case *ast.ExprStmt:
if sqlQueryCall, ok := stmt.X.(*ast.CallExpr); ok && s.ContainsCallExpr(stmt.X, ctx) != nil {
return s.checkQuery(sqlQueryCall, ctx)
}
}
return nil, nil
}
// NewSQLStrConcat looks for cases where we are building SQL strings via concatenation
func NewSQLStrConcat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
return &sqlStrConcat{
rule := &sqlStrConcat{
sqlStatement: sqlStatement{
patterns: []*regexp.Regexp{
regexp.MustCompile(`(?)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) `),
regexp.MustCompile(`(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) `),
},
MetaData: gosec.MetaData{
ID: id,
@ -99,13 +135,19 @@ func NewSQLStrConcat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
Confidence: gosec.High,
What: "SQL string concatenation",
},
CallList: gosec.NewCallList(),
},
}, []ast.Node{(*ast.BinaryExpr)(nil)}
}
rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext")
rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext")
return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)}
}
type sqlStrFormat struct {
gosec.CallList
sqlStatement
calls gosec.CallList
fmtCalls gosec.CallList
noIssue gosec.CallList
noIssueQuoted gosec.CallList
}
@ -130,14 +172,37 @@ func (s *sqlStrFormat) constObject(e ast.Expr, c *gosec.Context) bool {
return false
}
// Looks for "fmt.Sprintf("SELECT * FROM foo where '%s', userInput)"
func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
func (s *sqlStrFormat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) {
_, fnName, err := gosec.GetCallInfo(call, ctx)
if err != nil {
return nil, err
}
var query ast.Node
if strings.HasSuffix(fnName, "Context") {
query = call.Args[1]
} else {
query = call.Args[0]
}
if ident, ok := query.(*ast.Ident); ok && ident.Obj != nil {
decl := ident.Obj.Decl
if assign, ok := decl.(*ast.AssignStmt); ok {
for _, expr := range assign.Rhs {
issue, err := s.checkFormatting(expr, ctx)
if issue != nil {
return issue, err
}
}
}
}
return nil, nil
}
func (s *sqlStrFormat) checkFormatting(n ast.Node, ctx *gosec.Context) (*gosec.Issue, error) {
// argIndex changes the function argument which gets matched to the regex
argIndex := 0
// TODO(gm) improve confidence if database/sql is being used
if node := s.calls.ContainsPkgCallExpr(n, c, false); node != nil {
if node := s.fmtCalls.ContainsPkgCallExpr(n, ctx, false); node != nil {
// if the function is fmt.Fprintf, search for SQL statement in Args[1] instead
if sel, ok := node.Fun.(*ast.SelectorExpr); ok {
if sel.Sel.Name == "Fprintf" {
@ -177,7 +242,7 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
if argIndex+1 < len(node.Args) {
allSafe := true
for _, arg := range node.Args[argIndex+1:] {
if n := s.noIssueQuoted.ContainsPkgCallExpr(arg, c, true); n == nil && !s.constObject(arg, c) {
if n := s.noIssueQuoted.ContainsPkgCallExpr(arg, ctx, true); n == nil && !s.constObject(arg, ctx) {
allSafe = false
break
}
@ -187,7 +252,24 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
}
}
if s.MatchPatterns(formatter) {
return gosec.NewIssue(c, n, s.ID(), s.What, s.Severity, s.Confidence), nil
return gosec.NewIssue(ctx, n, s.ID(), s.What, s.Severity, s.Confidence), nil
}
}
return nil, nil
}
// Check SQL query formatting issues such as "fmt.Sprintf("SELECT * FROM foo where '%s', userInput)"
func (s *sqlStrFormat) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, error) {
switch stmt := n.(type) {
case *ast.AssignStmt:
for _, expr := range stmt.Rhs {
if sqlQueryCall, ok := expr.(*ast.CallExpr); ok && s.ContainsCallExpr(expr, ctx) != nil {
return s.checkQuery(sqlQueryCall, ctx)
}
}
case *ast.ExprStmt:
if sqlQueryCall, ok := stmt.X.(*ast.CallExpr); ok && s.ContainsCallExpr(stmt.X, ctx) != nil {
return s.checkQuery(sqlQueryCall, ctx)
}
}
return nil, nil
@ -196,12 +278,13 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
// NewSQLStrFormat looks for cases where we're building SQL query strings using format strings
func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
rule := &sqlStrFormat{
calls: gosec.NewCallList(),
CallList: gosec.NewCallList(),
fmtCalls: gosec.NewCallList(),
noIssue: gosec.NewCallList(),
noIssueQuoted: gosec.NewCallList(),
sqlStatement: sqlStatement{
patterns: []*regexp.Regexp{
regexp.MustCompile("(?)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) "),
regexp.MustCompile("(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) "),
regexp.MustCompile("%[^bdoxXfFp]"),
},
MetaData: gosec.MetaData{
@ -212,8 +295,11 @@ func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
},
},
}
rule.calls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf")
rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext")
rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext")
rule.fmtCalls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf")
rule.noIssue.AddAll("os", "Stdout", "Stderr")
rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier")
return rule, []ast.Node{(*ast.CallExpr)(nil)}
return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)}
}

View file

@ -789,6 +789,76 @@ func main(){
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
// Format string without proper quoting case insensitive
package main
import (
"database/sql"
"fmt"
"os"
)
func main(){
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
q := fmt.Sprintf("select * from foo where name = '%s'", os.Args[1])
rows, err := db.Query(q)
if err != nil {
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
// Format string without proper quoting with context
package main
import (
"context"
"database/sql"
"fmt"
"os"
)
func main(){
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
q := fmt.Sprintf("select * from foo where name = '%s'", os.Args[1])
rows, err := db.QueryContext(context.Background(), q)
if err != nil {
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
// Format string without proper quoting with transation
package main
import (
"context"
"database/sql"
"fmt"
"os"
)
func main(){
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
tx, err := db.Begin()
if err != nil {
panic(err)
}
defer tx.Rollback()
q := fmt.Sprintf("select * from foo where name = '%s'", os.Args[1])
rows, err := tx.QueryContext(context.Background(), q)
if err != nil {
panic(err)
}
defer rows.Close()
if err := tx.Commit(); err != nil {
panic(err)
}
}`}, 1, gosec.NewConfig()}, {[]string{`
// Format string false positive, safe string spec.
package main
import (
@ -895,6 +965,67 @@ func main(){
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
// case insensitive match
package main
import (
"database/sql"
"os"
)
func main(){
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
rows, err := db.Query("select * from foo where name = " + os.Args[1])
if err != nil {
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
// context match
package main
import (
"context"
"database/sql"
"os"
)
func main(){
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
rows, err := db.QueryContext(context.Background(), "select * from foo where name = " + os.Args[1])
if err != nil {
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
// DB transation check
package main
import (
"context"
"database/sql"
"os"
)
func main(){
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
tx, err := db.Begin()
if err != nil {
panic(err)
}
defer tx.Rollback()
rows, err := tx.QueryContext(context.Background(), "select * from foo where name = " + os.Args[1])
if err != nil {
panic(err)
}
defer rows.Close()
if err := tx.Commit(); err != nil {
panic(err)
}
}`}, 1, gosec.NewConfig()}, {[]string{`
// false positive
package main
import (