mirror of
https://github.com/securego/gosec.git
synced 2024-12-24 11:35:52 +00:00
Extend helpers and call list
- Update call list to work directly with call expression - Add call list test cases - Extend helpers to add GetCallInfo to resolve call name and package or type if it's a var. - Add test cases to ensure correct behaviour
This commit is contained in:
parent
d29c64800e
commit
5242a2c1df
4 changed files with 139 additions and 51 deletions
|
@ -13,16 +13,13 @@
|
|||
|
||||
package core
|
||||
|
||||
type set map[string]bool
|
||||
import "go/ast"
|
||||
|
||||
type calls struct {
|
||||
matchAny bool
|
||||
functions set
|
||||
}
|
||||
type set map[string]bool
|
||||
|
||||
/// CallList is used to check for usage of specific packages
|
||||
/// and functions.
|
||||
type CallList map[string]*calls
|
||||
type CallList map[string]set
|
||||
|
||||
/// NewCallList creates a new empty CallList
|
||||
func NewCallList() CallList {
|
||||
|
@ -30,36 +27,39 @@ func NewCallList() CallList {
|
|||
}
|
||||
|
||||
/// NewCallListFor createse a call list using the package path
|
||||
func NewCallListFor(pkg string, funcs ...string) CallList {
|
||||
func NewCallListFor(selector string, idents ...string) CallList {
|
||||
c := NewCallList()
|
||||
if len(funcs) == 0 {
|
||||
c[pkg] = &calls{true, make(set)}
|
||||
} else {
|
||||
for _, fn := range funcs {
|
||||
c.Add(pkg, fn)
|
||||
}
|
||||
c[selector] = make(set)
|
||||
for _, ident := range idents {
|
||||
c.Add(selector, ident)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
/// Add a new package and function to the call list
|
||||
func (c CallList) Add(pkg, fn string) {
|
||||
if cl, ok := c[pkg]; ok {
|
||||
if cl.matchAny {
|
||||
cl.matchAny = false
|
||||
}
|
||||
} else {
|
||||
c[pkg] = &calls{false, make(set)}
|
||||
/// Add a selector and call to the call list
|
||||
func (c CallList) Add(selector, ident string) {
|
||||
if _, ok := c[selector]; !ok {
|
||||
c[selector] = make(set)
|
||||
}
|
||||
c[pkg].functions[fn] = true
|
||||
c[selector][ident] = true
|
||||
}
|
||||
|
||||
/// Contains returns true if the package and function are
|
||||
/// members of this call list.
|
||||
func (c CallList) Contains(pkg, fn string) bool {
|
||||
if funcs, ok := c[pkg]; ok {
|
||||
_, ok = funcs.functions[fn]
|
||||
return ok || funcs.matchAny
|
||||
func (c CallList) Contains(selector, ident string) bool {
|
||||
if idents, ok := c[selector]; ok {
|
||||
_, found := idents[ident]
|
||||
return found
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/// ContainsCallExpr resolves the call expression name and type
|
||||
/// or package and determines if it exists within the CallList
|
||||
func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) bool {
|
||||
selector, ident, err := GetCallInfo(n, ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return c.Contains(selector, ident)
|
||||
}
|
||||
|
|
58
core/call_list_test.go
Normal file
58
core/call_list_test.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type callListRule struct {
|
||||
MetaData
|
||||
callList CallList
|
||||
matched int
|
||||
}
|
||||
|
||||
func (r *callListRule) Match(n ast.Node, c *Context) (gi *Issue, err error) {
|
||||
if r.callList.ContainsCallExpr(n, c) {
|
||||
r.matched += 1
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestCallListContainsCallExpr(t *testing.T) {
|
||||
config := map[string]interface{}{"ignoreNosec": false}
|
||||
analyzer := NewAnalyzer(config, nil)
|
||||
rule := &callListRule{
|
||||
MetaData: MetaData{
|
||||
Severity: Low,
|
||||
Confidence: Low,
|
||||
What: "A dummy rule",
|
||||
},
|
||||
callList: NewCallListFor("bytes.Buffer", "Write", "WriteTo"),
|
||||
matched: 0,
|
||||
}
|
||||
analyzer.AddRule(rule, []ast.Node{(*ast.CallExpr)(nil)})
|
||||
source := `
|
||||
package main
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
func main() {
|
||||
var b bytes.Buffer
|
||||
b.Write([]byte("Hello "))
|
||||
fmt.Fprintf(&b, "world!")
|
||||
}`
|
||||
|
||||
analyzer.ProcessSource("dummy.go", source)
|
||||
if rule.matched != 1 {
|
||||
t.Errorf("Expected to match a bytes.Buffer.Write call")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallListContains(t *testing.T) {
|
||||
callList := NewCallList()
|
||||
callList.Add("fmt", "Printf")
|
||||
if !callList.Contains("fmt", "Printf") {
|
||||
t.Errorf("Expected call list to contain fmt.Printf")
|
||||
}
|
||||
}
|
|
@ -69,18 +69,15 @@ func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*a
|
|||
importName = alias
|
||||
}
|
||||
|
||||
switch node := n.(type) {
|
||||
case *ast.CallExpr:
|
||||
switch fn := node.Fun.(type) {
|
||||
case *ast.SelectorExpr:
|
||||
switch expr := fn.X.(type) {
|
||||
case *ast.Ident:
|
||||
if expr.Name == importName {
|
||||
for _, name := range names {
|
||||
if fn.Sel.Name == name {
|
||||
return node, true
|
||||
}
|
||||
}
|
||||
if callExpr, ok := n.(*ast.CallExpr); ok {
|
||||
packageName, callName, err := GetCallInfo(callExpr, c)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
if packageName == importName {
|
||||
for _, name := range names {
|
||||
if callName == name {
|
||||
return callExpr, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -95,19 +92,15 @@ func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*a
|
|||
// node, matched := MatchCallByType(n, ctx, "bytes.Buffer", "WriteTo", "Write")
|
||||
//
|
||||
func MatchCallByType(n ast.Node, ctx *Context, requiredType string, calls ...string) (*ast.CallExpr, bool) {
|
||||
switch callExpr := n.(type) {
|
||||
case *ast.CallExpr:
|
||||
switch fn := callExpr.Fun.(type) {
|
||||
case *ast.SelectorExpr:
|
||||
switch expr := fn.X.(type) {
|
||||
case *ast.Ident:
|
||||
t := ctx.Info.TypeOf(expr)
|
||||
if t != nil && t.String() == requiredType {
|
||||
for _, call := range calls {
|
||||
if fn.Sel.Name == call {
|
||||
return callExpr, true
|
||||
}
|
||||
}
|
||||
if callExpr, ok := n.(*ast.CallExpr); ok {
|
||||
typeName, callName, err := GetCallInfo(callExpr, ctx)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
if typeName == requiredType {
|
||||
for _, call := range calls {
|
||||
if call == callName {
|
||||
return callExpr, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -171,3 +164,28 @@ func GetCallObject(n ast.Node, ctx *Context) (*ast.CallExpr, types.Object) {
|
|||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetCallInfo returns the package or type and name associated with a
|
||||
// call expression.
|
||||
func GetCallInfo(n ast.Node, ctx *Context) (string, string, error) {
|
||||
switch node := n.(type) {
|
||||
case *ast.CallExpr:
|
||||
switch fn := node.Fun.(type) {
|
||||
case *ast.SelectorExpr:
|
||||
switch expr := fn.X.(type) {
|
||||
case *ast.Ident:
|
||||
if expr.Obj != nil && expr.Obj.Kind == ast.Var {
|
||||
t := ctx.Info.TypeOf(expr)
|
||||
if t != nil {
|
||||
return t.String(), fn.Sel.Name, nil
|
||||
} else {
|
||||
return "undefined", fn.Sel.Name, fmt.Errorf("missing type info")
|
||||
}
|
||||
} else {
|
||||
return expr.Name, fn.Sel.Name, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", "", fmt.Errorf("unable to determine call info")
|
||||
}
|
||||
|
|
|
@ -56,4 +56,16 @@ func TestMatchCallByType(t *testing.T) {
|
|||
if rule.matched != 1 || len(rule.callExpr) != 1 {
|
||||
t.Errorf("Expected to match a bytes.Buffer.Write call")
|
||||
}
|
||||
|
||||
typeName, callName, err := GetCallInfo(rule.callExpr[0], &analyzer.context)
|
||||
if err != nil {
|
||||
t.Errorf("Unable to resolve call info: %v\n", err)
|
||||
}
|
||||
if typeName != "bytes.Buffer" {
|
||||
t.Errorf("Expected: %s, Got: %s\n", "bytes.Buffer", typeName)
|
||||
}
|
||||
if callName != "Write" {
|
||||
t.Errorf("Expected: %s, Got: %s\n", "Write", callName)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue