package main import ( "bytes" "encoding/json" "io" "net/http" "net/http/httptest" "strings" "testing" "time" ) var mockTokenResponse = TokenResponse{ Token: "mock-token", ExpiresAt: time.Now().Add(1 * time.Hour).Unix(), } func TestNewProxyServer(t *testing.T) { authToken := "test-auth-token" proxy := NewProxyServer(authToken) if proxy.authToken != authToken { t.Errorf("Expected authToken to be %s, got %s", authToken, proxy.authToken) } if proxy.client == nil { t.Error("Expected http.Client to be initialized") } } func TestRefreshTokenIfNeeded_ValidToken(t *testing.T) { proxy := NewProxyServer("test-auth-token") proxy.proxyEndpoint = "http://mock-endpoint.com" proxy.copilotToken = mockTokenResponse.Token proxy.client = &http.Client{ Transport: roundTripperFunc(func(req *http.Request) *http.Response { expectedURL := proxy.proxyEndpoint actualURL := req.URL.String() // Normalize both URLs by trimming trailing slashes expectedURL = strings.TrimRight(expectedURL, "/") actualURL = strings.TrimRight(actualURL, "/") if actualURL != expectedURL { t.Errorf("Expected URL %s, got %s", expectedURL, actualURL) } return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewBufferString("mock-response")), Header: make(http.Header), } }), } proxy.tokenExpiry = time.Now().Add(1 * time.Hour) err := proxy.refreshTokenIfNeeded() if err != nil { t.Errorf("Expected no error, got %v", err) } } func TestRefreshTokenIfNeeded_ExpiredToken(t *testing.T) { proxy := NewProxyServer("test-auth-token") proxy.copilotToken = "" proxy.tokenExpiry = time.Now().Add(-1 * time.Hour) proxy.client = &http.Client{ Transport: roundTripperFunc(func(req *http.Request) *http.Response { if req.URL.String() != tokenExchangeURL { t.Errorf("Expected URL %s, got %s", tokenExchangeURL, req.URL.String()) } respBody, _ := json.Marshal(mockTokenResponse) return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(respBody)), Header: make(http.Header), } }), } err := proxy.refreshTokenIfNeeded() if err != nil { t.Errorf("Expected no error, got %v", err) } if proxy.copilotToken != mockTokenResponse.Token { t.Errorf("Expected token to be %s, got %s", mockTokenResponse.Token, proxy.copilotToken) } } func TestProxyHandler(t *testing.T) { proxy := NewProxyServer("test-auth-token") proxy.proxyEndpoint = "http://mock-endpoint.com" proxy.copilotToken = mockTokenResponse.Token proxy.tokenExpiry = time.Now().Add(1 * time.Hour) proxy.client = &http.Client{ Transport: roundTripperFunc(func(req *http.Request) *http.Response { expectedURL := strings.TrimRight(proxy.proxyEndpoint, "/") actualURL := strings.TrimRight(req.URL.String(), "/") if actualURL != expectedURL { t.Errorf("Expected URL %s, got %s", expectedURL, actualURL) } return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewBufferString("mock-response")), Header: make(http.Header), } }), } req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString("test-body")) w := httptest.NewRecorder() proxy.proxyHandler(w, req) resp := w.Result() if resp.StatusCode != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) } } type roundTripperFunc func(req *http.Request) *http.Response func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req), nil }