Add db.Exec and db.Prepare to the sql rule (#763)

* Add db.Exec and db.Prepare to the sql rule

* add test cases for G201,G202
This commit is contained in:
kaiili 2022-01-17 20:50:37 +08:00 committed by GitHub
parent 742aa848f9
commit 1d909e2687
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 165 additions and 5 deletions

View file

@ -137,8 +137,8 @@ func NewSQLStrConcat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
}, },
} }
rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext") rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext")
rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext") rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext")
return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)} return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)}
} }
@ -306,8 +306,8 @@ func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
}, },
}, },
} }
rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext") rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext")
rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext") rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext")
rule.fmtCalls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf") rule.fmtCalls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf")
rule.noIssue.AddAll("os", "Stdout", "Stderr") rule.noIssue.AddAll("os", "Stdout", "Stderr")
rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier") rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier")

View file

@ -1255,7 +1255,103 @@ func main() {
panic(err) panic(err)
} }
defer db.Close() defer db.Close()
}`}, 1, gosec.NewConfig()}, }`}, 1, gosec.NewConfig()}, {[]string{`
// SQLI by db.Prepare(some)
package main
import (
"database/sql"
"fmt"
"log"
"os"
)
const Table = "foo"
func main() {
var album string
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
q := fmt.Sprintf("SELECT name FROM users where '%s' = ?", os.Args[1])
stmt, err := db.Prepare(q)
if err != nil {
log.Fatal(err)
}
stmt.QueryRow(fmt.Sprintf("%s", os.Args[2])).Scan(&album)
if err != nil {
if err == sql.ErrNoRows {
log.Fatal(err)
}
}
defer stmt.Close()
}
`}, 1, gosec.NewConfig()}, {[]string{`
// SQLI by db.PrepareContext(some)
package main
import (
"context"
"database/sql"
"fmt"
"log"
"os"
)
const Table = "foo"
func main() {
var album string
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
q := fmt.Sprintf("SELECT name FROM users where '%s' = ?", os.Args[1])
stmt, err := db.PrepareContext(context.Background(), q)
if err != nil {
log.Fatal(err)
}
stmt.QueryRow(fmt.Sprintf("%s", os.Args[2])).Scan(&album)
if err != nil {
if err == sql.ErrNoRows {
log.Fatal(err)
}
}
defer stmt.Close()
}
`}, 1, gosec.NewConfig()}, {[]string{`
// false positive
package main
import (
"database/sql"
"fmt"
"log"
"os"
)
const Table = "foo"
func main() {
var album string
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
stmt, err := db.Prepare("SELECT * FROM album WHERE id = ?")
if err != nil {
log.Fatal(err)
}
stmt.QueryRow(fmt.Sprintf("%s", os.Args[1])).Scan(&album)
if err != nil {
if err == sql.ErrNoRows {
log.Fatal(err)
}
}
defer stmt.Close()
}
`}, 0, gosec.NewConfig()},
} }
// SampleCodeG202 - SQL query string building via string concatenation // SampleCodeG202 - SQL query string building via string concatenation
@ -1431,6 +1527,70 @@ func main(){
} }
defer rows.Close() defer rows.Close()
} }
`}, 0, gosec.NewConfig()}, {[]string{`
// ExecContext match
package main
import (
"context"
"database/sql"
"fmt"
"os"
)
func main() {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
result, err := db.ExecContext(context.Background(), "select * from foo where name = "+os.Args[1])
if err != nil {
panic(err)
}
fmt.Println(result)
}`}, 1, gosec.NewConfig()}, {[]string{`
// Exec match
package main
import (
"database/sql"
"fmt"
"os"
)
func main() {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
result, err := db.Exec("select * from foo where name = " + os.Args[1])
if err != nil {
panic(err)
}
fmt.Println(result)
}`}, 1, gosec.NewConfig()}, {[]string{`
package main
import (
"database/sql"
"fmt"
)
const gender = "M"
const age = "32"
var staticQuery = "SELECT * FROM foo WHERE age < "
func main() {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
result, err := db.Exec("SELECT * FROM foo WHERE gender = " + gender)
if err != nil {
panic(err)
}
fmt.Println(result)
}
`}, 0, gosec.NewConfig()}, `}, 0, gosec.NewConfig()},
} }