diff --git a/analyzer.go b/analyzer.go index 0f9fef2..5f77879 100644 --- a/analyzer.go +++ b/analyzer.go @@ -172,9 +172,9 @@ func (gosec *Analyzer) Process(buildTags []string, packagePaths ...string) error for { select { case s := <-j: - packages, err := gosec.load(s, config) + pkgs, err := gosec.load(s, config) select { - case r <- result{pkgPath: s, pkgs: packages, err: err}: + case r <- result{pkgPath: s, pkgs: pkgs, err: err}: case <-quit: // we've been told to stop, probably an error while // processing a previous result. @@ -296,7 +296,6 @@ func (gosec *Analyzer) Check(pkg *packages.Package) { gosec.context.Pkg = pkg.Types gosec.context.PkgFiles = pkg.Syntax gosec.context.Imports = NewImportTracker() - gosec.context.Imports.TrackFile(file) gosec.context.PassedValues = make(map[string]interface{}) ast.Walk(gosec, file) gosec.stats.NumFiles++ @@ -434,6 +433,12 @@ func (gosec *Analyzer) Visit(n ast.Node) ast.Visitor { } return gosec } + switch i := n.(type) { + case *ast.File: + // Using ast.File instead of ast.ImportSpec, so that we can track + // all imports at once. + gosec.context.Imports.TrackFile(i) + } // Get any new rule exclusions. ignoredRules := gosec.ignore(n) @@ -453,9 +458,6 @@ func (gosec *Analyzer) Visit(n ast.Node) ast.Visitor { // Push the new set onto the stack. gosec.context.Ignores = append([]map[string][]SuppressionInfo{ignores}, gosec.context.Ignores...) - // Track aliased and initialization imports - gosec.context.Imports.TrackImport(n) - for _, rule := range gosec.ruleset.RegisteredFor(n) { // Check if all rules are ignored. generalSuppressions, generalIgnored := ignores[aliasOfAllRules] diff --git a/helpers.go b/helpers.go index bd6aff7..62ede05 100644 --- a/helpers.go +++ b/helpers.go @@ -37,12 +37,9 @@ import ( // // node, matched := MatchCallByPackage(n, ctx, "math/rand", "Read") func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*ast.CallExpr, bool) { - importedName, found := GetAliasedName(pkg, c) + importedNames, found := GetImportedNames(pkg, c) if !found { - importedName, found = GetImportedName(pkg, c) - if !found { - return nil, false - } + return nil, false } if callExpr, ok := n.(*ast.CallExpr); ok { @@ -50,7 +47,10 @@ func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*a if err != nil { return nil, false } - if packageName == importedName { + for _, in := range importedNames { + if packageName != in { + continue + } for _, name := range names { if callName == name { return callExpr, true @@ -247,48 +247,23 @@ func GetBinaryExprOperands(be *ast.BinaryExpr) []ast.Node { return result } -// GetImportedName returns the name used for the package within the -// code. It will ignore initialization only imports. -func GetImportedName(path string, ctx *Context) (string, bool) { - importName, imported := ctx.Imports.Imported[path] - if !imported { - return "", false - } - - if _, initonly := ctx.Imports.InitOnly[path]; initonly { - return "", false - } - - return importName, true -} - -// GetAliasedName returns the aliased name used for the package within the -// code. It will ignore initialization only imports. -func GetAliasedName(path string, ctx *Context) (string, bool) { - importName, imported := ctx.Imports.Aliased[path] - if !imported { - return "", false - } - - if _, initonly := ctx.Imports.InitOnly[path]; initonly { - return "", false - } - - return importName, true +// GetImportedNames returns the name(s)/alias(es) used for the package within +// the code. It ignores initialization-only imports. +func GetImportedNames(path string, ctx *Context) (names []string, found bool) { + importNames, imported := ctx.Imports.Imported[path] + return importNames, imported } // GetImportPath resolves the full import path of an identifier based on // the imports in the current context(including aliases). func GetImportPath(name string, ctx *Context) (string, bool) { for path := range ctx.Imports.Imported { - if imported, ok := GetImportedName(path, ctx); ok && imported == name { - return path, true - } - } - - for path := range ctx.Imports.Aliased { - if imported, ok := GetAliasedName(path, ctx); ok && imported == name { - return path, true + if imported, ok := GetImportedNames(path, ctx); ok { + for _, n := range imported { + if n == name { + return path, true + } + } } } diff --git a/import_tracker.go b/import_tracker.go index cbb8c55..30e7c00 100644 --- a/import_tracker.go +++ b/import_tracker.go @@ -22,54 +22,51 @@ import ( // by a source file. It is able to differentiate between plain imports, aliased // imports and init only imports. type ImportTracker struct { - Imported map[string]string - Aliased map[string]string - InitOnly map[string]bool + // Imported is a map of Imported with their associated names/aliases. + Imported map[string][]string } // NewImportTracker creates an empty Import tracker instance func NewImportTracker() *ImportTracker { return &ImportTracker{ - make(map[string]string), - make(map[string]string), - make(map[string]bool), + Imported: make(map[string][]string), } } // TrackFile track all the imports used by the supplied file func (t *ImportTracker) TrackFile(file *ast.File) { for _, imp := range file.Imports { - path := strings.Trim(imp.Path.Value, `"`) - parts := strings.Split(path, "/") - if len(parts) > 0 { - name := parts[len(parts)-1] - t.Imported[path] = name - } + t.TrackImport(imp) } } // TrackPackages tracks all the imports used by the supplied packages func (t *ImportTracker) TrackPackages(pkgs ...*types.Package) { for _, pkg := range pkgs { - t.Imported[pkg.Path()] = pkg.Name() + t.Imported[pkg.Path()] = []string{pkg.Name()} } } -// TrackImport tracks imports and handles the 'unsafe' import -func (t *ImportTracker) TrackImport(n ast.Node) { - if imported, ok := n.(*ast.ImportSpec); ok { - path := strings.Trim(imported.Path.Value, `"`) - if imported.Name != nil { - if imported.Name.Name == "_" { - // Initialization only import - t.InitOnly[path] = true - } else { - // Aliased import - t.Aliased[path] = imported.Name.Name - } - } - if path == "unsafe" { - t.Imported[path] = path +// TrackImport tracks imports. +func (t *ImportTracker) TrackImport(imported *ast.ImportSpec) { + importPath := strings.Trim(imported.Path.Value, `"`) + if imported.Name != nil { + if imported.Name.Name == "_" { + // Initialization only import + } else { + // Aliased import + t.Imported[importPath] = append(t.Imported[importPath], imported.Name.String()) } + } else { + t.Imported[importPath] = append(t.Imported[importPath], importName(importPath)) } } + +func importName(importPath string) string { + parts := strings.Split(importPath, "/") + name := importPath + if len(parts) > 0 { + name = parts[len(parts)-1] + } + return name +} diff --git a/import_tracker_test.go b/import_tracker_test.go index b060b34..4837312 100644 --- a/import_tracker_test.go +++ b/import_tracker_test.go @@ -27,7 +27,7 @@ var _ = Describe("Import Tracker", func() { files := pkgs[0].Syntax Expect(files).Should(HaveLen(1)) tracker.TrackFile(files[0]) - Expect(tracker.Imported).Should(Equal(map[string]string{"fmt": "fmt"})) + Expect(tracker.Imported).Should(Equal(map[string][]string{"fmt": {"fmt"}})) }) It("should parse the named imports from file", func() { tracker := gosec.NewImportTracker() @@ -47,7 +47,7 @@ var _ = Describe("Import Tracker", func() { files := pkgs[0].Syntax Expect(files).Should(HaveLen(1)) tracker.TrackFile(files[0]) - Expect(tracker.Imported).Should(Equal(map[string]string{"fmt": "fmt"})) + Expect(tracker.Imported).Should(Equal(map[string][]string{"fmt": {"fm"}})) }) }) }) diff --git a/testutils/source.go b/testutils/source.go index 3db02e2..31d90fe 100644 --- a/testutils/source.go +++ b/testutils/source.go @@ -3196,6 +3196,25 @@ func main() { println(bad) } `}, 1, gosec.NewConfig()}, + {[]string{` +package main + +import ( + crand "crypto/rand" + "math/big" + "math/rand" + rand2 "math/rand" + rand3 "math/rand" +) + +func main() { + _, _ = crand.Int(crand.Reader, big.NewInt(int64(2))) // good + + _ = rand.Intn(2) // bad + _ = rand2.Intn(2) // bad + _ = rand3.Intn(2) // bad +} +`}, 3, gosec.NewConfig()}, } // SampleCodeG501 - Blocklisted import MD5