diff --git a/rules/rand.go b/rules/rand.go index 6c5aa50..45ddd67 100644 --- a/rules/rand.go +++ b/rules/rand.go @@ -17,6 +17,7 @@ package rules import ( "go/ast" "regexp" + "strings" gas "github.com/GoASTScanner/gas/core" ) @@ -24,25 +25,81 @@ import ( type WeakRand struct { gas.MetaData pattern *regexp.Regexp - packageName string packagePath string } -func (w *WeakRand) Match(n ast.Node, c *gas.Context) (*gas.Issue, error) { - if call := gas.MatchCall(n, w.pattern); call != nil { - for _, pkg := range c.Pkg.Imports() { - if pkg.Name() == w.packageName && pkg.Path() == w.packagePath { - return gas.NewIssue(c, n, w.What, w.Severity, w.Confidence), nil - } - } +type pkgFunc struct { + packagePath string + funcName string +} + +// pkgId takes an import line and returns the identifier used +// for that package in the rest of the file +func pkgId(i *ast.ImportSpec) string { + if i.Name != nil { + return i.Name.String() } + trim := strings.Trim(i.Path.Value, `"`) + a := strings.Split(trim, "/") + return a[len(a)-1] +} + +// importIds returns a map of import names to their full paths +func importIds(f *ast.File) map[string]string { + pkgs := make(map[string]string) + for _, v := range f.Imports { + pkgs[pkgId(v)] = strings.Trim(v.Path.Value, `"`) + } + return pkgs +} + +// matchPkgFunc will return package level function calls split +// by full package path and function name +func matchPkgFunc(n ast.Node, c *gas.Context) *pkgFunc { + call, ok := n.(*ast.CallExpr) + if !ok { + return nil + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return nil + } + + id, ok := sel.X.(*ast.Ident) + if !ok { + return nil + } + + if id.Obj != nil { + return nil + } + + i := importIds(c.Root) + v, ok := i[id.Name] + if !ok { + return nil + } + + return &pkgFunc{ + packagePath: v, + funcName: sel.Sel.String(), + } +} + +func (w *WeakRand) Match(n ast.Node, c *gas.Context) (*gas.Issue, error) { + call := matchPkgFunc(n, c) + + if call != nil && call.packagePath == w.packagePath && w.pattern.MatchString(call.funcName) { + return gas.NewIssue(c, n, w.What, w.Severity, w.Confidence), nil + } + return nil, nil } func NewWeakRandCheck(conf map[string]interface{}) (r gas.Rule, n ast.Node) { r = &WeakRand{ - pattern: regexp.MustCompile(`^rand\.Read$`), - packageName: "rand", + pattern: regexp.MustCompile(`^Read$`), packagePath: "math/rand", MetaData: gas.MetaData{ Severity: gas.High, diff --git a/rules/rand_test.go b/rules/rand_test.go index 150003f..8652e59 100644 --- a/rules/rand_test.go +++ b/rules/rand_test.go @@ -55,3 +55,26 @@ func TestRandBad(t *testing.T) { checkTestResults(t, issues, 1, "Use of weak random number generator (math/rand instead of crypto/rand)") } + +func TestRandRenamed(t *testing.T) { + config := map[string]interface{}{"ignoreNosec": false} + analyzer := gas.NewAnalyzer(config, nil) + analyzer.AddRule(NewWeakRandCheck(config)) + + issues := gasTestRunner( + ` + package samples + + import ( + "crypto/rand" + mrand "math/rand" + ) + + + func main() { + good, err := rand.Read(nil) + i := mrand.Int() + }`, analyzer) + + checkTestResults(t, issues, 0, "Not expected to match") +}