feat(agent): give the public API client mirror failover
The public-API go-client (search/popular/etc.) had no mirror failover while the agent control-plane client did — a primary-domain takedown broke public calls. Inject a MirrorRoundTripper that reuses the SAME MirrorPool type + IsTransient policy, rotating to cfg.Auth.Mirrors on a transient error/5xx. WithRetry(0) hands failover ownership to the transport (no nested retry).
This commit is contained in:
parent
3d51013935
commit
96b23ed051
3 changed files with 278 additions and 0 deletions
88
internal/agent/mirror_transport.go
Normal file
88
internal/agent/mirror_transport.go
Normal file
|
|
@ -0,0 +1,88 @@
|
||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MirrorRoundTripper gives any *http.Client the same mirror failover the agent
|
||||||
|
// control-plane Client has: on a transient transport error or a retryable 5xx
|
||||||
|
// it rewrites the request to the next mirror in the shared MirrorPool and
|
||||||
|
// retries. It exists so the public-API go-client stops diverging from the agent
|
||||||
|
// client — both now survive a primary-domain takedown using the SAME pool and
|
||||||
|
// the SAME transient-error policy (IsTransient).
|
||||||
|
//
|
||||||
|
// Requests whose body cannot be replayed (Body != nil && GetBody == nil) are
|
||||||
|
// sent once with no failover, so a consumed body is never re-read. Standard
|
||||||
|
// library requests built with a *bytes.Reader/strings.Reader (and all GETs) set
|
||||||
|
// GetBody, so this only affects exotic streaming bodies the public API doesn't use.
|
||||||
|
type MirrorRoundTripper struct {
|
||||||
|
pool *MirrorPool
|
||||||
|
inner http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMirrorRoundTripper wraps inner (defaults to http.DefaultTransport) with
|
||||||
|
// failover across pool's mirrors.
|
||||||
|
func NewMirrorRoundTripper(pool *MirrorPool, inner http.RoundTripper) *MirrorRoundTripper {
|
||||||
|
if inner == nil {
|
||||||
|
inner = http.DefaultTransport
|
||||||
|
}
|
||||||
|
return &MirrorRoundTripper{pool: pool, inner: inner}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundTrip points the request at the current mirror and, on a transient
|
||||||
|
// failure, rotates the pool and retries against the next one. A non-transient
|
||||||
|
// HTTP status (4xx, or a 5xx IsTransient doesn't retry) or a non-replayable body
|
||||||
|
// is returned to the caller unchanged.
|
||||||
|
func (m *MirrorRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
attempts := 1
|
||||||
|
if req.Body == nil || req.GetBody != nil { // replayable → may fail over
|
||||||
|
if n := m.pool.Len(); n > attempts {
|
||||||
|
attempts = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for i := 0; i < attempts; i++ {
|
||||||
|
out := req.Clone(req.Context())
|
||||||
|
if req.GetBody != nil {
|
||||||
|
body, err := req.GetBody()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("mirror transport: rebuild body: %w", err)
|
||||||
|
}
|
||||||
|
out.Body = body
|
||||||
|
}
|
||||||
|
if base, err := url.Parse(m.pool.Current()); err == nil && base.Host != "" {
|
||||||
|
out.URL.Scheme = base.Scheme
|
||||||
|
out.URL.Host = base.Host
|
||||||
|
out.Host = base.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := m.inner.RoundTrip(out)
|
||||||
|
last := i == attempts-1
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
if last || !IsTransient(err) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
case resp.StatusCode >= 400 && IsTransient(&HTTPError{StatusCode: resp.StatusCode}):
|
||||||
|
if last {
|
||||||
|
return resp, nil // surface the real 5xx to the caller
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
lastErr = fmt.Errorf("mirror %s: HTTP %d", out.URL.Host, resp.StatusCode)
|
||||||
|
default:
|
||||||
|
return resp, nil // success, or a status we must not retry (4xx/auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, rotated := m.pool.Rotate(); !rotated {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lastErr == nil {
|
||||||
|
lastErr = fmt.Errorf("mirror transport: all mirrors failed")
|
||||||
|
}
|
||||||
|
return nil, lastErr
|
||||||
|
}
|
||||||
172
internal/agent/mirror_transport_test.go
Normal file
172
internal/agent/mirror_transport_test.go
Normal file
|
|
@ -0,0 +1,172 @@
|
||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMirrorRoundTripper_FailoverOn503(t *testing.T) {
|
||||||
|
var primaryHits, mirrorHits int
|
||||||
|
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
primaryHits++
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
}))
|
||||||
|
defer primary.Close()
|
||||||
|
mirror := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
mirrorHits++
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
io.WriteString(w, "ok")
|
||||||
|
}))
|
||||||
|
defer mirror.Close()
|
||||||
|
|
||||||
|
pool := NewMirrorPool(primary.URL, []string{mirror.URL})
|
||||||
|
rt := NewMirrorRoundTripper(pool, http.DefaultTransport)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, primary.URL+"/api/v1/search", nil)
|
||||||
|
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RoundTrip: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want 200", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if primaryHits != 1 || mirrorHits != 1 {
|
||||||
|
t.Errorf("hits primary=%d mirror=%d, want 1/1", primaryHits, mirrorHits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMirrorRoundTripper_NoFailoverOn404(t *testing.T) {
|
||||||
|
var mirrorHits int
|
||||||
|
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer primary.Close()
|
||||||
|
mirror := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
mirrorHits++
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer mirror.Close()
|
||||||
|
|
||||||
|
pool := NewMirrorPool(primary.URL, []string{mirror.URL})
|
||||||
|
rt := NewMirrorRoundTripper(pool, http.DefaultTransport)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, primary.URL+"/x", nil)
|
||||||
|
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RoundTrip: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Errorf("status = %d, want 404 (surfaced, not retried)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if mirrorHits != 0 {
|
||||||
|
t.Errorf("mirror hit %d times — must NOT fail over on 404", mirrorHits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMirrorRoundTripper_FailoverOnConnRefused(t *testing.T) {
|
||||||
|
dead := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
|
||||||
|
deadURL := dead.URL
|
||||||
|
dead.Close() // port now refuses connections
|
||||||
|
|
||||||
|
mirror := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer mirror.Close()
|
||||||
|
|
||||||
|
pool := NewMirrorPool(deadURL, []string{mirror.URL})
|
||||||
|
rt := NewMirrorRoundTripper(pool, http.DefaultTransport)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, deadURL+"/x", nil)
|
||||||
|
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RoundTrip should have failed over, got: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want 200 after failover", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMirrorRoundTripper_ReplaysBodyOnFailover(t *testing.T) {
|
||||||
|
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusBadGateway)
|
||||||
|
}))
|
||||||
|
defer primary.Close()
|
||||||
|
var gotBody string
|
||||||
|
mirror := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
b, _ := io.ReadAll(r.Body)
|
||||||
|
gotBody = string(b)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer mirror.Close()
|
||||||
|
|
||||||
|
pool := NewMirrorPool(primary.URL, []string{mirror.URL})
|
||||||
|
rt := NewMirrorRoundTripper(pool, http.DefaultTransport)
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, primary.URL+"/x", strings.NewReader("payload"))
|
||||||
|
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RoundTrip: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if gotBody != "payload" {
|
||||||
|
t.Errorf("mirror received body %q, want \"payload\" (body must be replayed on failover)", gotBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMirrorRoundTripper_NonReplayableBodyNoFailover(t *testing.T) {
|
||||||
|
var primaryHits, mirrorHits int
|
||||||
|
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
primaryHits++
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
}))
|
||||||
|
defer primary.Close()
|
||||||
|
mirror := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
mirrorHits++
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer mirror.Close()
|
||||||
|
|
||||||
|
pool := NewMirrorPool(primary.URL, []string{mirror.URL})
|
||||||
|
rt := NewMirrorRoundTripper(pool, http.DefaultTransport)
|
||||||
|
// A body with no GetBody can't be replayed → must be sent once, no failover.
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, primary.URL+"/x", io.NopCloser(strings.NewReader("payload")))
|
||||||
|
req.GetBody = nil
|
||||||
|
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RoundTrip: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("status = %d, want 503 (single attempt, no failover)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if primaryHits != 1 || mirrorHits != 0 {
|
||||||
|
t.Errorf("hits primary=%d mirror=%d, want 1/0 (non-replayable body must not fail over)", primaryHits, mirrorHits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMirrorRoundTripper_SingleMirrorSurfaces503(t *testing.T) {
|
||||||
|
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
}))
|
||||||
|
defer primary.Close()
|
||||||
|
|
||||||
|
pool := NewMirrorPool(primary.URL, nil)
|
||||||
|
rt := NewMirrorRoundTripper(pool, http.DefaultTransport)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, primary.URL+"/x", nil)
|
||||||
|
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RoundTrip: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("status = %d, want 503 surfaced (no mirror to fail over to)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -2,11 +2,14 @@ package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/fatih/color"
|
"github.com/fatih/color"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
tc "github.com/torrentclaw/go-client"
|
tc "github.com/torrentclaw/go-client"
|
||||||
|
"github.com/torrentclaw/unarr/internal/agent"
|
||||||
"github.com/torrentclaw/unarr/internal/config"
|
"github.com/torrentclaw/unarr/internal/config"
|
||||||
"github.com/torrentclaw/unarr/internal/sentry"
|
"github.com/torrentclaw/unarr/internal/sentry"
|
||||||
"github.com/torrentclaw/unarr/internal/upgrade"
|
"github.com/torrentclaw/unarr/internal/upgrade"
|
||||||
|
|
@ -235,6 +238,21 @@ func getClient() *tc.Client {
|
||||||
|
|
||||||
opts = append(opts, tc.WithUserAgent("unarr/"+Version))
|
opts = append(opts, tc.WithUserAgent("unarr/"+Version))
|
||||||
|
|
||||||
|
// Mirror failover for the public-API client, matching the agent control-plane
|
||||||
|
// client's resilience: wrap the transport so search/popular/etc. rotate across
|
||||||
|
// cfg.Auth.Mirrors on a primary takedown, using the same MirrorPool TYPE +
|
||||||
|
// IsTransient policy the agent client uses (a fresh pool instance — the two
|
||||||
|
// clients fail over independently). WithRetry(0) disables the go-client's own
|
||||||
|
// retry loop so the transport owns failover exclusively (no nested
|
||||||
|
// retry×backoff on an outage). WithTimeout(30s) is set idiomatically and gives
|
||||||
|
// room for a couple of mirror attempts (go-client's bare default is 15s).
|
||||||
|
pool := agent.NewMirrorPool(cfg.Auth.APIURL, cfg.Auth.Mirrors)
|
||||||
|
opts = append(opts,
|
||||||
|
tc.WithHTTPClient(&http.Client{Transport: agent.NewMirrorRoundTripper(pool, nil)}),
|
||||||
|
tc.WithTimeout(30*time.Second),
|
||||||
|
tc.WithRetry(0, 0, 0),
|
||||||
|
)
|
||||||
|
|
||||||
apiClient = tc.NewClient(opts...)
|
apiClient = tc.NewClient(opts...)
|
||||||
return apiClient
|
return apiClient
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue