copilot-proxy/main_test.go

128 lines
3.5 KiB
Go

package main
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
var mockTokenResponse = TokenResponse{
Token: "mock-token",
ExpiresAt: time.Now().Add(1 * time.Hour).Unix(),
}
func TestNewProxyServer(t *testing.T) {
authToken := "test-auth-token"
proxy := NewProxyServer(authToken)
if proxy.authToken != authToken {
t.Errorf("Expected authToken to be %s, got %s", authToken, proxy.authToken)
}
if proxy.client == nil {
t.Error("Expected http.Client to be initialized")
}
}
func TestRefreshTokenIfNeeded_ValidToken(t *testing.T) {
proxy := NewProxyServer("test-auth-token")
proxy.proxyEndpoint = "http://mock-endpoint.com"
proxy.copilotToken = mockTokenResponse.Token
proxy.client = &http.Client{
Transport: roundTripperFunc(func(req *http.Request) *http.Response {
expectedURL := proxy.proxyEndpoint
actualURL := req.URL.String()
// Normalize both URLs by trimming trailing slashes
expectedURL = strings.TrimRight(expectedURL, "/")
actualURL = strings.TrimRight(actualURL, "/")
if actualURL != expectedURL {
t.Errorf("Expected URL %s, got %s", expectedURL, actualURL)
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString("mock-response")),
Header: make(http.Header),
}
}),
}
proxy.tokenExpiry = time.Now().Add(1 * time.Hour)
err := proxy.refreshTokenIfNeeded()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
}
func TestRefreshTokenIfNeeded_ExpiredToken(t *testing.T) {
proxy := NewProxyServer("test-auth-token")
proxy.copilotToken = ""
proxy.tokenExpiry = time.Now().Add(-1 * time.Hour)
proxy.client = &http.Client{
Transport: roundTripperFunc(func(req *http.Request) *http.Response {
if req.URL.String() != tokenExchangeURL {
t.Errorf("Expected URL %s, got %s", tokenExchangeURL, req.URL.String())
}
respBody, _ := json.Marshal(mockTokenResponse)
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(respBody)),
Header: make(http.Header),
}
}),
}
err := proxy.refreshTokenIfNeeded()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if proxy.copilotToken != mockTokenResponse.Token {
t.Errorf("Expected token to be %s, got %s", mockTokenResponse.Token, proxy.copilotToken)
}
}
func TestProxyHandler(t *testing.T) {
proxy := NewProxyServer("test-auth-token")
proxy.proxyEndpoint = "http://mock-endpoint.com"
proxy.copilotToken = mockTokenResponse.Token
proxy.tokenExpiry = time.Now().Add(1 * time.Hour)
proxy.client = &http.Client{
Transport: roundTripperFunc(func(req *http.Request) *http.Response {
expectedURL := strings.TrimRight(proxy.proxyEndpoint, "/")
actualURL := strings.TrimRight(req.URL.String(), "/")
if actualURL != expectedURL {
t.Errorf("Expected URL %s, got %s", expectedURL, actualURL)
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString("mock-response")),
Header: make(http.Header),
}
}),
}
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString("test-body"))
w := httptest.NewRecorder()
proxy.proxyHandler(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
}
}
type roundTripperFunc func(req *http.Request) *http.Response
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req), nil
}