package main

import (
	"bytes"
	"encoding/json"
	"fmt"
	"io"
	"log"
	"net/http"
	"os"
	"sync"
	"time"
)

const (
	tokenExchangeURL      = "https://api.github.com/copilot_internal/v2/token"
	defaultOpenAIEndpoint = "https://api.githubcopilot.com"
	userAgent             = "curl/8.7.1"
)

type TokenResponse struct {
	Token     string `json:"token"`
	ExpiresAt int64  `json:"expires_at"`
}

type ProxyServer struct {
	authToken     string
	copilotToken  string
	tokenExpiry   time.Time
	proxyEndpoint string
	mutex         sync.RWMutex
	client        *http.Client
}

func NewProxyServer(authToken string) *ProxyServer {
	return &ProxyServer{
		authToken: authToken,
		client:    &http.Client{Timeout: 30 * time.Second},
	}
}

func (p *ProxyServer) refreshTokenIfNeeded() error {
	p.mutex.RLock()
	tokenValid := p.copilotToken != "" && time.Now().Before(p.tokenExpiry)
	p.mutex.RUnlock()

	if tokenValid {
		return nil
	}

	p.mutex.Lock()
	defer p.mutex.Unlock()

	if p.copilotToken != "" && time.Now().Before(p.tokenExpiry) {
		return nil
	}

	req, err := http.NewRequest("GET", tokenExchangeURL, nil)
	if err != nil {
		return fmt.Errorf("failed to create token exchange request: %w", err)
	}

	req.Header.Set("Authorization", fmt.Sprintf("token %s", p.authToken))
	req.Header.Set("Content-Type", "application/json")
	req.Header.Set("User-Agent", userAgent)

	resp, err := p.client.Do(req)
	if err != nil {
		return fmt.Errorf("token exchange request failed: %w", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		body, _ := io.ReadAll(resp.Body)
		return fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
	}

	var tokenResp TokenResponse
	body, err := io.ReadAll(resp.Body)
	if err != nil {
		return fmt.Errorf("failed to read token response body: %w", err)
	}

	if err := json.Unmarshal(body, &tokenResp); err != nil {
		return fmt.Errorf("failed to parse token response: %w", err)
	}

	p.copilotToken = tokenResp.Token
	p.tokenExpiry = time.Unix(tokenResp.ExpiresAt, 0)
	p.proxyEndpoint = defaultOpenAIEndpoint

	log.Printf("token refreshed, valid until: %v", p.tokenExpiry)
	return nil
}

func (p *ProxyServer) proxyHandler(w http.ResponseWriter, r *http.Request) {
	if err := p.refreshTokenIfNeeded(); err != nil {
		log.Printf("error refreshing token: %v", err)
		http.Error(w, "failed to authenticate with copilot", http.StatusInternalServerError)
		return
	}

	body, err := io.ReadAll(r.Body)
	if err != nil {
		http.Error(w, "error reading request body", http.StatusBadRequest)
		return
	}
	defer r.Body.Close()

	targetURL := p.proxyEndpoint + r.URL.Path
	if p.proxyEndpoint == "" {
		log.Printf("warning: proxy endpoint is empty, using default openai endpoint")
		targetURL = defaultOpenAIEndpoint + r.URL.Path
	}

	if r.URL.RawQuery != "" {
		targetURL += "?" + r.URL.RawQuery
	}

	log.Printf("proxying request to: %s", targetURL)
	proxyReq, err := http.NewRequest(r.Method, targetURL, bytes.NewReader(body))
	if err != nil {
		http.Error(w, "error creating proxy request", http.StatusInternalServerError)
		return
	}

	proxyReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.copilotToken))
	proxyReq.Header.Set("Copilot-Integration-Id", "vscode-chat")
	proxyReq.Header.Set("Editor-Version", "Neovim/0.6.1")

	resp, err := p.client.Do(proxyReq)
	if err != nil {
		log.Printf("error forwarding request to copilot: %v", err)
		http.Error(w, "error forwarding request to copilot", http.StatusBadGateway)
		return
	}
	defer resp.Body.Close()

	for name, values := range resp.Header {
		for _, value := range values {
			w.Header().Add(name, value)
		}
	}

	w.WriteHeader(resp.StatusCode)

	if _, err := io.Copy(w, resp.Body); err != nil {
		log.Printf("error copying response body: %v", err)
	}
}

func main() {
	authToken := os.Getenv("GITHUB_AUTH_TOKEN")
	if authToken == "" {
		log.Fatal("GITHUB_AUTH_TOKEN environment variable is required")
	}

	proxy := NewProxyServer(authToken)

	http.HandleFunc("/", proxy.proxyHandler)

	port := os.Getenv("PORT")
	if port == "" {
		port = "8080"
	}

	log.Printf("starting proxy server on port %s", port)
	if err := http.ListenAndServe(":"+port, nil); err != nil {
		log.Fatalf("failed to start server: %v", err)
	}
}