mirror of
https://github.com/securego/gosec.git
synced 2024-12-25 03:55:54 +00:00
Allow quoted strings to be used to format SQL queries (#240)
* Support stripping vendor paths when matching calls * Factor out matching of formatter string * Quoted strings are safe to use with SQL str formatted strings * Add test for allowing quoted strings with string formatters * Install the pq package for tests to pass
This commit is contained in:
parent
ec32ce68d8
commit
762ff3a709
14 changed files with 88 additions and 33 deletions
|
@ -12,6 +12,7 @@ install:
|
||||||
- go get -u github.com/onsi/ginkgo/ginkgo
|
- go get -u github.com/onsi/ginkgo/ginkgo
|
||||||
- go get -u github.com/onsi/gomega
|
- go get -u github.com/onsi/gomega
|
||||||
- go get -u golang.org/x/crypto/ssh
|
- go get -u golang.org/x/crypto/ssh
|
||||||
|
- go get -u github.com/lib/pq
|
||||||
- go get -u github.com/securego/gosec/cmd/gosec/...
|
- go get -u github.com/securego/gosec/cmd/gosec/...
|
||||||
- go get -v -t ./...
|
- go get -v -t ./...
|
||||||
- export PATH=$PATH:$HOME/gopath/bin
|
- export PATH=$PATH:$HOME/gopath/bin
|
||||||
|
|
22
call_list.go
22
call_list.go
|
@ -15,8 +15,11 @@ package gosec
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"go/ast"
|
"go/ast"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const vendorPath = "vendor/"
|
||||||
|
|
||||||
type set map[string]bool
|
type set map[string]bool
|
||||||
|
|
||||||
// CallList is used to check for usage of specific packages
|
// CallList is used to check for usage of specific packages
|
||||||
|
@ -55,17 +58,27 @@ func (c CallList) Contains(selector, ident string) bool {
|
||||||
|
|
||||||
// ContainsCallExpr resolves the call expression name and type
|
// ContainsCallExpr resolves the call expression name and type
|
||||||
/// or package and determines if it exists within the CallList
|
/// or package and determines if it exists within the CallList
|
||||||
func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) *ast.CallExpr {
|
func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context, stripVendor bool) *ast.CallExpr {
|
||||||
selector, ident, err := GetCallInfo(n, ctx)
|
selector, ident, err := GetCallInfo(n, ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use only explicit path to reduce conflicts
|
// Use only explicit path (optionally strip vendor path prefix) to reduce conflicts
|
||||||
if path, ok := GetImportPath(selector, ctx); ok && c.Contains(path, ident) {
|
path, ok := GetImportPath(selector, ctx)
|
||||||
return n.(*ast.CallExpr)
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if stripVendor {
|
||||||
|
if vendorIdx := strings.Index(path, vendorPath); vendorIdx >= 0 {
|
||||||
|
path = path[vendorIdx+len(vendorPath):]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !c.Contains(path, ident) {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return n.(*ast.CallExpr)
|
||||||
/*
|
/*
|
||||||
// Try direct resolution
|
// Try direct resolution
|
||||||
if c.Contains(selector, ident) {
|
if c.Contains(selector, ident) {
|
||||||
|
@ -74,5 +87,4 @@ func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) *ast.CallExpr {
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -73,7 +73,7 @@ var _ = Describe("call list", func() {
|
||||||
v := testutils.NewMockVisitor()
|
v := testutils.NewMockVisitor()
|
||||||
v.Context = ctx
|
v.Context = ctx
|
||||||
v.Callback = func(n ast.Node, ctx *gosec.Context) bool {
|
v.Callback = func(n ast.Node, ctx *gosec.Context) bool {
|
||||||
if _, ok := n.(*ast.CallExpr); ok && calls.ContainsCallExpr(n, ctx) != nil {
|
if _, ok := n.(*ast.CallExpr); ok && calls.ContainsCallExpr(n, ctx, false) != nil {
|
||||||
matched++
|
matched++
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
|
|
|
@ -19,7 +19,7 @@ func (a *archive) ID() string {
|
||||||
|
|
||||||
// Match inspects AST nodes to determine if the filepath.Joins uses any argument derived from type zip.File
|
// Match inspects AST nodes to determine if the filepath.Joins uses any argument derived from type zip.File
|
||||||
func (a *archive) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
func (a *archive) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
||||||
if node := a.calls.ContainsCallExpr(n, c); node != nil {
|
if node := a.calls.ContainsCallExpr(n, c, false); node != nil {
|
||||||
for _, arg := range node.Args {
|
for _, arg := range node.Args {
|
||||||
var argType types.Type
|
var argType types.Type
|
||||||
if selector, ok := arg.(*ast.SelectorExpr); ok {
|
if selector, ok := arg.(*ast.SelectorExpr); ok {
|
||||||
|
|
|
@ -33,7 +33,7 @@ func (r *bindsToAllNetworkInterfaces) ID() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *bindsToAllNetworkInterfaces) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
func (r *bindsToAllNetworkInterfaces) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
||||||
callExpr := r.calls.ContainsCallExpr(n, c)
|
callExpr := r.calls.ContainsCallExpr(n, c, false)
|
||||||
if callExpr == nil {
|
if callExpr == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,7 +53,7 @@ func (r *noErrorCheck) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, erro
|
||||||
switch stmt := n.(type) {
|
switch stmt := n.(type) {
|
||||||
case *ast.AssignStmt:
|
case *ast.AssignStmt:
|
||||||
for _, expr := range stmt.Rhs {
|
for _, expr := range stmt.Rhs {
|
||||||
if callExpr, ok := expr.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(expr, ctx) == nil {
|
if callExpr, ok := expr.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(expr, ctx, false) == nil {
|
||||||
pos := returnsError(callExpr, ctx)
|
pos := returnsError(callExpr, ctx)
|
||||||
if pos < 0 || pos >= len(stmt.Lhs) {
|
if pos < 0 || pos >= len(stmt.Lhs) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -64,7 +64,7 @@ func (r *noErrorCheck) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, erro
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case *ast.ExprStmt:
|
case *ast.ExprStmt:
|
||||||
if callExpr, ok := stmt.X.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(stmt.X, ctx) == nil {
|
if callExpr, ok := stmt.X.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(stmt.X, ctx, false) == nil {
|
||||||
pos := returnsError(callExpr, ctx)
|
pos := returnsError(callExpr, ctx)
|
||||||
if pos >= 0 {
|
if pos >= 0 {
|
||||||
return gosec.NewIssue(ctx, n, r.ID(), r.What, r.Severity, r.Confidence), nil
|
return gosec.NewIssue(ctx, n, r.ID(), r.What, r.Severity, r.Confidence), nil
|
||||||
|
|
|
@ -34,7 +34,7 @@ func (r *readfile) ID() string {
|
||||||
|
|
||||||
// isJoinFunc checks if there is a filepath.Join or other join function
|
// isJoinFunc checks if there is a filepath.Join or other join function
|
||||||
func (r *readfile) isJoinFunc(n ast.Node, c *gosec.Context) bool {
|
func (r *readfile) isJoinFunc(n ast.Node, c *gosec.Context) bool {
|
||||||
if call := r.pathJoin.ContainsCallExpr(n, c); call != nil {
|
if call := r.pathJoin.ContainsCallExpr(n, c, false); call != nil {
|
||||||
for _, arg := range call.Args {
|
for _, arg := range call.Args {
|
||||||
// edge case: check if one of the args is a BinaryExpr
|
// edge case: check if one of the args is a BinaryExpr
|
||||||
if binExp, ok := arg.(*ast.BinaryExpr); ok {
|
if binExp, ok := arg.(*ast.BinaryExpr); ok {
|
||||||
|
@ -44,21 +44,21 @@ func (r *readfile) isJoinFunc(n ast.Node, c *gosec.Context) bool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// try and resolve identity
|
// try and resolve identity
|
||||||
if ident, ok := arg.(*ast.Ident); ok {
|
if ident, ok := arg.(*ast.Ident); ok {
|
||||||
obj := c.Info.ObjectOf(ident)
|
obj := c.Info.ObjectOf(ident)
|
||||||
if _, ok := obj.(*types.Var); ok && !gosec.TryResolve(ident, c) {
|
if _, ok := obj.(*types.Var); ok && !gosec.TryResolve(ident, c) {
|
||||||
return true
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Match inspects AST nodes to determine if the match the methods `os.Open` or `ioutil.ReadFile`
|
// Match inspects AST nodes to determine if the match the methods `os.Open` or `ioutil.ReadFile`
|
||||||
func (r *readfile) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
func (r *readfile) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
||||||
if node := r.ContainsCallExpr(n, c); node != nil {
|
if node := r.ContainsCallExpr(n, c, false); node != nil {
|
||||||
for _, arg := range node.Args {
|
for _, arg := range node.Args {
|
||||||
// handles path joining functions in Arg
|
// handles path joining functions in Arg
|
||||||
// eg. os.Open(filepath.Join("/tmp/", file))
|
// eg. os.Open(filepath.Join("/tmp/", file))
|
||||||
|
|
|
@ -32,7 +32,7 @@ func (w *weakKeyStrength) ID() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *weakKeyStrength) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
func (w *weakKeyStrength) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
||||||
if callExpr := w.calls.ContainsCallExpr(n, c); callExpr != nil {
|
if callExpr := w.calls.ContainsCallExpr(n, c, false); callExpr != nil {
|
||||||
if bits, err := gosec.GetInt(callExpr.Args[1]); err == nil && bits < (int64)(w.bits) {
|
if bits, err := gosec.GetInt(callExpr.Args[1]); err == nil && bits < (int64)(w.bits) {
|
||||||
return gosec.NewIssue(c, n, w.ID(), w.What, w.Severity, w.Confidence), nil
|
return gosec.NewIssue(c, n, w.ID(), w.What, w.Severity, w.Confidence), nil
|
||||||
}
|
}
|
||||||
|
|
41
rules/sql.go
41
rules/sql.go
|
@ -98,8 +98,9 @@ func NewSQLStrConcat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
|
||||||
|
|
||||||
type sqlStrFormat struct {
|
type sqlStrFormat struct {
|
||||||
sqlStatement
|
sqlStatement
|
||||||
calls gosec.CallList
|
calls gosec.CallList
|
||||||
noIssue gosec.CallList
|
noIssue gosec.CallList
|
||||||
|
noIssueQuoted gosec.CallList
|
||||||
}
|
}
|
||||||
|
|
||||||
// Looks for "fmt.Sprintf("SELECT * FROM foo where '%s', userInput)"
|
// Looks for "fmt.Sprintf("SELECT * FROM foo where '%s', userInput)"
|
||||||
|
@ -109,7 +110,7 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
|
||||||
argIndex := 0
|
argIndex := 0
|
||||||
|
|
||||||
// TODO(gm) improve confidence if database/sql is being used
|
// TODO(gm) improve confidence if database/sql is being used
|
||||||
if node := s.calls.ContainsCallExpr(n, c); node != nil {
|
if node := s.calls.ContainsCallExpr(n, c, false); node != nil {
|
||||||
// if the function is fmt.Fprintf, search for SQL statement in Args[1] instead
|
// if the function is fmt.Fprintf, search for SQL statement in Args[1] instead
|
||||||
if sel, ok := node.Fun.(*ast.SelectorExpr); ok {
|
if sel, ok := node.Fun.(*ast.SelectorExpr); ok {
|
||||||
if sel.Sel.Name == "Fprintf" {
|
if sel.Sel.Name == "Fprintf" {
|
||||||
|
@ -125,17 +126,35 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
|
||||||
argIndex = 1
|
argIndex = 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var formatter string
|
||||||
|
|
||||||
// concats callexpr arg strings together if needed before regex evaluation
|
// concats callexpr arg strings together if needed before regex evaluation
|
||||||
if argExpr, ok := node.Args[argIndex].(*ast.BinaryExpr); ok {
|
if argExpr, ok := node.Args[argIndex].(*ast.BinaryExpr); ok {
|
||||||
if fullStr, ok := gosec.ConcatString(argExpr); ok {
|
if fullStr, ok := gosec.ConcatString(argExpr); ok {
|
||||||
if s.MatchPatterns(fullStr) {
|
formatter = fullStr
|
||||||
return gosec.NewIssue(c, n, s.ID(), s.What, s.Severity, s.Confidence),
|
|
||||||
nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
} else if arg, e := gosec.GetString(node.Args[argIndex]); e == nil {
|
||||||
|
formatter = arg
|
||||||
|
}
|
||||||
|
if len(formatter) <= 0 {
|
||||||
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if arg, e := gosec.GetString(node.Args[argIndex]); s.MatchPatterns(arg) && e == nil {
|
// If all formatter args are quoted, then the SQL construction is safe
|
||||||
|
if argIndex+1 < len(node.Args) {
|
||||||
|
allQuoted := true
|
||||||
|
for _, arg := range node.Args[argIndex+1:] {
|
||||||
|
if n := s.noIssueQuoted.ContainsCallExpr(arg, c, true); n == nil {
|
||||||
|
allQuoted = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if allQuoted {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.MatchPatterns(formatter) {
|
||||||
return gosec.NewIssue(c, n, s.ID(), s.What, s.Severity, s.Confidence), nil
|
return gosec.NewIssue(c, n, s.ID(), s.What, s.Severity, s.Confidence), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -145,8 +164,9 @@ func (s *sqlStrFormat) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error)
|
||||||
// NewSQLStrFormat looks for cases where we're building SQL query strings using format strings
|
// NewSQLStrFormat looks for cases where we're building SQL query strings using format strings
|
||||||
func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
|
func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
|
||||||
rule := &sqlStrFormat{
|
rule := &sqlStrFormat{
|
||||||
calls: gosec.NewCallList(),
|
calls: gosec.NewCallList(),
|
||||||
noIssue: gosec.NewCallList(),
|
noIssue: gosec.NewCallList(),
|
||||||
|
noIssueQuoted: gosec.NewCallList(),
|
||||||
sqlStatement: sqlStatement{
|
sqlStatement: sqlStatement{
|
||||||
patterns: []*regexp.Regexp{
|
patterns: []*regexp.Regexp{
|
||||||
regexp.MustCompile("(?)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) "),
|
regexp.MustCompile("(?)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) "),
|
||||||
|
@ -162,5 +182,6 @@ func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
|
||||||
}
|
}
|
||||||
rule.calls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf")
|
rule.calls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf")
|
||||||
rule.noIssue.AddAll("os", "Stdout", "Stderr")
|
rule.noIssue.AddAll("os", "Stdout", "Stderr")
|
||||||
|
rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier")
|
||||||
return rule, []ast.Node{(*ast.CallExpr)(nil)}
|
return rule, []ast.Node{(*ast.CallExpr)(nil)}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,7 @@ func (r *ssrf) ResolveVar(n *ast.CallExpr, c *gosec.Context) bool {
|
||||||
// Match inspects AST nodes to determine if certain net/http methods are called with variable input
|
// Match inspects AST nodes to determine if certain net/http methods are called with variable input
|
||||||
func (r *ssrf) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
func (r *ssrf) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
||||||
// Call expression is using http package directly
|
// Call expression is using http package directly
|
||||||
if node := r.ContainsCallExpr(n, c); node != nil {
|
if node := r.ContainsCallExpr(n, c, false); node != nil {
|
||||||
if r.ResolveVar(node, c) {
|
if r.ResolveVar(node, c) {
|
||||||
return gosec.NewIssue(c, n, r.ID(), r.What, r.Severity, r.Confidence), nil
|
return gosec.NewIssue(c, n, r.ID(), r.What, r.Severity, r.Confidence), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ func (r *subprocess) ID() string {
|
||||||
//
|
//
|
||||||
// syscall.Exec("echo", "foobar" + tainted)
|
// syscall.Exec("echo", "foobar" + tainted)
|
||||||
func (r *subprocess) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
func (r *subprocess) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
||||||
if node := r.ContainsCallExpr(n, c); node != nil {
|
if node := r.ContainsCallExpr(n, c, false); node != nil {
|
||||||
for _, arg := range node.Args {
|
for _, arg := range node.Args {
|
||||||
if ident, ok := arg.(*ast.Ident); ok {
|
if ident, ok := arg.(*ast.Ident); ok {
|
||||||
obj := c.Info.ObjectOf(ident)
|
obj := c.Info.ObjectOf(ident)
|
||||||
|
|
|
@ -32,7 +32,7 @@ func (t *badTempFile) ID() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *badTempFile) Match(n ast.Node, c *gosec.Context) (gi *gosec.Issue, err error) {
|
func (t *badTempFile) Match(n ast.Node, c *gosec.Context) (gi *gosec.Issue, err error) {
|
||||||
if node := t.calls.ContainsCallExpr(n, c); node != nil {
|
if node := t.calls.ContainsCallExpr(n, c, false); node != nil {
|
||||||
if arg, e := gosec.GetString(node.Args[0]); t.args.MatchString(arg) && e == nil {
|
if arg, e := gosec.GetString(node.Args[0]); t.args.MatchString(arg) && e == nil {
|
||||||
return gosec.NewIssue(c, n, t.ID(), t.What, t.Severity, t.Confidence), nil
|
return gosec.NewIssue(c, n, t.ID(), t.What, t.Severity, t.Confidence), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,7 @@ func (t *templateCheck) ID() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *templateCheck) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
func (t *templateCheck) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
|
||||||
if node := t.calls.ContainsCallExpr(n, c); node != nil {
|
if node := t.calls.ContainsCallExpr(n, c, false); node != nil {
|
||||||
for _, arg := range node.Args {
|
for _, arg := range node.Args {
|
||||||
if _, ok := arg.(*ast.BasicLit); !ok { // basic lits are safe
|
if _, ok := arg.(*ast.BasicLit); !ok { // basic lits are safe
|
||||||
return gosec.NewIssue(c, n, t.ID(), t.What, t.Severity, t.Confidence), nil
|
return gosec.NewIssue(c, n, t.ID(), t.What, t.Severity, t.Confidence), nil
|
||||||
|
|
|
@ -292,6 +292,27 @@ func main(){
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
}`, 0}, {`
|
||||||
|
// Format string false positive, quoted formatter argument.
|
||||||
|
package main
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"github.com/lib/pq"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main(){
|
||||||
|
db, err := sql.Open("postgres", "localhost")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
q := fmt.Sprintf("SELECT * FROM %s where id = 1", pq.QuoteIdentifier(os.Args[1]))
|
||||||
|
rows, err := db.Query(q)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
}`, 0}}
|
}`, 0}}
|
||||||
|
|
||||||
// SampleCodeG202 - SQL query string building via string concatenation
|
// SampleCodeG202 - SQL query string building via string concatenation
|
||||||
|
|
Loading…
Reference in a new issue