128 lines
3.5 KiB
Go
128 lines
3.5 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
var mockTokenResponse = TokenResponse{
|
|
Token: "mock-token",
|
|
ExpiresAt: time.Now().Add(1 * time.Hour).Unix(),
|
|
}
|
|
|
|
func TestNewProxyServer(t *testing.T) {
|
|
authToken := "test-auth-token"
|
|
proxy := NewProxyServer(authToken)
|
|
|
|
if proxy.authToken != authToken {
|
|
t.Errorf("Expected authToken to be %s, got %s", authToken, proxy.authToken)
|
|
}
|
|
if proxy.client == nil {
|
|
t.Error("Expected http.Client to be initialized")
|
|
}
|
|
}
|
|
|
|
func TestRefreshTokenIfNeeded_ValidToken(t *testing.T) {
|
|
proxy := NewProxyServer("test-auth-token")
|
|
proxy.proxyEndpoint = "http://mock-endpoint.com"
|
|
proxy.copilotToken = mockTokenResponse.Token
|
|
|
|
proxy.client = &http.Client{
|
|
Transport: roundTripperFunc(func(req *http.Request) *http.Response {
|
|
expectedURL := proxy.proxyEndpoint
|
|
actualURL := req.URL.String()
|
|
|
|
// Normalize both URLs by trimming trailing slashes
|
|
expectedURL = strings.TrimRight(expectedURL, "/")
|
|
actualURL = strings.TrimRight(actualURL, "/")
|
|
|
|
if actualURL != expectedURL {
|
|
t.Errorf("Expected URL %s, got %s", expectedURL, actualURL)
|
|
}
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(bytes.NewBufferString("mock-response")),
|
|
Header: make(http.Header),
|
|
}
|
|
}),
|
|
}
|
|
proxy.tokenExpiry = time.Now().Add(1 * time.Hour)
|
|
|
|
err := proxy.refreshTokenIfNeeded()
|
|
if err != nil {
|
|
t.Errorf("Expected no error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRefreshTokenIfNeeded_ExpiredToken(t *testing.T) {
|
|
proxy := NewProxyServer("test-auth-token")
|
|
proxy.copilotToken = ""
|
|
proxy.tokenExpiry = time.Now().Add(-1 * time.Hour)
|
|
|
|
proxy.client = &http.Client{
|
|
Transport: roundTripperFunc(func(req *http.Request) *http.Response {
|
|
if req.URL.String() != tokenExchangeURL {
|
|
t.Errorf("Expected URL %s, got %s", tokenExchangeURL, req.URL.String())
|
|
}
|
|
respBody, _ := json.Marshal(mockTokenResponse)
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
|
Header: make(http.Header),
|
|
}
|
|
}),
|
|
}
|
|
|
|
err := proxy.refreshTokenIfNeeded()
|
|
if err != nil {
|
|
t.Errorf("Expected no error, got %v", err)
|
|
}
|
|
if proxy.copilotToken != mockTokenResponse.Token {
|
|
t.Errorf("Expected token to be %s, got %s", mockTokenResponse.Token, proxy.copilotToken)
|
|
}
|
|
}
|
|
|
|
func TestProxyHandler(t *testing.T) {
|
|
proxy := NewProxyServer("test-auth-token")
|
|
proxy.proxyEndpoint = "http://mock-endpoint.com"
|
|
proxy.copilotToken = mockTokenResponse.Token
|
|
proxy.tokenExpiry = time.Now().Add(1 * time.Hour)
|
|
|
|
proxy.client = &http.Client{
|
|
Transport: roundTripperFunc(func(req *http.Request) *http.Response {
|
|
expectedURL := strings.TrimRight(proxy.proxyEndpoint, "/")
|
|
actualURL := strings.TrimRight(req.URL.String(), "/")
|
|
|
|
if actualURL != expectedURL {
|
|
t.Errorf("Expected URL %s, got %s", expectedURL, actualURL)
|
|
}
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(bytes.NewBufferString("mock-response")),
|
|
Header: make(http.Header),
|
|
}
|
|
}),
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString("test-body"))
|
|
w := httptest.NewRecorder()
|
|
|
|
proxy.proxyHandler(w, req)
|
|
|
|
resp := w.Result()
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
type roundTripperFunc func(req *http.Request) *http.Response
|
|
|
|
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
return f(req), nil
|
|
}
|