diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-08 00:05:46 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-08 00:05:46 -0600 |
| commit | 7bc2e617da68946bb87b18bdf5fbfe5000061671 (patch) | |
| tree | 10d40c46d7be6899697eb067415a0d6fd1c4ad0c | |
| parent | 1f78eec63681c6f20b1158947caecffe94cd848f (diff) | |
add shim for tools supportmain
| -rw-r--r-- | cmd/claude-proxy/main.go | 221 |
1 files changed, 136 insertions, 85 deletions
diff --git a/cmd/claude-proxy/main.go b/cmd/claude-proxy/main.go index 1741de7..765df94 100644 --- a/cmd/claude-proxy/main.go +++ b/cmd/claude-proxy/main.go @@ -3,113 +3,164 @@ package main import ( "bytes" "encoding/json" + "fmt" "io" "log" "net/http" + "os/exec" + "strings" + "time" - "github.com/gin-gonic/gin" "github.com/xlgmokha/x/pkg/env" ) -const ollamaURL = "http://localhost:11434/api/generate" - -type AnthropicMsg struct { +type Message struct { Role string `json:"role"` Content string `json:"content"` } -type AnthropicRequest struct { - Model string `json:"model"` - MaxTokens int `json:"max_tokens"` - Messages []AnthropicMsg `json:"messages"` - Stream bool `json:"stream"` - Temperature float64 `json:"temperature,omitempty"` + +type Tool struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]interface{} `json:"input_schema"` +} + +type RequestBody struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Stream bool `json:"stream"` + Tools []Tool `json:"tools"` } + type StreamChunk struct { - Content []struct { - Type string `json:"type"` - Text string `json:"text"` - } `json:"content"` - StopReason string `json:"stop_reason"` + Type string `json:"type"` + Content string `json:"content"` } -type OllamaRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - Stream bool `json:"stream"` +func executeShell(command string) string { + cmd := exec.Command("bash", "-c", command) + var out bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &out + err := cmd.Run() + if err != nil { + return fmt.Sprintf("Shell Error: %v\n%s", err, out.String()) + } + return out.String() } -func main() { - r := gin.Default() +func runCodeSnippet(language, code string) string { + // Only Python supported for now + if language != "python" { + return "Only Python is supported." + } + cmd := exec.Command("python3", "-c", code) + var out bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &out + err := cmd.Run() + if err != nil { + return fmt.Sprintf("Code Error: %v\n%s", err, out.String()) + } + return out.String() +} - r.POST("/v1/messages", func(c *gin.Context) { - if c.GetHeader("anthropic-version") == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "missing anthropic-version header"}) - return - } +func mockSearch(query string) string { + return fmt.Sprintf("[Mocked Web Search] Query: %s", query) +} - var req AnthropicRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid JSON"}) - return - } +func parseFunctionCall(content string) (string, map[string]interface{}) { + start := strings.Index(content, "<function_call") + if start == -1 { + return "", nil + } + end := strings.Index(content, ">") + if end == -1 { + return "", nil + } - // Combine user messages into final prompt - fullPrompt := "" - for _, msg := range req.Messages { - if msg.Role == "user" { - fullPrompt += msg.Content + "\n" - } - } + tag := content[start : end+1] + nameStart := strings.Index(tag, "name=\"") + 6 + nameEnd := strings.Index(tag[nameStart:], "\"") + nameStart + name := tag[nameStart:nameEnd] - ollamaReq := OllamaRequest{ - Model: req.Model, - Prompt: fullPrompt, - Stream: req.Stream, - } + argsStart := strings.Index(tag, "arguments='") + 10 + argsEnd := strings.LastIndex(tag, "'") + argsJSON := tag[argsStart:argsEnd] - buf, _ := json.Marshal(ollamaReq) - resp, err := http.Post(ollamaURL, "application/json", bytes.NewReader(buf)) - if err != nil { - c.JSON(http.StatusBadGateway, gin.H{"error": "failed contacting ollama"}) - return - } - defer resp.Body.Close() - - // Prepare chunked JSON for Claude - c.Header("Content-Type", "application/json") - c.Status(http.StatusOK) - encoder := json.NewEncoder(c.Writer) - - if req.Stream { - decoder := json.NewDecoder(resp.Body) - for { - var chunk struct{ Response string } - if err := decoder.Decode(&chunk); err != nil { - break - } - sc := StreamChunk{ - Content: []struct { - Type string `json:"type"` - Text string `json:"text"` - }{{Type: "text", Text: chunk.Response}}, - } - encoder.Encode(sc) - c.Writer.Flush() - } + var args map[string]interface{} + _ = json.Unmarshal([]byte(argsJSON), &args) + return name, args +} + +func encodeChunk(w io.Writer, chunkType, content string) { + chunk := StreamChunk{Type: chunkType, Content: content} + json.NewEncoder(w).Encode(chunk) + fmt.Fprint(w, "\n\n") +} + +func handler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/v1/messages" { + http.NotFound(w, r) + return + } + + var req RequestBody + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&req); err != nil { + http.Error(w, "Invalid JSON", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported", http.StatusInternalServerError) + return + } + + userMsg := req.Messages[len(req.Messages)-1].Content + name, args := parseFunctionCall(userMsg) + result := "" + + switch name { + case "bash": + result = executeShell(args["command"].(string)) + case "code_execution": + result = runCodeSnippet(args["language"].(string), args["code"].(string)) + case "computer_use": + action := args["action"].(string) + path := args["path"].(string) + if action == "list" { + result = executeShell("ls -l " + path) } else { - bodyBytes, _ := io.ReadAll(resp.Body) - var all struct{ Response string } - json.Unmarshal(bodyBytes, &all) - sc := StreamChunk{ - Content: []struct { - Type string `json:"type"` - Text string `json:"text"` - }{{Type: "text", Text: all.Response}}, - } - encoder.Encode(sc) + result = executeShell("cat " + path) } - }) + case "text_editor": + result = fmt.Sprintf("[Mock] Applied edit '%s' to %s", args["edit"], args["path"]) + case "web_search_20250305": + result = mockSearch(args["query"].(string)) + default: + result = "Tool not supported or missing function_call" + } + + encodeChunk(w, "tool_result", result) + flusher.Flush() - bindAddr := env.Fetch("BIND_ADDR", ":http") - log.Fatal(r.Run(bindAddr)) + // Simulated final model reply + encodeChunk(w, "content_block", "Here's the result of your tool request.") + flusher.Flush() + time.Sleep(1 * time.Second) + fmt.Fprint(w, "event: done\ndata: [DONE]\n\n") + flusher.Flush() +} + +func main() { + http.HandleFunc("/v1/messages", handler) + bindAddr := env.Fetch("BIND_ADDR", ":8080") + fmt.Println("Claude-compatible proxy running on", bindAddr) + log.Fatal(http.ListenAndServe(bindAddr, nil)) } |
