171 lines
4.1 KiB
Go
171 lines
4.1 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
tokenExchangeURL = "https://api.github.com/copilot_internal/v2/token"
|
|
defaultOpenAIEndpoint = "https://api.githubcopilot.com"
|
|
userAgent = "curl/8.7.1"
|
|
)
|
|
|
|
type TokenResponse struct {
|
|
Token string `json:"token"`
|
|
ExpiresAt int64 `json:"expires_at"`
|
|
}
|
|
|
|
type ProxyServer struct {
|
|
authToken string
|
|
copilotToken string
|
|
tokenExpiry time.Time
|
|
proxyEndpoint string
|
|
mutex sync.RWMutex
|
|
client *http.Client
|
|
}
|
|
|
|
func NewProxyServer(authToken string) *ProxyServer {
|
|
return &ProxyServer{
|
|
authToken: authToken,
|
|
client: &http.Client{Timeout: 30 * time.Second},
|
|
}
|
|
}
|
|
|
|
func (p *ProxyServer) refreshTokenIfNeeded() error {
|
|
p.mutex.RLock()
|
|
tokenValid := p.copilotToken != "" && time.Now().Before(p.tokenExpiry)
|
|
p.mutex.RUnlock()
|
|
|
|
if tokenValid {
|
|
return nil
|
|
}
|
|
|
|
p.mutex.Lock()
|
|
defer p.mutex.Unlock()
|
|
|
|
if p.copilotToken != "" && time.Now().Before(p.tokenExpiry) {
|
|
return nil
|
|
}
|
|
|
|
req, err := http.NewRequest("GET", tokenExchangeURL, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create token exchange request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Authorization", fmt.Sprintf("token %s", p.authToken))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("User-Agent", userAgent)
|
|
|
|
resp, err := p.client.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("token exchange request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var tokenResp TokenResponse
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read token response body: %w", err)
|
|
}
|
|
|
|
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
|
return fmt.Errorf("failed to parse token response: %w", err)
|
|
}
|
|
|
|
p.copilotToken = tokenResp.Token
|
|
p.tokenExpiry = time.Unix(tokenResp.ExpiresAt, 0)
|
|
p.proxyEndpoint = defaultOpenAIEndpoint
|
|
|
|
log.Printf("token refreshed, valid until: %v", p.tokenExpiry)
|
|
return nil
|
|
}
|
|
|
|
func (p *ProxyServer) proxyHandler(w http.ResponseWriter, r *http.Request) {
|
|
if err := p.refreshTokenIfNeeded(); err != nil {
|
|
log.Printf("error refreshing token: %v", err)
|
|
http.Error(w, "failed to authenticate with copilot", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
http.Error(w, "error reading request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
defer r.Body.Close()
|
|
|
|
targetURL := p.proxyEndpoint + r.URL.Path
|
|
if p.proxyEndpoint == "" {
|
|
log.Printf("warning: proxy endpoint is empty, using default openai endpoint")
|
|
targetURL = defaultOpenAIEndpoint + r.URL.Path
|
|
}
|
|
|
|
if r.URL.RawQuery != "" {
|
|
targetURL += "?" + r.URL.RawQuery
|
|
}
|
|
|
|
log.Printf("proxying request to: %s", targetURL)
|
|
proxyReq, err := http.NewRequest(r.Method, targetURL, bytes.NewReader(body))
|
|
if err != nil {
|
|
http.Error(w, "error creating proxy request", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
proxyReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.copilotToken))
|
|
proxyReq.Header.Set("Copilot-Integration-Id", "vscode-chat")
|
|
proxyReq.Header.Set("Editor-Version", "Neovim/0.6.1")
|
|
|
|
resp, err := p.client.Do(proxyReq)
|
|
if err != nil {
|
|
log.Printf("error forwarding request to copilot: %v", err)
|
|
http.Error(w, "error forwarding request to copilot", http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
for name, values := range resp.Header {
|
|
for _, value := range values {
|
|
w.Header().Add(name, value)
|
|
}
|
|
}
|
|
|
|
w.WriteHeader(resp.StatusCode)
|
|
|
|
if _, err := io.Copy(w, resp.Body); err != nil {
|
|
log.Printf("error copying response body: %v", err)
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
authToken := os.Getenv("GITHUB_AUTH_TOKEN")
|
|
if authToken == "" {
|
|
log.Fatal("GITHUB_AUTH_TOKEN environment variable is required")
|
|
}
|
|
|
|
proxy := NewProxyServer(authToken)
|
|
|
|
http.HandleFunc("/", proxy.proxyHandler)
|
|
|
|
port := os.Getenv("PORT")
|
|
if port == "" {
|
|
port = "8080"
|
|
}
|
|
|
|
log.Printf("starting proxy server on port %s", port)
|
|
if err := http.ListenAndServe(":"+port, nil); err != nil {
|
|
log.Fatalf("failed to start server: %v", err)
|
|
}
|
|
}
|