mirror of
https://github.com/securego/gosec.git
synced 2024-12-24 11:35:52 +00:00
Improve the SQL concatenation and string formatting rules to be applied only in the database/sql context
In addition makes pattern matching used by the rules cases insensitive. Signed-off-by: Cosmin Cojocar <cosmin.cojocar@gmx.ch>
This commit is contained in:
parent
32be4a5cc6
commit
68bce94323
4 changed files with 257 additions and 37 deletions
|
@ -1,6 +1,7 @@
|
|||
package gosec_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
|
@ -91,21 +92,23 @@ var _ = Describe("Issue", func() {
|
|||
})
|
||||
|
||||
It("should provide accurate line and file information for multi-line statements", func() {
|
||||
var target *ast.BinaryExpr
|
||||
|
||||
source := `package main
|
||||
import "os"
|
||||
func main(){`
|
||||
source += "q := `SELECT * FROM table WHERE` + \n os.Args[1] + `= ?` // nolint: gosec\n"
|
||||
source += `println(q)}`
|
||||
|
||||
var target *ast.CallExpr
|
||||
source := `
|
||||
package main
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
func main() {
|
||||
_, _ := net.Listen("tcp", "0.0.0.0:2000")
|
||||
}
|
||||
`
|
||||
pkg := testutils.NewTestPackage()
|
||||
defer pkg.Close()
|
||||
pkg.AddFile("foo.go", source)
|
||||
ctx := pkg.CreateContext("foo.go")
|
||||
v := testutils.NewMockVisitor()
|
||||
v.Callback = func(n ast.Node, ctx *gosec.Context) bool {
|
||||
if node, ok := n.(*ast.BinaryExpr); ok {
|
||||
if node, ok := n.(*ast.CallExpr); ok {
|
||||
target = node
|
||||
}
|
||||
return true
|
||||
|
@ -114,15 +117,16 @@ var _ = Describe("Issue", func() {
|
|||
ast.Walk(v, ctx.Root)
|
||||
Expect(target).ShouldNot(BeNil())
|
||||
|
||||
// Use SQL rule to check binary expr
|
||||
fmt.Printf("target: %v\n", target)
|
||||
|
||||
cfg := gosec.NewConfig()
|
||||
rule, _ := rules.NewSQLStrConcat("TEST", cfg)
|
||||
rule, _ := rules.NewBindsToAllNetworkInterfaces("TEST", cfg)
|
||||
issue, err := rule.Match(target, ctx)
|
||||
Expect(err).ShouldNot(HaveOccurred())
|
||||
Expect(issue).ShouldNot(BeNil())
|
||||
Expect(issue.File).Should(MatchRegexp("foo.go"))
|
||||
Expect(issue.Line).Should(MatchRegexp("3-4"))
|
||||
Expect(issue.Col).Should(Equal("21"))
|
||||
Expect(issue.Line).Should(MatchRegexp("7"))
|
||||
Expect(issue.Col).Should(Equal("10"))
|
||||
})
|
||||
|
||||
It("should maintain the provided severity score", func() {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package rules
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/token"
|
||||
|
||||
|
@ -99,7 +98,7 @@ func NewImplicitAliasing(id string, conf gosec.Config) (gosec.Rule, []ast.Node)
|
|||
ID: id,
|
||||
Severity: gosec.Medium,
|
||||
Confidence: gosec.Medium,
|
||||
What: fmt.Sprintf("Implicit memory aliasing in for loop."),
|
||||
What: "Implicit memory aliasing in for loop.",
|
||||
},
|
||||
}, []ast.Node{(*ast.RangeStmt)(nil), (*ast.UnaryExpr)(nil), (*ast.ReturnStmt)(nil)}
|
||||
}
|
||||
|
|
130
rules/sql.go
130
rules/sql.go
|
@ -16,13 +16,16 @@ package rules
|
|||
|
||||
import (
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/securego/gosec/v2"
|
||||
)
|
||||
|
||||
type sqlStatement struct {
|
||||
gosec.MetaData
|
||||
gosec.CallList
|
||||
|
||||
// Contains a list of patterns which must all match for the rule to match.
|
||||
patterns []*regexp.Regexp
|
||||
|
@ -65,33 +68,66 @@ func (s *sqlStrConcat) checkObject(n *ast.Ident, c *gosec.Context) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// Look for "SELECT * FROM table WHERE " + " ' OR 1=1"
|
||||
func (s *sqlStrConcat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
||||
if node, ok := n.(*ast.BinaryExpr); ok {
|
||||
if start, ok := node.X.(*ast.BasicLit); ok {
|
||||
// checkQuery verifies if the query parameters is a string concatenation
|
||||
func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) {
|
||||
_, fnName, err := gosec.GetCallInfo(call, ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var query ast.Node
|
||||
if strings.HasSuffix(fnName, "Context") {
|
||||
query = call.Args[1]
|
||||
} else {
|
||||
query = call.Args[0]
|
||||
}
|
||||
|
||||
if be, ok := query.(*ast.BinaryExpr); ok {
|
||||
// Skip all operations which aren't concatenation
|
||||
if be.Op != token.ADD {
|
||||
return nil, nil
|
||||
}
|
||||
if start, ok := be.X.(*ast.BasicLit); ok {
|
||||
if str, e := gosec.GetString(start); e == nil {
|
||||
if !s.MatchPatterns(str) {
|
||||
return nil, nil
|
||||
}
|
||||
if _, ok := node.Y.(*ast.BasicLit); ok {
|
||||
if _, ok := be.Y.(*ast.BasicLit); ok {
|
||||
return nil, nil // string cat OK
|
||||
}
|
||||
if second, ok := node.Y.(*ast.Ident); ok && s.checkObject(second, c) {
|
||||
if second, ok := be.Y.(*ast.Ident); ok && s.checkObject(second, ctx) {
|
||||
return nil, nil
|
||||
}
|
||||
return gosec.NewIssue(c, n, s.ID(), s.What, s.Severity, s.Confidence), nil
|
||||
return gosec.NewIssue(ctx, be, s.ID(), s.What, s.Severity, s.Confidence), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Checks SQL query concatenation issues such as "SELECT * FROM table WHERE " + " ' OR 1=1"
|
||||
func (s *sqlStrConcat) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, error) {
|
||||
switch stmt := n.(type) {
|
||||
case *ast.AssignStmt:
|
||||
for _, expr := range stmt.Rhs {
|
||||
if sqlQueryCall, ok := expr.(*ast.CallExpr); ok && s.ContainsCallExpr(expr, ctx) != nil {
|
||||
return s.checkQuery(sqlQueryCall, ctx)
|
||||
}
|
||||
}
|
||||
case *ast.ExprStmt:
|
||||
if sqlQueryCall, ok := stmt.X.(*ast.CallExpr); ok && s.ContainsCallExpr(stmt.X, ctx) != nil {
|
||||
return s.checkQuery(sqlQueryCall, ctx)
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// NewSQLStrConcat looks for cases where we are building SQL strings via concatenation
|
||||
func NewSQLStrConcat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
|
||||
return &sqlStrConcat{
|
||||
rule := &sqlStrConcat{
|
||||
sqlStatement: sqlStatement{
|
||||
patterns: []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) `),
|
||||
regexp.MustCompile(`(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) `),
|
||||
},
|
||||
MetaData: gosec.MetaData{
|
||||
ID: id,
|
||||
|
@ -99,13 +135,19 @@ func NewSQLStrConcat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
|
|||
Confidence: gosec.High,
|
||||
What: "SQL string concatenation",
|
||||
},
|
||||
CallList: gosec.NewCallList(),
|
||||
},
|
||||
}, []ast.Node{(*ast.BinaryExpr)(nil)}
|
||||
}
|
||||
|
||||
rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext")
|
||||
rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext")
|
||||
return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)}
|
||||
}
|
||||
|
||||
type sqlStrFormat struct {
|
||||
gosec.CallList
|
||||
sqlStatement
|
||||
calls gosec.CallList
|
||||
fmtCalls gosec.CallList
|
||||
noIssue gosec.CallList
|
||||
noIssueQuoted gosec.CallList
|
||||
}
|
||||
|
@ -130,14 +172,37 @@ func (s *sqlStrFormat) constObject(e ast.Expr, c *gosec.Context) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// Looks for "fmt.Sprintf("SELECT * FROM foo where '%s', userInput)"
|
||||
func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
||||
func (s *sqlStrFormat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) {
|
||||
_, fnName, err := gosec.GetCallInfo(call, ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var query ast.Node
|
||||
if strings.HasSuffix(fnName, "Context") {
|
||||
query = call.Args[1]
|
||||
} else {
|
||||
query = call.Args[0]
|
||||
}
|
||||
|
||||
if ident, ok := query.(*ast.Ident); ok && ident.Obj != nil {
|
||||
decl := ident.Obj.Decl
|
||||
if assign, ok := decl.(*ast.AssignStmt); ok {
|
||||
for _, expr := range assign.Rhs {
|
||||
issue, err := s.checkFormatting(expr, ctx)
|
||||
if issue != nil {
|
||||
return issue, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *sqlStrFormat) checkFormatting(n ast.Node, ctx *gosec.Context) (*gosec.Issue, error) {
|
||||
// argIndex changes the function argument which gets matched to the regex
|
||||
argIndex := 0
|
||||
|
||||
// TODO(gm) improve confidence if database/sql is being used
|
||||
if node := s.calls.ContainsPkgCallExpr(n, c, false); node != nil {
|
||||
if node := s.fmtCalls.ContainsPkgCallExpr(n, ctx, false); node != nil {
|
||||
// if the function is fmt.Fprintf, search for SQL statement in Args[1] instead
|
||||
if sel, ok := node.Fun.(*ast.SelectorExpr); ok {
|
||||
if sel.Sel.Name == "Fprintf" {
|
||||
|
@ -177,7 +242,7 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
|
|||
if argIndex+1 < len(node.Args) {
|
||||
allSafe := true
|
||||
for _, arg := range node.Args[argIndex+1:] {
|
||||
if n := s.noIssueQuoted.ContainsPkgCallExpr(arg, c, true); n == nil && !s.constObject(arg, c) {
|
||||
if n := s.noIssueQuoted.ContainsPkgCallExpr(arg, ctx, true); n == nil && !s.constObject(arg, ctx) {
|
||||
allSafe = false
|
||||
break
|
||||
}
|
||||
|
@ -187,7 +252,24 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
|
|||
}
|
||||
}
|
||||
if s.MatchPatterns(formatter) {
|
||||
return gosec.NewIssue(c, n, s.ID(), s.What, s.Severity, s.Confidence), nil
|
||||
return gosec.NewIssue(ctx, n, s.ID(), s.What, s.Severity, s.Confidence), nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check SQL query formatting issues such as "fmt.Sprintf("SELECT * FROM foo where '%s', userInput)"
|
||||
func (s *sqlStrFormat) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, error) {
|
||||
switch stmt := n.(type) {
|
||||
case *ast.AssignStmt:
|
||||
for _, expr := range stmt.Rhs {
|
||||
if sqlQueryCall, ok := expr.(*ast.CallExpr); ok && s.ContainsCallExpr(expr, ctx) != nil {
|
||||
return s.checkQuery(sqlQueryCall, ctx)
|
||||
}
|
||||
}
|
||||
case *ast.ExprStmt:
|
||||
if sqlQueryCall, ok := stmt.X.(*ast.CallExpr); ok && s.ContainsCallExpr(stmt.X, ctx) != nil {
|
||||
return s.checkQuery(sqlQueryCall, ctx)
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
|
@ -196,12 +278,13 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
|
|||
// NewSQLStrFormat looks for cases where we're building SQL query strings using format strings
|
||||
func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
|
||||
rule := &sqlStrFormat{
|
||||
calls: gosec.NewCallList(),
|
||||
CallList: gosec.NewCallList(),
|
||||
fmtCalls: gosec.NewCallList(),
|
||||
noIssue: gosec.NewCallList(),
|
||||
noIssueQuoted: gosec.NewCallList(),
|
||||
sqlStatement: sqlStatement{
|
||||
patterns: []*regexp.Regexp{
|
||||
regexp.MustCompile("(?)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) "),
|
||||
regexp.MustCompile("(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) "),
|
||||
regexp.MustCompile("%[^bdoxXfFp]"),
|
||||
},
|
||||
MetaData: gosec.MetaData{
|
||||
|
@ -212,8 +295,11 @@ func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
|
|||
},
|
||||
},
|
||||
}
|
||||
rule.calls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf")
|
||||
rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext")
|
||||
rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext")
|
||||
rule.fmtCalls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf")
|
||||
rule.noIssue.AddAll("os", "Stdout", "Stderr")
|
||||
rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier")
|
||||
return rule, []ast.Node{(*ast.CallExpr)(nil)}
|
||||
|
||||
return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)}
|
||||
}
|
||||
|
|
|
@ -789,6 +789,76 @@ func main(){
|
|||
}
|
||||
defer rows.Close()
|
||||
}`}, 1, gosec.NewConfig()}, {[]string{`
|
||||
// Format string without proper quoting case insensitive
|
||||
package main
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main(){
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
q := fmt.Sprintf("select * from foo where name = '%s'", os.Args[1])
|
||||
rows, err := db.Query(q)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
}`}, 1, gosec.NewConfig()}, {[]string{`
|
||||
// Format string without proper quoting with context
|
||||
package main
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main(){
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
q := fmt.Sprintf("select * from foo where name = '%s'", os.Args[1])
|
||||
rows, err := db.QueryContext(context.Background(), q)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
}`}, 1, gosec.NewConfig()}, {[]string{`
|
||||
// Format string without proper quoting with transation
|
||||
package main
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main(){
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
q := fmt.Sprintf("select * from foo where name = '%s'", os.Args[1])
|
||||
rows, err := tx.QueryContext(context.Background(), q)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if err := tx.Commit(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}`}, 1, gosec.NewConfig()}, {[]string{`
|
||||
// Format string false positive, safe string spec.
|
||||
package main
|
||||
import (
|
||||
|
@ -895,6 +965,67 @@ func main(){
|
|||
}
|
||||
defer rows.Close()
|
||||
}`}, 1, gosec.NewConfig()}, {[]string{`
|
||||
// case insensitive match
|
||||
package main
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
)
|
||||
func main(){
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
rows, err := db.Query("select * from foo where name = " + os.Args[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
}`}, 1, gosec.NewConfig()}, {[]string{`
|
||||
// context match
|
||||
package main
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
)
|
||||
func main(){
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
rows, err := db.QueryContext(context.Background(), "select * from foo where name = " + os.Args[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
}`}, 1, gosec.NewConfig()}, {[]string{`
|
||||
// DB transation check
|
||||
package main
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
)
|
||||
func main(){
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
rows, err := tx.QueryContext(context.Background(), "select * from foo where name = " + os.Args[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
if err := tx.Commit(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}`}, 1, gosec.NewConfig()}, {[]string{`
|
||||
// false positive
|
||||
package main
|
||||
import (
|
||||
|
|
Loading…
Reference in a new issue