diff --git a/helpers.go b/helpers.go index bc01a94..638129b 100644 --- a/helpers.go +++ b/helpers.go @@ -281,3 +281,33 @@ func ConcatString(n *ast.BinaryExpr) (string, bool) { } return s, true } + +// FindVarIdentities returns array of all variable identities in a given binary expression +func FindVarIdentities(n *ast.BinaryExpr, c *Context) ([]*ast.Ident, bool) { + identities := []*ast.Ident{} + // sub expressions are found in X object, Y object is always the last term + if rightOperand, ok := n.Y.(*ast.Ident); ok { + obj := c.Info.ObjectOf(rightOperand) + if _, ok := obj.(*types.Var); ok && !TryResolve(rightOperand, c) { + identities = append(identities, rightOperand) + } + } + if leftOperand, ok := n.X.(*ast.BinaryExpr); ok { + if leftIdentities, ok := FindVarIdentities(leftOperand, c); ok { + identities = append(identities, leftIdentities...) + } + } else { + if leftOperand, ok := n.X.(*ast.Ident); ok { + obj := c.Info.ObjectOf(leftOperand) + if _, ok := obj.(*types.Var); ok && !TryResolve(leftOperand, c) { + identities = append(identities, leftOperand) + } + } + } + + if len(identities) > 0 { + return identities, true + } + // if nil or error, return false + return nil, false +} diff --git a/rules/readfile.go b/rules/readfile.go index 61e1c85..2b38852 100644 --- a/rules/readfile.go +++ b/rules/readfile.go @@ -24,6 +24,7 @@ import ( type readfile struct { gosec.MetaData gosec.CallList + pathJoin gosec.CallList } // ID returns the identifier for this rule @@ -31,10 +32,49 @@ func (r *readfile) ID() string { return r.MetaData.ID } +// isJoinFunc checks if there is a filepath.Join or other join function +func (r *readfile) isJoinFunc(n ast.Node, c *gosec.Context) bool { + if call := r.pathJoin.ContainsCallExpr(n, c); call != nil { + for _, arg := range call.Args { + // edge case: check if one of the args is a BinaryExpr + if binExp, ok := arg.(*ast.BinaryExpr); ok { + // iterate and resolve all found identites from the BinaryExpr + if _, ok := gosec.FindVarIdentities(binExp, c); ok { + return true + } + } + + // try and resolve identity + if ident, ok := arg.(*ast.Ident); ok { + obj := c.Info.ObjectOf(ident) + if _, ok := obj.(*types.Var); ok && !gosec.TryResolve(ident, c) { + return true + } + } + } +} + return false +} + // Match inspects AST nodes to determine if the match the methods `os.Open` or `ioutil.ReadFile` func (r *readfile) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { if node := r.ContainsCallExpr(n, c); node != nil { for _, arg := range node.Args { + // handles path joining functions in Arg + // eg. os.Open(filepath.Join("/tmp/", file)) + if callExpr, ok := arg.(*ast.CallExpr); ok { + if r.isJoinFunc(callExpr, c) { + return gosec.NewIssue(c, n, r.ID(), r.What, r.Severity, r.Confidence), nil + } + } + // handles binary string concatenation eg. ioutil.Readfile("/tmp/" + file + "/blob") + if binExp, ok := arg.(*ast.BinaryExpr); ok { + // resolve all found identites from the BinaryExpr + if _, ok := gosec.FindVarIdentities(binExp, c); ok { + return gosec.NewIssue(c, n, r.ID(), r.What, r.Severity, r.Confidence), nil + } + } + if ident, ok := arg.(*ast.Ident); ok { obj := c.Info.ObjectOf(ident) if _, ok := obj.(*types.Var); ok && !gosec.TryResolve(ident, c) { @@ -49,6 +89,7 @@ func (r *readfile) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { // NewReadFile detects cases where we read files func NewReadFile(id string, conf gosec.Config) (gosec.Rule, []ast.Node) { rule := &readfile{ + pathJoin: gosec.NewCallList(), CallList: gosec.NewCallList(), MetaData: gosec.MetaData{ ID: id, @@ -57,6 +98,8 @@ func NewReadFile(id string, conf gosec.Config) (gosec.Rule, []ast.Node) { Confidence: gosec.High, }, } + rule.pathJoin.Add("path/filepath", "Join") + rule.pathJoin.Add("path", "Join") rule.Add("io/ioutil", "ReadFile") rule.Add("os", "Open") return rule, []ast.Node{(*ast.CallExpr)(nil)} diff --git a/testutils/source.go b/testutils/source.go index d3464b4..48e83b8 100644 --- a/testutils/source.go +++ b/testutils/source.go @@ -500,7 +500,7 @@ import ( func main() { http.HandleFunc("/bar", func(w http.ResponseWriter, r *http.Request) { - title := r.URL.Query().Get("title") + title := r.URL.Query().Get("title") f, err := os.Open(title) if err != nil { fmt.Printf("Error: %v\n", err) @@ -512,6 +512,65 @@ func main() { fmt.Fprintf(w, "%s", body) }) log.Fatal(http.ListenAndServe(":3000", nil)) +}`, 1}, {` +package main + +import ( + "log" + "os" + "io/ioutil" +) + + func main() { + f2 := os.Getenv("tainted_file2") + body, err := ioutil.ReadFile("/tmp/" + f2) + if err != nil { + log.Printf("Error: %v\n", err) + } + log.Print(body) + }`, 1}, {` + package main + + import ( + "bufio" + "fmt" + "os" + "path/filepath" + ) + +func main() { + reader := bufio.NewReader(os.Stdin) + fmt.Print("Please enter file to read: ") + file, _ := reader.ReadString('\n') + file = file[:len(file)-1] + f, err := os.Open(filepath.Join("/tmp/service/", file)) + if err != nil { + fmt.Printf("Error: %v\n", err) + } + contents := make([]byte, 15) + if _, err = f.Read(contents); err != nil { + fmt.Printf("Error: %v\n", err) + } + fmt.Println(string(contents)) +}`, 1}, {` +package main + +import ( + "log" + "os" + "io/ioutil" + "path/filepath" +) + +func main() { + dir := os.Getenv("server_root") + f3 := os.Getenv("tainted_file3") + // edge case where both a binary expression and file Join are used. + body, err := ioutil.ReadFile(filepath.Join("/var/"+dir, f3)) + if err != nil { + log.Printf("Error: %v\n", err) + } + log.Print(body) }`, 1}} // SampleCodeG305 - File path traversal when extracting zip archives