Allow quoted strings to be used to format SQL queries (#240)

* Support stripping vendor paths when matching calls

* Factor out matching of formatter string

* Quoted strings are safe to use with SQL str formatted strings

* Add test for allowing quoted strings with string formatters

* Install the pq package for tests to pass
This commit is contained in:
Dale Hui 2018-09-25 00:40:05 -07:00 committed by Cosmin Cojocar
parent ec32ce68d8
commit 762ff3a709
14 changed files with 88 additions and 33 deletions

View file

@ -12,6 +12,7 @@ install:
- go get -u github.com/onsi/ginkgo/ginkgo
- go get -u github.com/onsi/gomega
- go get -u golang.org/x/crypto/ssh
- go get -u github.com/lib/pq
- go get -u github.com/securego/gosec/cmd/gosec/...
- go get -v -t ./...
- export PATH=$PATH:$HOME/gopath/bin

View file

@ -15,8 +15,11 @@ package gosec
import (
"go/ast"
"strings"
)
const vendorPath = "vendor/"
type set map[string]bool
// CallList is used to check for usage of specific packages
@ -55,17 +58,27 @@ func (c CallList) Contains(selector, ident string) bool {
// ContainsCallExpr resolves the call expression name and type
/// or package and determines if it exists within the CallList
func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) *ast.CallExpr {
func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context, stripVendor bool) *ast.CallExpr {
selector, ident, err := GetCallInfo(n, ctx)
if err != nil {
return nil
}
// Use only explicit path to reduce conflicts
if path, ok := GetImportPath(selector, ctx); ok && c.Contains(path, ident) {
return n.(*ast.CallExpr)
// Use only explicit path (optionally strip vendor path prefix) to reduce conflicts
path, ok := GetImportPath(selector, ctx)
if !ok {
return nil
}
if stripVendor {
if vendorIdx := strings.Index(path, vendorPath); vendorIdx >= 0 {
path = path[vendorIdx+len(vendorPath):]
}
}
if !c.Contains(path, ident) {
return nil
}
return n.(*ast.CallExpr)
/*
// Try direct resolution
if c.Contains(selector, ident) {
@ -74,5 +87,4 @@ func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) *ast.CallExpr {
}
*/
return nil
}

View file

@ -73,7 +73,7 @@ var _ = Describe("call list", func() {
v := testutils.NewMockVisitor()
v.Context = ctx
v.Callback = func(n ast.Node, ctx *gosec.Context) bool {
if _, ok := n.(*ast.CallExpr); ok && calls.ContainsCallExpr(n, ctx) != nil {
if _, ok := n.(*ast.CallExpr); ok && calls.ContainsCallExpr(n, ctx, false) != nil {
matched++
}
return true

View file

@ -19,7 +19,7 @@ func (a *archive) ID() string {
// Match inspects AST nodes to determine if the filepath.Joins uses any argument derived from type zip.File
func (a *archive) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
if node := a.calls.ContainsCallExpr(n, c); node != nil {
if node := a.calls.ContainsCallExpr(n, c, false); node != nil {
for _, arg := range node.Args {
var argType types.Type
if selector, ok := arg.(*ast.SelectorExpr); ok {

View file

@ -33,7 +33,7 @@ func (r *bindsToAllNetworkInterfaces) ID() string {
}
func (r *bindsToAllNetworkInterfaces) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
callExpr := r.calls.ContainsCallExpr(n, c)
callExpr := r.calls.ContainsCallExpr(n, c, false)
if callExpr == nil {
return nil, nil
}

View file

@ -53,7 +53,7 @@ func (r *noErrorCheck) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, erro
switch stmt := n.(type) {
case *ast.AssignStmt:
for _, expr := range stmt.Rhs {
if callExpr, ok := expr.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(expr, ctx) == nil {
if callExpr, ok := expr.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(expr, ctx, false) == nil {
pos := returnsError(callExpr, ctx)
if pos < 0 || pos >= len(stmt.Lhs) {
return nil, nil
@ -64,7 +64,7 @@ func (r *noErrorCheck) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, erro
}
}
case *ast.ExprStmt:
if callExpr, ok := stmt.X.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(stmt.X, ctx) == nil {
if callExpr, ok := stmt.X.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(stmt.X, ctx, false) == nil {
pos := returnsError(callExpr, ctx)
if pos >= 0 {
return gosec.NewIssue(ctx, n, r.ID(), r.What, r.Severity, r.Confidence), nil

View file

@ -34,7 +34,7 @@ func (r *readfile) ID() string {
// isJoinFunc checks if there is a filepath.Join or other join function
func (r *readfile) isJoinFunc(n ast.Node, c *gosec.Context) bool {
if call := r.pathJoin.ContainsCallExpr(n, c); call != nil {
if call := r.pathJoin.ContainsCallExpr(n, c, false); call != nil {
for _, arg := range call.Args {
// edge case: check if one of the args is a BinaryExpr
if binExp, ok := arg.(*ast.BinaryExpr); ok {
@ -44,21 +44,21 @@ func (r *readfile) isJoinFunc(n ast.Node, c *gosec.Context) bool {
}
}
// try and resolve identity
if ident, ok := arg.(*ast.Ident); ok {
obj := c.Info.ObjectOf(ident)
if _, ok := obj.(*types.Var); ok && !gosec.TryResolve(ident, c) {
return true
// try and resolve identity
if ident, ok := arg.(*ast.Ident); ok {
obj := c.Info.ObjectOf(ident)
if _, ok := obj.(*types.Var); ok && !gosec.TryResolve(ident, c) {
return true
}
}
}
}
}
return false
}
// Match inspects AST nodes to determine if the match the methods `os.Open` or `ioutil.ReadFile`
func (r *readfile) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
if node := r.ContainsCallExpr(n, c); node != nil {
if node := r.ContainsCallExpr(n, c, false); node != nil {
for _, arg := range node.Args {
// handles path joining functions in Arg
// eg. os.Open(filepath.Join("/tmp/", file))

View file

@ -32,7 +32,7 @@ func (w *weakKeyStrength) ID() string {
}
func (w *weakKeyStrength) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
if callExpr := w.calls.ContainsCallExpr(n, c); callExpr != nil {
if callExpr := w.calls.ContainsCallExpr(n, c, false); callExpr != nil {
if bits, err := gosec.GetInt(callExpr.Args[1]); err == nil && bits < (int64)(w.bits) {
return gosec.NewIssue(c, n, w.ID(), w.What, w.Severity, w.Confidence), nil
}

View file

@ -98,8 +98,9 @@ func NewSQLStrConcat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
type sqlStrFormat struct {
sqlStatement
calls gosec.CallList
noIssue gosec.CallList
calls gosec.CallList
noIssue gosec.CallList
noIssueQuoted gosec.CallList
}
// Looks for "fmt.Sprintf("SELECT * FROM foo where '%s', userInput)"
@ -109,7 +110,7 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
argIndex := 0
// TODO(gm) improve confidence if database/sql is being used
if node := s.calls.ContainsCallExpr(n, c); node != nil {
if node := s.calls.ContainsCallExpr(n, c, false); node != nil {
// if the function is fmt.Fprintf, search for SQL statement in Args[1] instead
if sel, ok := node.Fun.(*ast.SelectorExpr); ok {
if sel.Sel.Name == "Fprintf" {
@ -125,17 +126,35 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
argIndex = 1
}
}
var formatter string
// concats callexpr arg strings together if needed before regex evaluation
if argExpr, ok := node.Args[argIndex].(*ast.BinaryExpr); ok {
if fullStr, ok := gosec.ConcatString(argExpr); ok {
if s.MatchPatterns(fullStr) {
return gosec.NewIssue(c, n, s.ID(), s.What, s.Severity, s.Confidence),
nil
}
formatter = fullStr
}
} else if arg, e := gosec.GetString(node.Args[argIndex]); e == nil {
formatter = arg
}
if len(formatter) <= 0 {
return nil, nil
}
if arg, e := gosec.GetString(node.Args[argIndex]); s.MatchPatterns(arg) && e == nil {
// If all formatter args are quoted, then the SQL construction is safe
if argIndex+1 < len(node.Args) {
allQuoted := true
for _, arg := range node.Args[argIndex+1:] {
if n := s.noIssueQuoted.ContainsCallExpr(arg, c, true); n == nil {
allQuoted = false
break
}
}
if allQuoted {
return nil, nil
}
}
if s.MatchPatterns(formatter) {
return gosec.NewIssue(c, n, s.ID(), s.What, s.Severity, s.Confidence), nil
}
}
@ -145,8 +164,9 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
// NewSQLStrFormat looks for cases where we're building SQL query strings using format strings
func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
rule := &sqlStrFormat{
calls: gosec.NewCallList(),
noIssue: gosec.NewCallList(),
calls: gosec.NewCallList(),
noIssue: gosec.NewCallList(),
noIssueQuoted: gosec.NewCallList(),
sqlStatement: sqlStatement{
patterns: []*regexp.Regexp{
regexp.MustCompile("(?)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) "),
@ -162,5 +182,6 @@ func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
}
rule.calls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf")
rule.noIssue.AddAll("os", "Stdout", "Stderr")
rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier")
return rule, []ast.Node{(*ast.CallExpr)(nil)}
}

View file

@ -35,7 +35,7 @@ func (r *ssrf) ResolveVar(n *ast.CallExpr, c *gosec.Context) bool {
// Match inspects AST nodes to determine if certain net/http methods are called with variable input
func (r *ssrf) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
// Call expression is using http package directly
if node := r.ContainsCallExpr(n, c); node != nil {
if node := r.ContainsCallExpr(n, c, false); node != nil {
if r.ResolveVar(node, c) {
return gosec.NewIssue(c, n, r.ID(), r.What, r.Severity, r.Confidence), nil
}

View file

@ -40,7 +40,7 @@ func (r *subprocess) ID() string {
//
// syscall.Exec("echo", "foobar" + tainted)
func (r *subprocess) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
if node := r.ContainsCallExpr(n, c); node != nil {
if node := r.ContainsCallExpr(n, c, false); node != nil {
for _, arg := range node.Args {
if ident, ok := arg.(*ast.Ident); ok {
obj := c.Info.ObjectOf(ident)

View file

@ -32,7 +32,7 @@ func (t *badTempFile) ID() string {
}
func (t *badTempFile) Match(n ast.Node, c *gosec.Context) (gi *gosec.Issue, err error) {
if node := t.calls.ContainsCallExpr(n, c); node != nil {
if node := t.calls.ContainsCallExpr(n, c, false); node != nil {
if arg, e := gosec.GetString(node.Args[0]); t.args.MatchString(arg) && e == nil {
return gosec.NewIssue(c, n, t.ID(), t.What, t.Severity, t.Confidence), nil
}

View file

@ -30,7 +30,7 @@ func (t *templateCheck) ID() string {
}
func (t *templateCheck) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
if node := t.calls.ContainsCallExpr(n, c); node != nil {
if node := t.calls.ContainsCallExpr(n, c, false); node != nil {
for _, arg := range node.Args {
if _, ok := arg.(*ast.BasicLit); !ok { // basic lits are safe
return gosec.NewIssue(c, n, t.ID(), t.What, t.Severity, t.Confidence), nil

View file

@ -292,6 +292,27 @@ func main(){
panic(err)
}
defer rows.Close()
}`, 0}, {`
// Format string false positive, quoted formatter argument.
package main
import (
"database/sql"
"fmt"
"os"
"github.com/lib/pq"
)
func main(){
db, err := sql.Open("postgres", "localhost")
if err != nil {
panic(err)
}
q := fmt.Sprintf("SELECT * FROM %s where id = 1", pq.QuoteIdentifier(os.Args[1]))
rows, err := db.Query(q)
if err != nil {
panic(err)
}
defer rows.Close()
}`, 0}}
// SampleCodeG202 - SQL query string building via string concatenation