The public-API go-client (search/popular/etc.) had no mirror failover while the agent control-plane client did — a primary-domain takedown broke public calls. Inject a MirrorRoundTripper that reuses the SAME MirrorPool type + IsTransient policy, rotating to cfg.Auth.Mirrors on a transient error/5xx. WithRetry(0) hands failover ownership to the transport (no nested retry).
172 lines
5.5 KiB
Go
172 lines
5.5 KiB
Go
package agent
|
|
|
|
import (
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func TestMirrorRoundTripper_FailoverOn503(t *testing.T) {
|
|
var primaryHits, mirrorHits int
|
|
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
primaryHits++
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
}))
|
|
defer primary.Close()
|
|
mirror := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
mirrorHits++
|
|
w.WriteHeader(http.StatusOK)
|
|
io.WriteString(w, "ok")
|
|
}))
|
|
defer mirror.Close()
|
|
|
|
pool := NewMirrorPool(primary.URL, []string{mirror.URL})
|
|
rt := NewMirrorRoundTripper(pool, http.DefaultTransport)
|
|
req, _ := http.NewRequest(http.MethodGet, primary.URL+"/api/v1/search", nil)
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("RoundTrip: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("status = %d, want 200", resp.StatusCode)
|
|
}
|
|
if primaryHits != 1 || mirrorHits != 1 {
|
|
t.Errorf("hits primary=%d mirror=%d, want 1/1", primaryHits, mirrorHits)
|
|
}
|
|
}
|
|
|
|
func TestMirrorRoundTripper_NoFailoverOn404(t *testing.T) {
|
|
var mirrorHits int
|
|
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}))
|
|
defer primary.Close()
|
|
mirror := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
mirrorHits++
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer mirror.Close()
|
|
|
|
pool := NewMirrorPool(primary.URL, []string{mirror.URL})
|
|
rt := NewMirrorRoundTripper(pool, http.DefaultTransport)
|
|
req, _ := http.NewRequest(http.MethodGet, primary.URL+"/x", nil)
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("RoundTrip: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusNotFound {
|
|
t.Errorf("status = %d, want 404 (surfaced, not retried)", resp.StatusCode)
|
|
}
|
|
if mirrorHits != 0 {
|
|
t.Errorf("mirror hit %d times — must NOT fail over on 404", mirrorHits)
|
|
}
|
|
}
|
|
|
|
func TestMirrorRoundTripper_FailoverOnConnRefused(t *testing.T) {
|
|
dead := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
|
|
deadURL := dead.URL
|
|
dead.Close() // port now refuses connections
|
|
|
|
mirror := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer mirror.Close()
|
|
|
|
pool := NewMirrorPool(deadURL, []string{mirror.URL})
|
|
rt := NewMirrorRoundTripper(pool, http.DefaultTransport)
|
|
req, _ := http.NewRequest(http.MethodGet, deadURL+"/x", nil)
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("RoundTrip should have failed over, got: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("status = %d, want 200 after failover", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestMirrorRoundTripper_ReplaysBodyOnFailover(t *testing.T) {
|
|
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusBadGateway)
|
|
}))
|
|
defer primary.Close()
|
|
var gotBody string
|
|
mirror := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
b, _ := io.ReadAll(r.Body)
|
|
gotBody = string(b)
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer mirror.Close()
|
|
|
|
pool := NewMirrorPool(primary.URL, []string{mirror.URL})
|
|
rt := NewMirrorRoundTripper(pool, http.DefaultTransport)
|
|
req, _ := http.NewRequest(http.MethodPost, primary.URL+"/x", strings.NewReader("payload"))
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("RoundTrip: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if gotBody != "payload" {
|
|
t.Errorf("mirror received body %q, want \"payload\" (body must be replayed on failover)", gotBody)
|
|
}
|
|
}
|
|
|
|
func TestMirrorRoundTripper_NonReplayableBodyNoFailover(t *testing.T) {
|
|
var primaryHits, mirrorHits int
|
|
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
primaryHits++
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
}))
|
|
defer primary.Close()
|
|
mirror := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
mirrorHits++
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer mirror.Close()
|
|
|
|
pool := NewMirrorPool(primary.URL, []string{mirror.URL})
|
|
rt := NewMirrorRoundTripper(pool, http.DefaultTransport)
|
|
// A body with no GetBody can't be replayed → must be sent once, no failover.
|
|
req, _ := http.NewRequest(http.MethodPost, primary.URL+"/x", io.NopCloser(strings.NewReader("payload")))
|
|
req.GetBody = nil
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("RoundTrip: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
|
t.Errorf("status = %d, want 503 (single attempt, no failover)", resp.StatusCode)
|
|
}
|
|
if primaryHits != 1 || mirrorHits != 0 {
|
|
t.Errorf("hits primary=%d mirror=%d, want 1/0 (non-replayable body must not fail over)", primaryHits, mirrorHits)
|
|
}
|
|
}
|
|
|
|
func TestMirrorRoundTripper_SingleMirrorSurfaces503(t *testing.T) {
|
|
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
}))
|
|
defer primary.Close()
|
|
|
|
pool := NewMirrorPool(primary.URL, nil)
|
|
rt := NewMirrorRoundTripper(pool, http.DefaultTransport)
|
|
req, _ := http.NewRequest(http.MethodGet, primary.URL+"/x", nil)
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("RoundTrip: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
|
t.Errorf("status = %d, want 503 surfaced (no mirror to fail over to)", resp.StatusCode)
|
|
}
|
|
}
|