go-ai-agent/main.go

432 lines
10 KiB
Go

package main
import (
"bufio"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"os"
"path"
"path/filepath"
"strings"
"github.com/anthropics/anthropic-sdk-go"
"github.com/invopop/jsonschema"
"gocv.io/x/gocv"
)
func main() {
client := anthropic.NewClient()
scanner := bufio.NewScanner(os.Stdin)
getUserMessage := func() (string, bool) {
if !scanner.Scan() {
return "", false
}
return scanner.Text(), true
}
tools := []ToolDefinition{
ReadFileDefinition,
ListFilesDefinition,
EditFileDefinition,
Base64EncodeFileDefinition,
WebcamDefinition,
}
agent := NewAgent(&client, getUserMessage, tools)
err := agent.Run(context.TODO())
if err != nil {
fmt.Printf("Error: %s\n", err.Error())
}
}
func NewAgent(client *anthropic.Client, getUserMessage func() (string, bool), tools []ToolDefinition) *Agent {
return &Agent{
client: client,
getUserMessage: getUserMessage,
tools: tools,
}
}
type Agent struct {
client *anthropic.Client
getUserMessage func() (string, bool)
tools []ToolDefinition
}
func (a *Agent) Run(ctx context.Context) error {
conversation := []anthropic.MessageParam{}
fmt.Println("Chat with Claude (use 'ctrl-c' to quit)")
readUserInput := true
for {
if readUserInput {
fmt.Print("\u001b[94mYou\u001b[0m: ")
userInput, ok := a.getUserMessage()
if !ok {
break
}
userMessage := anthropic.NewUserMessage(anthropic.NewTextBlock(userInput))
conversation = append(conversation, userMessage)
}
message, err := a.runInference(ctx, conversation)
if err != nil {
return err
}
conversation = append(conversation, message.ToParam())
toolResults := []anthropic.ContentBlockParamUnion{}
for _, content := range message.Content {
switch content.Type {
case "text":
print("\n")
case "tool_use":
result := a.executeTool(content.ID, content.Name, content.Input)
toolResults = append(toolResults, result)
}
}
if len(toolResults) == 0 {
readUserInput = true
continue
}
readUserInput = false
conversation = append(conversation, anthropic.NewUserMessage(toolResults...))
}
return nil
}
func (a *Agent) executeTool(id, name string, input json.RawMessage) anthropic.ContentBlockParamUnion {
var toolDef ToolDefinition
var found bool
for _, tool := range a.tools {
if tool.Name == name {
toolDef = tool
found = true
break
}
}
if !found {
return anthropic.NewToolResultBlock(id, "tool not found", true)
}
fmt.Printf("\n\u001b[92mtool\u001b[0m: %s(%s)\n", name, input)
response, err := toolDef.Function(input)
if err != nil {
return anthropic.NewToolResultBlock(id, err.Error(), true)
}
return anthropic.NewToolResultBlock(id, response, false)
}
func (a *Agent) runInference(ctx context.Context, conversation []anthropic.MessageParam) (*anthropic.Message, error) {
anthropicTools := []anthropic.ToolUnionParam{}
for _, tool := range a.tools {
anthropicTools = append(anthropicTools, anthropic.ToolUnionParam{
OfTool: &anthropic.ToolParam{
Name: tool.Name,
Description: anthropic.String(tool.Description),
InputSchema: tool.InputSchema,
},
})
}
stream := a.client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
Model: anthropic.ModelClaude3_7SonnetLatest,
MaxTokens: int64(1024),
Messages: conversation,
Tools: anthropicTools,
})
print("\u001b[93mClaude\u001b[0m: ")
message := anthropic.Message{}
for stream.Next() {
event := stream.Current()
err := message.Accumulate(event)
if err != nil {
return nil, err
}
switch eventVariant := event.AsAny().(type) {
case anthropic.ContentBlockDeltaEvent:
switch deltaVariant := eventVariant.Delta.AsAny().(type) {
case anthropic.TextDelta:
print(deltaVariant.Text)
}
}
}
return &message, stream.Err()
}
type ToolDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema anthropic.ToolInputSchemaParam `json:"input_schema"`
Function func(input json.RawMessage) (string, error)
}
var ReadFileDefinition = ToolDefinition{
Name: "read_file",
Description: "Read the contents of a given relative file path. Use this when you want to see what's inside a file. Do not use this with directory names.",
InputSchema: ReadFileInputSchema,
Function: ReadFile,
}
type ReadFileInput struct {
Path string `json:"path" jsonschema_description:"The relative path of a file in the working directory."`
}
var ReadFileInputSchema = GenerateSchema[ReadFileInput]()
func ReadFile(input json.RawMessage) (string, error) {
readFileInput := ReadFileInput{}
err := json.Unmarshal(input, &readFileInput)
if err != nil {
panic(err)
}
content, err := os.ReadFile(readFileInput.Path)
if err != nil {
return "", err
}
return string(content), nil
}
func GenerateSchema[T any]() anthropic.ToolInputSchemaParam {
reflector := jsonschema.Reflector{
AllowAdditionalProperties: false,
DoNotReference: true,
}
var v T
schema := reflector.Reflect(v)
return anthropic.ToolInputSchemaParam{
Properties: schema.Properties,
}
}
var ListFilesDefinition = ToolDefinition{
Name: "list_files",
Description: "List files and directories at a given path. If no path is provided, lists files in the current directory.",
InputSchema: ListFilesInputSchema,
Function: ListFiles,
}
type ListFilesInput struct {
Path string `json:"path,omitempty" jsonschema_description:"Optional relative path to list files from. Defaults to current directory if not provided."`
}
var ListFilesInputSchema = GenerateSchema[ListFilesInput]()
func ListFiles(input json.RawMessage) (string, error) {
listFilesInput := ListFilesInput{}
err := json.Unmarshal(input, &listFilesInput)
if err != nil {
panic(err)
}
dir := "."
if listFilesInput.Path != "" {
dir = listFilesInput.Path
}
var files []string
err = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Ignore devenv folder because it blows up the list of files (too many
// tokens)
if strings.Contains(path, ".devenv") {
return nil
}
relPath, err := filepath.Rel(dir, path)
if err != nil {
return err
}
if relPath != "." {
if info.IsDir() {
files = append(files, relPath+"/")
} else {
files = append(files, relPath)
}
}
return nil
})
if err != nil {
return "", err
}
result, err := json.Marshal(files)
if err != nil {
return "", err
}
return string(result), nil
}
var EditFileDefinition = ToolDefinition{
Name: "edit_file",
Description: `Make edits to a text file.
Replaces 'old_str' with 'new_str' in the given file. 'old_str' and 'new_str' MUST be different from each other.
If the file specified with path doesn't exist, it will be created.
`,
InputSchema: EditFileInputSchema,
Function: EditFile,
}
type EditFileInput struct {
Path string `json:"path" jsonschema_description:"The path to the file"`
OldStr string `json:"old_str" jsonschema_description:"Text to search for - must match exactly and must only have one match exactly"`
NewStr string `json:"new_str" jsonschema_description:"Text to replace old_str with"`
}
var EditFileInputSchema = GenerateSchema[EditFileInput]()
func EditFile(input json.RawMessage) (string, error) {
editFileInput := EditFileInput{}
err := json.Unmarshal(input, &editFileInput)
if err != nil {
return "", err
}
if editFileInput.Path == "" || editFileInput.OldStr == editFileInput.NewStr {
return "", fmt.Errorf("invalid input parameters")
}
content, err := os.ReadFile(editFileInput.Path)
if err != nil {
if os.IsNotExist(err) && editFileInput.OldStr == "" {
return createNewFile(editFileInput.Path, editFileInput.NewStr)
}
return "", err
}
oldContent := string(content)
newContent := strings.Replace(oldContent, editFileInput.OldStr, editFileInput.NewStr, -1)
if oldContent == newContent && editFileInput.OldStr != "" {
return "", fmt.Errorf("old_str not found in file")
}
err = os.WriteFile(editFileInput.Path, []byte(newContent), 0644)
if err != nil {
return "", err
}
return "OK", nil
}
func createNewFile(filePath, content string) (string, error) {
dir := path.Dir(filePath)
if dir != "." {
err := os.MkdirAll(dir, 0755)
if err != nil {
return "", fmt.Errorf("failed to create directory: %w", err)
}
}
err := os.WriteFile(filePath, []byte(content), 0644)
if err != nil {
return "", fmt.Errorf("failed to create file: %w", err)
}
return fmt.Sprintf("Successfully created file %s", filePath), nil
}
var Base64EncodeFileDefinition = ToolDefinition{
Name: "base64_encode",
Description: `Generates a base64 encoding of a file.
This is especially useful when asked to describe an image file (you can use
this get a base64 encoded representation of the image file).
`,
InputSchema: Base64EncodeFileInputSchema,
Function: Base64EncodeFile,
}
type Base64EncodeFileInput struct {
Path string `json:"path" jsonschema_description:"The path to the image"`
}
var Base64EncodeFileInputSchema = GenerateSchema[EditFileInput]()
func Base64EncodeFile(input json.RawMessage) (string, error) {
analyzeImageInput := Base64EncodeFileInput{}
err := json.Unmarshal(input, &analyzeImageInput)
if err != nil {
panic(err)
}
content, err := os.ReadFile(analyzeImageInput.Path)
if err != nil {
return "", err
}
encoded := base64.StdEncoding.EncodeToString([]byte(content))
return encoded, nil
}
var WebcamDefinition = ToolDefinition{
Name: "webcam",
Description: `Take a picture using the computer's webcam.
This way you can see what the user sees and provide a description of what
you see.
`,
InputSchema: WebcamDefinitionInputSchema,
Function: Webcam,
}
type WebcamDefinitionInput struct{}
var WebcamDefinitionInputSchema = GenerateSchema[WebcamDefinitionInput]()
func Webcam(input json.RawMessage) (string, error) {
webcam, err := gocv.OpenVideoCapture(0)
if err != nil {
return "", err
}
defer webcam.Close()
webcam.Set(gocv.VideoCaptureFrameWidth, 1280.0)
webcam.Set(gocv.VideoCaptureFrameHeight, 720.0)
if !webcam.IsOpened() {
return "", errors.New("Unable to open video capture device")
}
img := gocv.NewMat()
defer img.Close()
if ok := webcam.Read(&img); !ok {
return "", errors.New("Cannot read from video capture device")
}
if img.Empty() {
return "", errors.New("Capture image is empty")
}
jpegData, err := gocv.IMEncode(".jpg", img)
if err != nil {
return "", err
}
encoded := base64.StdEncoding.EncodeToString(jpegData.GetBytes())
return encoded, nil
}