mirror of
https://github.com/securego/gosec.git
synced 2024-11-05 19:45:51 +00:00
Allow quoted strings to be used to format SQL queries (#240)
* Support stripping vendor paths when matching calls * Factor out matching of formatter string * Quoted strings are safe to use with SQL str formatted strings * Add test for allowing quoted strings with string formatters * Install the pq package for tests to pass
This commit is contained in:
parent
ec32ce68d8
commit
762ff3a709
14 changed files with 88 additions and 33 deletions
|
@ -12,6 +12,7 @@ install:
|
|||
- go get -u github.com/onsi/ginkgo/ginkgo
|
||||
- go get -u github.com/onsi/gomega
|
||||
- go get -u golang.org/x/crypto/ssh
|
||||
- go get -u github.com/lib/pq
|
||||
- go get -u github.com/securego/gosec/cmd/gosec/...
|
||||
- go get -v -t ./...
|
||||
- export PATH=$PATH:$HOME/gopath/bin
|
||||
|
|
22
call_list.go
22
call_list.go
|
@ -15,8 +15,11 @@ package gosec
|
|||
|
||||
import (
|
||||
"go/ast"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const vendorPath = "vendor/"
|
||||
|
||||
type set map[string]bool
|
||||
|
||||
// CallList is used to check for usage of specific packages
|
||||
|
@ -55,17 +58,27 @@ func (c CallList) Contains(selector, ident string) bool {
|
|||
|
||||
// ContainsCallExpr resolves the call expression name and type
|
||||
/// or package and determines if it exists within the CallList
|
||||
func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) *ast.CallExpr {
|
||||
func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context, stripVendor bool) *ast.CallExpr {
|
||||
selector, ident, err := GetCallInfo(n, ctx)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use only explicit path to reduce conflicts
|
||||
if path, ok := GetImportPath(selector, ctx); ok && c.Contains(path, ident) {
|
||||
return n.(*ast.CallExpr)
|
||||
// Use only explicit path (optionally strip vendor path prefix) to reduce conflicts
|
||||
path, ok := GetImportPath(selector, ctx)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if stripVendor {
|
||||
if vendorIdx := strings.Index(path, vendorPath); vendorIdx >= 0 {
|
||||
path = path[vendorIdx+len(vendorPath):]
|
||||
}
|
||||
}
|
||||
if !c.Contains(path, ident) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return n.(*ast.CallExpr)
|
||||
/*
|
||||
// Try direct resolution
|
||||
if c.Contains(selector, ident) {
|
||||
|
@ -74,5 +87,4 @@ func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) *ast.CallExpr {
|
|||
}
|
||||
*/
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -73,7 +73,7 @@ var _ = Describe("call list", func() {
|
|||
v := testutils.NewMockVisitor()
|
||||
v.Context = ctx
|
||||
v.Callback = func(n ast.Node, ctx *gosec.Context) bool {
|
||||
if _, ok := n.(*ast.CallExpr); ok && calls.ContainsCallExpr(n, ctx) != nil {
|
||||
if _, ok := n.(*ast.CallExpr); ok && calls.ContainsCallExpr(n, ctx, false) != nil {
|
||||
matched++
|
||||
}
|
||||
return true
|
||||
|
|
|
@ -19,7 +19,7 @@ func (a *archive) ID() string {
|
|||
|
||||
// Match inspects AST nodes to determine if the filepath.Joins uses any argument derived from type zip.File
|
||||
func (a *archive) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
||||
if node := a.calls.ContainsCallExpr(n, c); node != nil {
|
||||
if node := a.calls.ContainsCallExpr(n, c, false); node != nil {
|
||||
for _, arg := range node.Args {
|
||||
var argType types.Type
|
||||
if selector, ok := arg.(*ast.SelectorExpr); ok {
|
||||
|
|
|
@ -33,7 +33,7 @@ func (r *bindsToAllNetworkInterfaces) ID() string {
|
|||
}
|
||||
|
||||
func (r *bindsToAllNetworkInterfaces) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
||||
callExpr := r.calls.ContainsCallExpr(n, c)
|
||||
callExpr := r.calls.ContainsCallExpr(n, c, false)
|
||||
if callExpr == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
@ -53,7 +53,7 @@ func (r *noErrorCheck) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, erro
|
|||
switch stmt := n.(type) {
|
||||
case *ast.AssignStmt:
|
||||
for _, expr := range stmt.Rhs {
|
||||
if callExpr, ok := expr.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(expr, ctx) == nil {
|
||||
if callExpr, ok := expr.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(expr, ctx, false) == nil {
|
||||
pos := returnsError(callExpr, ctx)
|
||||
if pos < 0 || pos >= len(stmt.Lhs) {
|
||||
return nil, nil
|
||||
|
@ -64,7 +64,7 @@ func (r *noErrorCheck) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, erro
|
|||
}
|
||||
}
|
||||
case *ast.ExprStmt:
|
||||
if callExpr, ok := stmt.X.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(stmt.X, ctx) == nil {
|
||||
if callExpr, ok := stmt.X.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(stmt.X, ctx, false) == nil {
|
||||
pos := returnsError(callExpr, ctx)
|
||||
if pos >= 0 {
|
||||
return gosec.NewIssue(ctx, n, r.ID(), r.What, r.Severity, r.Confidence), nil
|
||||
|
|
|
@ -34,7 +34,7 @@ func (r *readfile) ID() string {
|
|||
|
||||
// isJoinFunc checks if there is a filepath.Join or other join function
|
||||
func (r *readfile) isJoinFunc(n ast.Node, c *gosec.Context) bool {
|
||||
if call := r.pathJoin.ContainsCallExpr(n, c); call != nil {
|
||||
if call := r.pathJoin.ContainsCallExpr(n, c, false); call != nil {
|
||||
for _, arg := range call.Args {
|
||||
// edge case: check if one of the args is a BinaryExpr
|
||||
if binExp, ok := arg.(*ast.BinaryExpr); ok {
|
||||
|
@ -44,21 +44,21 @@ func (r *readfile) isJoinFunc(n ast.Node, c *gosec.Context) bool {
|
|||
}
|
||||
}
|
||||
|
||||
// try and resolve identity
|
||||
if ident, ok := arg.(*ast.Ident); ok {
|
||||
obj := c.Info.ObjectOf(ident)
|
||||
if _, ok := obj.(*types.Var); ok && !gosec.TryResolve(ident, c) {
|
||||
return true
|
||||
// try and resolve identity
|
||||
if ident, ok := arg.(*ast.Ident); ok {
|
||||
obj := c.Info.ObjectOf(ident)
|
||||
if _, ok := obj.(*types.Var); ok && !gosec.TryResolve(ident, c) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Match inspects AST nodes to determine if the match the methods `os.Open` or `ioutil.ReadFile`
|
||||
func (r *readfile) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
||||
if node := r.ContainsCallExpr(n, c); node != nil {
|
||||
if node := r.ContainsCallExpr(n, c, false); node != nil {
|
||||
for _, arg := range node.Args {
|
||||
// handles path joining functions in Arg
|
||||
// eg. os.Open(filepath.Join("/tmp/", file))
|
||||
|
|
|
@ -32,7 +32,7 @@ func (w *weakKeyStrength) ID() string {
|
|||
}
|
||||
|
||||
func (w *weakKeyStrength) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
||||
if callExpr := w.calls.ContainsCallExpr(n, c); callExpr != nil {
|
||||
if callExpr := w.calls.ContainsCallExpr(n, c, false); callExpr != nil {
|
||||
if bits, err := gosec.GetInt(callExpr.Args[1]); err == nil && bits < (int64)(w.bits) {
|
||||
return gosec.NewIssue(c, n, w.ID(), w.What, w.Severity, w.Confidence), nil
|
||||
}
|
||||
|
|
41
rules/sql.go
41
rules/sql.go
|
@ -98,8 +98,9 @@ func NewSQLStrConcat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
|
|||
|
||||
type sqlStrFormat struct {
|
||||
sqlStatement
|
||||
calls gosec.CallList
|
||||
noIssue gosec.CallList
|
||||
calls gosec.CallList
|
||||
noIssue gosec.CallList
|
||||
noIssueQuoted gosec.CallList
|
||||
}
|
||||
|
||||
// Looks for "fmt.Sprintf("SELECT * FROM foo where '%s', userInput)"
|
||||
|
@ -109,7 +110,7 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
|
|||
argIndex := 0
|
||||
|
||||
// TODO(gm) improve confidence if database/sql is being used
|
||||
if node := s.calls.ContainsCallExpr(n, c); node != nil {
|
||||
if node := s.calls.ContainsCallExpr(n, c, false); node != nil {
|
||||
// if the function is fmt.Fprintf, search for SQL statement in Args[1] instead
|
||||
if sel, ok := node.Fun.(*ast.SelectorExpr); ok {
|
||||
if sel.Sel.Name == "Fprintf" {
|
||||
|
@ -125,17 +126,35 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
|
|||
argIndex = 1
|
||||
}
|
||||
}
|
||||
|
||||
var formatter string
|
||||
|
||||
// concats callexpr arg strings together if needed before regex evaluation
|
||||
if argExpr, ok := node.Args[argIndex].(*ast.BinaryExpr); ok {
|
||||
if fullStr, ok := gosec.ConcatString(argExpr); ok {
|
||||
if s.MatchPatterns(fullStr) {
|
||||
return gosec.NewIssue(c, n, s.ID(), s.What, s.Severity, s.Confidence),
|
||||
nil
|
||||
}
|
||||
formatter = fullStr
|
||||
}
|
||||
} else if arg, e := gosec.GetString(node.Args[argIndex]); e == nil {
|
||||
formatter = arg
|
||||
}
|
||||
if len(formatter) <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if arg, e := gosec.GetString(node.Args[argIndex]); s.MatchPatterns(arg) && e == nil {
|
||||
// If all formatter args are quoted, then the SQL construction is safe
|
||||
if argIndex+1 < len(node.Args) {
|
||||
allQuoted := true
|
||||
for _, arg := range node.Args[argIndex+1:] {
|
||||
if n := s.noIssueQuoted.ContainsCallExpr(arg, c, true); n == nil {
|
||||
allQuoted = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allQuoted {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
if s.MatchPatterns(formatter) {
|
||||
return gosec.NewIssue(c, n, s.ID(), s.What, s.Severity, s.Confidence), nil
|
||||
}
|
||||
}
|
||||
|
@ -145,8 +164,9 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
|
|||
// NewSQLStrFormat looks for cases where we're building SQL query strings using format strings
|
||||
func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
|
||||
rule := &sqlStrFormat{
|
||||
calls: gosec.NewCallList(),
|
||||
noIssue: gosec.NewCallList(),
|
||||
calls: gosec.NewCallList(),
|
||||
noIssue: gosec.NewCallList(),
|
||||
noIssueQuoted: gosec.NewCallList(),
|
||||
sqlStatement: sqlStatement{
|
||||
patterns: []*regexp.Regexp{
|
||||
regexp.MustCompile("(?)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) "),
|
||||
|
@ -162,5 +182,6 @@ func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
|
|||
}
|
||||
rule.calls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf")
|
||||
rule.noIssue.AddAll("os", "Stdout", "Stderr")
|
||||
rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier")
|
||||
return rule, []ast.Node{(*ast.CallExpr)(nil)}
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ func (r *ssrf) ResolveVar(n *ast.CallExpr, c *gosec.Context) bool {
|
|||
// Match inspects AST nodes to determine if certain net/http methods are called with variable input
|
||||
func (r *ssrf) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
||||
// Call expression is using http package directly
|
||||
if node := r.ContainsCallExpr(n, c); node != nil {
|
||||
if node := r.ContainsCallExpr(n, c, false); node != nil {
|
||||
if r.ResolveVar(node, c) {
|
||||
return gosec.NewIssue(c, n, r.ID(), r.What, r.Severity, r.Confidence), nil
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ func (r *subprocess) ID() string {
|
|||
//
|
||||
// syscall.Exec("echo", "foobar" + tainted)
|
||||
func (r *subprocess) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
||||
if node := r.ContainsCallExpr(n, c); node != nil {
|
||||
if node := r.ContainsCallExpr(n, c, false); node != nil {
|
||||
for _, arg := range node.Args {
|
||||
if ident, ok := arg.(*ast.Ident); ok {
|
||||
obj := c.Info.ObjectOf(ident)
|
||||
|
|
|
@ -32,7 +32,7 @@ func (t *badTempFile) ID() string {
|
|||
}
|
||||
|
||||
func (t *badTempFile) Match(n ast.Node, c *gosec.Context) (gi *gosec.Issue, err error) {
|
||||
if node := t.calls.ContainsCallExpr(n, c); node != nil {
|
||||
if node := t.calls.ContainsCallExpr(n, c, false); node != nil {
|
||||
if arg, e := gosec.GetString(node.Args[0]); t.args.MatchString(arg) && e == nil {
|
||||
return gosec.NewIssue(c, n, t.ID(), t.What, t.Severity, t.Confidence), nil
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ func (t *templateCheck) ID() string {
|
|||
}
|
||||
|
||||
func (t *templateCheck) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
||||
if node := t.calls.ContainsCallExpr(n, c); node != nil {
|
||||
if node := t.calls.ContainsCallExpr(n, c, false); node != nil {
|
||||
for _, arg := range node.Args {
|
||||
if _, ok := arg.(*ast.BasicLit); !ok { // basic lits are safe
|
||||
return gosec.NewIssue(c, n, t.ID(), t.What, t.Severity, t.Confidence), nil
|
||||
|
|
|
@ -292,6 +292,27 @@ func main(){
|
|||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
}`, 0}, {`
|
||||
// Format string false positive, quoted formatter argument.
|
||||
package main
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
func main(){
|
||||
db, err := sql.Open("postgres", "localhost")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
q := fmt.Sprintf("SELECT * FROM %s where id = 1", pq.QuoteIdentifier(os.Args[1]))
|
||||
rows, err := db.Query(q)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
}`, 0}}
|
||||
|
||||
// SampleCodeG202 - SQL query string building via string concatenation
|
||||
|
|
Loading…
Reference in a new issue