gosec/autofix/ai.go
Tran The Lam 56f943b802
Add support to generate auto fixes using LLM (AI) (#1177)
This feature adds support to generate auto fixes for Go scanning findings using LLM (AI). In a first instance, it relies on Gemini API to get a suggestion for a solution. This can be later extended, to integrate also other AI providers.

---------

Signed-off-by: Cosmin Cojocar <ccojocar@google.com>
Co-authored-by: ccoVeille <3875889+ccoVeille@users.noreply.github.com>
Co-authored-by: Cosmin Cojocar <ccojocar@google.com>
2024-08-12 12:52:41 +02:00

142 lines
4 KiB
Go

package autofix
import (
"context"
"errors"
"fmt"
"time"
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/option"
"github.com/securego/gosec/v2/issue"
)
const (
GeminiModel = "gemini-1.5-flash"
AIPrompt = `Provide a brief explanation and a solution to fix this security issue
in Go programming language: %q.
Answer in markdown format and keep the response limited to 200 words.`
GeminiProvider = "gemini"
timeout = 30 * time.Second
)
// GenAIClient defines the interface for the GenAI client.
type GenAIClient interface {
// Close clean up and close the client.
Close() error
// GenerativeModel build the generative mode.
GenerativeModel(name string) GenAIGenerativeModel
}
// GenAIGenerativeModel defines the interface for the Generative Model.
type GenAIGenerativeModel interface {
// GenerateContent generates an response for given prompt.
GenerateContent(ctx context.Context, prompt string) (string, error)
}
// genAIClientWrapper wraps the genai.Client to implement GenAIClient.
type genAIClientWrapper struct {
client *genai.Client
}
// Close closes the gen AI client.
func (w *genAIClientWrapper) Close() error {
return w.client.Close()
}
// GenerativeModel builds the generative Model.
func (w *genAIClientWrapper) GenerativeModel(name string) GenAIGenerativeModel {
return &genAIGenerativeModelWrapper{model: w.client.GenerativeModel(name)}
}
// genAIGenerativeModelWrapper wraps the genai.GenerativeModel to implement GenAIGenerativeModel
type genAIGenerativeModelWrapper struct {
// model is the underlying generative model
model *genai.GenerativeModel
}
// GenerateContent generates a response for the given prompt using gemini API.
func (w *genAIGenerativeModelWrapper) GenerateContent(ctx context.Context, prompt string) (string, error) {
resp, err := w.model.GenerateContent(ctx, genai.Text(prompt))
if err != nil {
return "", fmt.Errorf("generating autofix: %w", err)
}
if len(resp.Candidates) == 0 {
return "", errors.New("no autofix returned by gemini")
}
if len(resp.Candidates[0].Content.Parts) == 0 {
return "", errors.New("nothing found in the first autofix returned by gemini")
}
// Return the first candidate
return fmt.Sprintf("%+v", resp.Candidates[0].Content.Parts[0]), nil
}
// NewGenAIClient creates a new gemini API client.
func NewGenAIClient(ctx context.Context, aiApiKey, endpoint string) (GenAIClient, error) {
clientOptions := []option.ClientOption{option.WithAPIKey(aiApiKey)}
if endpoint != "" {
clientOptions = append(clientOptions, option.WithEndpoint(endpoint))
}
client, err := genai.NewClient(ctx, clientOptions...)
if err != nil {
return nil, fmt.Errorf("calling gemini API: %w", err)
}
return &genAIClientWrapper{client: client}, nil
}
func generateSolutionByGemini(client GenAIClient, issues []*issue.Issue) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
model := client.GenerativeModel(GeminiModel)
cachedAutofix := make(map[string]string)
for _, issue := range issues {
if val, ok := cachedAutofix[issue.What]; ok {
issue.Autofix = val
continue
}
prompt := fmt.Sprintf(AIPrompt, issue.What)
resp, err := model.GenerateContent(ctx, prompt)
if err != nil {
return fmt.Errorf("generating autofix with gemini: %w", err)
}
if resp == "" {
return errors.New("no autofix returned by gemini")
}
issue.Autofix = resp
cachedAutofix[issue.What] = issue.Autofix
}
return nil
}
// GenerateSolution generates a solution for the given issues using the specified AI provider
func GenerateSolution(aiApiProvider, aiApiKey, endpoint string, issues []*issue.Issue) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
var client GenAIClient
switch aiApiProvider {
case GeminiProvider:
var err error
client, err = NewGenAIClient(ctx, aiApiKey, endpoint)
if err != nil {
return fmt.Errorf("generating autofix: %w", err)
}
default:
return errors.New("ai provider not supported")
}
defer client.Close()
return generateSolutionByGemini(client, issues)
}