From 96b23ed0518f95ab47d58c977246797b64550a4e Mon Sep 17 00:00:00 2001 From: Deivid Soto Date: Mon, 1 Jun 2026 15:53:00 +0200 Subject: [PATCH] feat(agent): give the public API client mirror failover MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The public-API go-client (search/popular/etc.) had no mirror failover while the agent control-plane client did — a primary-domain takedown broke public calls. Inject a MirrorRoundTripper that reuses the SAME MirrorPool type + IsTransient policy, rotating to cfg.Auth.Mirrors on a transient error/5xx. WithRetry(0) hands failover ownership to the transport (no nested retry). --- internal/agent/mirror_transport.go | 88 ++++++++++++ internal/agent/mirror_transport_test.go | 172 ++++++++++++++++++++++++ internal/cmd/root.go | 18 +++ 3 files changed, 278 insertions(+) create mode 100644 internal/agent/mirror_transport.go create mode 100644 internal/agent/mirror_transport_test.go diff --git a/internal/agent/mirror_transport.go b/internal/agent/mirror_transport.go new file mode 100644 index 0000000..906207c --- /dev/null +++ b/internal/agent/mirror_transport.go @@ -0,0 +1,88 @@ +package agent + +import ( + "fmt" + "net/http" + "net/url" +) + +// MirrorRoundTripper gives any *http.Client the same mirror failover the agent +// control-plane Client has: on a transient transport error or a retryable 5xx +// it rewrites the request to the next mirror in the shared MirrorPool and +// retries. It exists so the public-API go-client stops diverging from the agent +// client — both now survive a primary-domain takedown using the SAME pool and +// the SAME transient-error policy (IsTransient). +// +// Requests whose body cannot be replayed (Body != nil && GetBody == nil) are +// sent once with no failover, so a consumed body is never re-read. Standard +// library requests built with a *bytes.Reader/strings.Reader (and all GETs) set +// GetBody, so this only affects exotic streaming bodies the public API doesn't use. +type MirrorRoundTripper struct { + pool *MirrorPool + inner http.RoundTripper +} + +// NewMirrorRoundTripper wraps inner (defaults to http.DefaultTransport) with +// failover across pool's mirrors. +func NewMirrorRoundTripper(pool *MirrorPool, inner http.RoundTripper) *MirrorRoundTripper { + if inner == nil { + inner = http.DefaultTransport + } + return &MirrorRoundTripper{pool: pool, inner: inner} +} + +// RoundTrip points the request at the current mirror and, on a transient +// failure, rotates the pool and retries against the next one. A non-transient +// HTTP status (4xx, or a 5xx IsTransient doesn't retry) or a non-replayable body +// is returned to the caller unchanged. +func (m *MirrorRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + attempts := 1 + if req.Body == nil || req.GetBody != nil { // replayable → may fail over + if n := m.pool.Len(); n > attempts { + attempts = n + } + } + + var lastErr error + for i := 0; i < attempts; i++ { + out := req.Clone(req.Context()) + if req.GetBody != nil { + body, err := req.GetBody() + if err != nil { + return nil, fmt.Errorf("mirror transport: rebuild body: %w", err) + } + out.Body = body + } + if base, err := url.Parse(m.pool.Current()); err == nil && base.Host != "" { + out.URL.Scheme = base.Scheme + out.URL.Host = base.Host + out.Host = base.Host + } + + resp, err := m.inner.RoundTrip(out) + last := i == attempts-1 + switch { + case err != nil: + if last || !IsTransient(err) { + return nil, err + } + lastErr = err + case resp.StatusCode >= 400 && IsTransient(&HTTPError{StatusCode: resp.StatusCode}): + if last { + return resp, nil // surface the real 5xx to the caller + } + resp.Body.Close() + lastErr = fmt.Errorf("mirror %s: HTTP %d", out.URL.Host, resp.StatusCode) + default: + return resp, nil // success, or a status we must not retry (4xx/auth) + } + + if _, rotated := m.pool.Rotate(); !rotated { + break + } + } + if lastErr == nil { + lastErr = fmt.Errorf("mirror transport: all mirrors failed") + } + return nil, lastErr +} diff --git a/internal/agent/mirror_transport_test.go b/internal/agent/mirror_transport_test.go new file mode 100644 index 0000000..3781b36 --- /dev/null +++ b/internal/agent/mirror_transport_test.go @@ -0,0 +1,172 @@ +package agent + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestMirrorRoundTripper_FailoverOn503(t *testing.T) { + var primaryHits, mirrorHits int + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + primaryHits++ + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + mirror := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + mirrorHits++ + w.WriteHeader(http.StatusOK) + io.WriteString(w, "ok") + })) + defer mirror.Close() + + pool := NewMirrorPool(primary.URL, []string{mirror.URL}) + rt := NewMirrorRoundTripper(pool, http.DefaultTransport) + req, _ := http.NewRequest(http.MethodGet, primary.URL+"/api/v1/search", nil) + + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want 200", resp.StatusCode) + } + if primaryHits != 1 || mirrorHits != 1 { + t.Errorf("hits primary=%d mirror=%d, want 1/1", primaryHits, mirrorHits) + } +} + +func TestMirrorRoundTripper_NoFailoverOn404(t *testing.T) { + var mirrorHits int + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer primary.Close() + mirror := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + mirrorHits++ + w.WriteHeader(http.StatusOK) + })) + defer mirror.Close() + + pool := NewMirrorPool(primary.URL, []string{mirror.URL}) + rt := NewMirrorRoundTripper(pool, http.DefaultTransport) + req, _ := http.NewRequest(http.MethodGet, primary.URL+"/x", nil) + + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusNotFound { + t.Errorf("status = %d, want 404 (surfaced, not retried)", resp.StatusCode) + } + if mirrorHits != 0 { + t.Errorf("mirror hit %d times — must NOT fail over on 404", mirrorHits) + } +} + +func TestMirrorRoundTripper_FailoverOnConnRefused(t *testing.T) { + dead := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + deadURL := dead.URL + dead.Close() // port now refuses connections + + mirror := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer mirror.Close() + + pool := NewMirrorPool(deadURL, []string{mirror.URL}) + rt := NewMirrorRoundTripper(pool, http.DefaultTransport) + req, _ := http.NewRequest(http.MethodGet, deadURL+"/x", nil) + + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip should have failed over, got: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want 200 after failover", resp.StatusCode) + } +} + +func TestMirrorRoundTripper_ReplaysBodyOnFailover(t *testing.T) { + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadGateway) + })) + defer primary.Close() + var gotBody string + mirror := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, _ := io.ReadAll(r.Body) + gotBody = string(b) + w.WriteHeader(http.StatusOK) + })) + defer mirror.Close() + + pool := NewMirrorPool(primary.URL, []string{mirror.URL}) + rt := NewMirrorRoundTripper(pool, http.DefaultTransport) + req, _ := http.NewRequest(http.MethodPost, primary.URL+"/x", strings.NewReader("payload")) + + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip: %v", err) + } + defer resp.Body.Close() + if gotBody != "payload" { + t.Errorf("mirror received body %q, want \"payload\" (body must be replayed on failover)", gotBody) + } +} + +func TestMirrorRoundTripper_NonReplayableBodyNoFailover(t *testing.T) { + var primaryHits, mirrorHits int + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + primaryHits++ + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + mirror := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + mirrorHits++ + w.WriteHeader(http.StatusOK) + })) + defer mirror.Close() + + pool := NewMirrorPool(primary.URL, []string{mirror.URL}) + rt := NewMirrorRoundTripper(pool, http.DefaultTransport) + // A body with no GetBody can't be replayed → must be sent once, no failover. + req, _ := http.NewRequest(http.MethodPost, primary.URL+"/x", io.NopCloser(strings.NewReader("payload"))) + req.GetBody = nil + + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusServiceUnavailable { + t.Errorf("status = %d, want 503 (single attempt, no failover)", resp.StatusCode) + } + if primaryHits != 1 || mirrorHits != 0 { + t.Errorf("hits primary=%d mirror=%d, want 1/0 (non-replayable body must not fail over)", primaryHits, mirrorHits) + } +} + +func TestMirrorRoundTripper_SingleMirrorSurfaces503(t *testing.T) { + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + pool := NewMirrorPool(primary.URL, nil) + rt := NewMirrorRoundTripper(pool, http.DefaultTransport) + req, _ := http.NewRequest(http.MethodGet, primary.URL+"/x", nil) + + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusServiceUnavailable { + t.Errorf("status = %d, want 503 surfaced (no mirror to fail over to)", resp.StatusCode) + } +} diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 375d8e9..e8ad752 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -2,11 +2,14 @@ package cmd import ( "fmt" + "net/http" "os" + "time" "github.com/fatih/color" "github.com/spf13/cobra" tc "github.com/torrentclaw/go-client" + "github.com/torrentclaw/unarr/internal/agent" "github.com/torrentclaw/unarr/internal/config" "github.com/torrentclaw/unarr/internal/sentry" "github.com/torrentclaw/unarr/internal/upgrade" @@ -235,6 +238,21 @@ func getClient() *tc.Client { opts = append(opts, tc.WithUserAgent("unarr/"+Version)) + // Mirror failover for the public-API client, matching the agent control-plane + // client's resilience: wrap the transport so search/popular/etc. rotate across + // cfg.Auth.Mirrors on a primary takedown, using the same MirrorPool TYPE + + // IsTransient policy the agent client uses (a fresh pool instance — the two + // clients fail over independently). WithRetry(0) disables the go-client's own + // retry loop so the transport owns failover exclusively (no nested + // retry×backoff on an outage). WithTimeout(30s) is set idiomatically and gives + // room for a couple of mirror attempts (go-client's bare default is 15s). + pool := agent.NewMirrorPool(cfg.Auth.APIURL, cfg.Auth.Mirrors) + opts = append(opts, + tc.WithHTTPClient(&http.Client{Transport: agent.NewMirrorRoundTripper(pool, nil)}), + tc.WithTimeout(30*time.Second), + tc.WithRetry(0, 0, 0), + ) + apiClient = tc.NewClient(opts...) return apiClient }