feat(sync): replace WS+DO transport with unified HTTP sync
Replace the WebSocket + Cloudflare Durable Object architecture with a single POST /sync endpoint. The CLI now operates autonomously with local state (tasks.json) and syncs bidirectionally via adaptive-interval HTTP polling (3s watching, 60s idle). - Remove transport_ws, transport_hybrid, transport_http (~2,600 lines) - Add SyncClient with adaptive interval loop - Add LocalState for CLI-side task persistence - Add TaskStateFromUpdate() helper (DRY) - Extract finalize() to deduplicate processTask/processTaskRetry - Consolidate shortID() into agent.ShortID (was in 3 packages) - Wire GetActiveCount so `unarr status` shows active tasks - Remove poll_interval, heartbeat_interval, ws_url from config - Simplify ProgressReporter (sync replaces direct HTTP reporting)
This commit is contained in:
parent
2398707cc1
commit
5d4a67c7a2
26 changed files with 1320 additions and 3400 deletions
|
|
@ -40,27 +40,6 @@ func (c *Client) Register(ctx context.Context, req RegisterRequest) (*RegisterRe
|
|||
return &resp, nil
|
||||
}
|
||||
|
||||
// Heartbeat sends a periodic keep-alive signal and returns server directives.
|
||||
func (c *Client) Heartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
||||
var resp HeartbeatResponse
|
||||
if err := c.doPost(ctx, "/api/internal/agent/heartbeat", req, &resp); err != nil {
|
||||
return nil, fmt.Errorf("heartbeat: %w", err)
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ClaimTasks polls for pending download tasks and claims them atomically.
|
||||
// Also returns any stream requests for completed downloads.
|
||||
func (c *Client) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
|
||||
url := fmt.Sprintf("/api/internal/agent/tasks?agentId=%s", agentID)
|
||||
var resp TasksResponse
|
||||
if err := c.doGet(ctx, url, &resp); err != nil {
|
||||
return nil, fmt.Errorf("claim tasks: %w", err)
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ReportStatus reports download progress or completion for a task.
|
||||
// Deregister notifies the server that the agent is shutting down.
|
||||
func (c *Client) Deregister(ctx context.Context, agentID string) error {
|
||||
req := struct {
|
||||
|
|
@ -91,6 +70,16 @@ func (c *Client) BatchReportStatus(ctx context.Context, updates []StatusUpdate)
|
|||
return &resp, nil
|
||||
}
|
||||
|
||||
// Sync sends the CLI's full state and receives all pending server actions.
|
||||
// This is the single endpoint for bidirectional state synchronization.
|
||||
func (c *Client) Sync(ctx context.Context, req SyncRequest) (*SyncResponse, error) {
|
||||
var resp SyncResponse
|
||||
if err := c.doPost(ctx, "/api/internal/agent/sync", req, &resp); err != nil {
|
||||
return nil, fmt.Errorf("sync: %w", err)
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Usenet endpoints
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -72,70 +72,6 @@ func TestRegister(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHeartbeat(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/internal/agent/heartbeat" {
|
||||
t.Errorf("path = %s, want /api/internal/agent/heartbeat", r.URL.Path)
|
||||
}
|
||||
var req HeartbeatRequest
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
if req.AgentID != "agent-123" {
|
||||
t.Errorf("agentId = %q, want agent-123", req.AgentID)
|
||||
}
|
||||
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-123"})
|
||||
if err != nil {
|
||||
t.Fatalf("Heartbeat failed: %v", err)
|
||||
}
|
||||
if !resp.Success {
|
||||
t.Error("expected success=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaimTasks(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
t.Errorf("method = %s, want GET", r.Method)
|
||||
}
|
||||
if r.URL.Query().Get("agentId") != "agent-123" {
|
||||
t.Errorf("agentId param = %q, want agent-123", r.URL.Query().Get("agentId"))
|
||||
}
|
||||
json.NewEncoder(w).Encode(TasksResponse{
|
||||
Tasks: []Task{
|
||||
{
|
||||
ID: "task-uuid-1",
|
||||
InfoHash: "abc123def456abc123def456abc123def456abc1",
|
||||
Title: "The Matrix (1999)",
|
||||
PreferredMethod: "auto",
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
resp, err := c.ClaimTasks(context.Background(), "agent-123")
|
||||
if err != nil {
|
||||
t.Fatalf("ClaimTasks failed: %v", err)
|
||||
}
|
||||
if len(resp.Tasks) != 1 {
|
||||
t.Fatalf("len(tasks) = %d, want 1", len(resp.Tasks))
|
||||
}
|
||||
if resp.Tasks[0].ID != "task-uuid-1" {
|
||||
t.Errorf("task.ID = %q, want task-uuid-1", resp.Tasks[0].ID)
|
||||
}
|
||||
if resp.Tasks[0].InfoHash != "abc123def456abc123def456abc123def456abc1" {
|
||||
t.Errorf("task.InfoHash = %q", resp.Tasks[0].InfoHash)
|
||||
}
|
||||
if resp.Tasks[0].PreferredMethod != "auto" {
|
||||
t.Errorf("task.PreferredMethod = %q, want auto", resp.Tasks[0].PreferredMethod)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReportStatus(t *testing.T) {
|
||||
var received StatusUpdate
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
@ -173,22 +109,6 @@ func TestReportStatus(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestClaimTasksEmpty(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(TasksResponse{Tasks: []Task{}})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
resp, err := c.ClaimTasks(context.Background(), "agent-123")
|
||||
if err != nil {
|
||||
t.Fatalf("ClaimTasks failed: %v", err)
|
||||
}
|
||||
if len(resp.Tasks) != 0 {
|
||||
t.Errorf("expected empty tasks, got %d", len(resp.Tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
|
|
@ -279,50 +199,12 @@ func TestUserAgent(t *testing.T) {
|
|||
if r.Header.Get("User-Agent") != "unarr/0.2.0" {
|
||||
t.Errorf("User-Agent = %q, want unarr/0.2.0", r.Header.Get("User-Agent"))
|
||||
}
|
||||
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
|
||||
json.NewEncoder(w).Encode(RegisterResponse{Success: true})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr/0.2.0")
|
||||
c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "x"})
|
||||
}
|
||||
|
||||
func TestHeartbeatWithUpgradeSignal(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(HeartbeatResponse{
|
||||
Success: true,
|
||||
Upgrade: &UpgradeSignal{Version: "2.0.0"},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("Heartbeat failed: %v", err)
|
||||
}
|
||||
if resp.Upgrade == nil {
|
||||
t.Fatal("expected upgrade signal, got nil")
|
||||
}
|
||||
if resp.Upgrade.Version != "2.0.0" {
|
||||
t.Errorf("upgrade version = %q, want 2.0.0", resp.Upgrade.Version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeartbeatWithoutUpgradeSignal(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "test-key", "unarr-test")
|
||||
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("Heartbeat failed: %v", err)
|
||||
}
|
||||
if resp.Upgrade != nil {
|
||||
t.Errorf("expected no upgrade signal, got %+v", resp.Upgrade)
|
||||
}
|
||||
c.Register(context.Background(), RegisterRequest{AgentID: "x"})
|
||||
}
|
||||
|
||||
func TestDeregister(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -14,75 +14,62 @@ import (
|
|||
|
||||
// DaemonConfig holds daemon runtime settings.
|
||||
type DaemonConfig struct {
|
||||
AgentID string
|
||||
AgentName string
|
||||
Version string
|
||||
DownloadDir string
|
||||
PollInterval time.Duration
|
||||
HeartbeatInterval time.Duration
|
||||
StreamPort int // port for the HTTP stream server (reported in heartbeat)
|
||||
LanIP string // LAN IP (reported in heartbeat for stream URL resolution)
|
||||
TailscaleIP string // Tailscale IP (reported in heartbeat for stream URL resolution)
|
||||
AgentID string
|
||||
AgentName string
|
||||
Version string
|
||||
DownloadDir string
|
||||
StreamPort int // port for the HTTP stream server
|
||||
LanIP string // LAN IP (reported in sync for stream URL resolution)
|
||||
TailscaleIP string // Tailscale IP (reported in sync for stream URL resolution)
|
||||
}
|
||||
|
||||
// Daemon manages the main loop: register, heartbeat, poll tasks.
|
||||
// Daemon manages agent registration and the sync loop.
|
||||
type Daemon struct {
|
||||
cfg DaemonConfig
|
||||
transport Transport
|
||||
cfg DaemonConfig
|
||||
client *Client
|
||||
sync *SyncClient
|
||||
state *LocalState
|
||||
|
||||
// Callbacks
|
||||
// Callbacks — set by cmd/daemon.go before calling Run.
|
||||
OnTasksClaimed func(tasks []Task)
|
||||
OnStreamRequested func(req StreamRequest)
|
||||
OnControlAction func(action, taskID string)
|
||||
OnControlAction func(action, taskID string, deleteFiles bool)
|
||||
GetActiveCount func() int // returns number of active downloads (wired from manager)
|
||||
|
||||
// State
|
||||
User UserInfo
|
||||
Features FeatureFlags
|
||||
Info AgentInfo
|
||||
State DaemonState
|
||||
heartbeatFailures int
|
||||
lastNotifiedVersion string
|
||||
|
||||
// Callbacks for state tracking (set by cmd/daemon.go)
|
||||
GetActiveCount func() int
|
||||
GetCleanableBytes func() int64
|
||||
|
||||
// Watching tracks whether a user is viewing download progress in the web UI.
|
||||
// When false, the progress reporter skips detailed updates (only sends final states).
|
||||
// Accessed from heartbeat goroutine, flush goroutine, and WatchingFunc closure — must be atomic.
|
||||
Watching atomic.Bool
|
||||
|
||||
// Exposed tickers for hot-reload
|
||||
PollTicker *time.Ticker
|
||||
HeartbeatTicker *time.Ticker
|
||||
|
||||
// pollNow triggers an immediate poll (e.g. on resume)
|
||||
pollNow chan struct{}
|
||||
|
||||
// ScanNow triggers an immediate library scan (from heartbeat or WebSocket control event)
|
||||
// ScanNow triggers an immediate library scan.
|
||||
ScanNow chan struct{}
|
||||
}
|
||||
|
||||
// NewDaemon creates a daemon with the given transport.
|
||||
// Use NewHTTPTransport for HTTP-only, or NewHybridTransport for WS+HTTP.
|
||||
func NewDaemon(cfg DaemonConfig, transport Transport) *Daemon {
|
||||
if cfg.PollInterval == 0 {
|
||||
cfg.PollInterval = 30 * time.Second
|
||||
}
|
||||
if cfg.HeartbeatInterval == 0 {
|
||||
cfg.HeartbeatInterval = 30 * time.Second
|
||||
}
|
||||
|
||||
// NewDaemon creates a daemon with an HTTP client for sync-based communication.
|
||||
func NewDaemon(cfg DaemonConfig, client *Client) *Daemon {
|
||||
state := NewLocalState()
|
||||
return &Daemon{
|
||||
cfg: cfg,
|
||||
transport: transport,
|
||||
pollNow: make(chan struct{}, 1),
|
||||
ScanNow: make(chan struct{}, 1),
|
||||
cfg: cfg,
|
||||
client: client,
|
||||
state: state,
|
||||
sync: NewSyncClient(client, cfg, state),
|
||||
ScanNow: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// Transport returns the configured transport.
|
||||
func (d *Daemon) Transport() Transport { return d.transport }
|
||||
// SyncClient returns the sync client for external wiring.
|
||||
func (d *Daemon) SyncClient() *SyncClient { return d.sync }
|
||||
|
||||
// UpdateStreamPort updates the stream port reported in sync requests.
|
||||
func (d *Daemon) UpdateStreamPort(port int) {
|
||||
d.cfg.StreamPort = port
|
||||
d.sync.cfg.StreamPort = port
|
||||
}
|
||||
|
||||
// Register registers the agent and fetches user info + features.
|
||||
// Retries with exponential backoff on transient errors (429, 5xx, network).
|
||||
|
|
@ -109,11 +96,10 @@ func (d *Daemon) Register(ctx context.Context) error {
|
|||
var resp *RegisterResponse
|
||||
var err error
|
||||
for attempt := range maxRetries {
|
||||
resp, err = d.transport.Register(ctx, req)
|
||||
resp, err = d.client.Register(ctx, req)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
// Only retry on transient errors (429, 5xx, network failures)
|
||||
if !isTransientError(err) {
|
||||
return fmt.Errorf("register: %w", err)
|
||||
}
|
||||
|
|
@ -154,14 +140,9 @@ func (d *Daemon) Register(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Run connects the transport, registers the agent, and starts the main loop.
|
||||
// Blocks until ctx is cancelled. Callers must NOT call transport.Connect before Run.
|
||||
// Run registers the agent and starts the sync loop.
|
||||
// Blocks until ctx is cancelled.
|
||||
func (d *Daemon) Run(ctx context.Context) error {
|
||||
// Connect transport (establishes WebSocket if available, falls back to HTTP)
|
||||
if err := d.transport.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("connect transport: %w", err)
|
||||
}
|
||||
|
||||
// Register
|
||||
if err := d.Register(ctx); err != nil {
|
||||
return err
|
||||
|
|
@ -169,163 +150,61 @@ func (d *Daemon) Run(ctx context.Context) error {
|
|||
|
||||
log.Printf("Agent registered: %s (%s) [%s]", d.User.Name, d.User.Email, d.User.Plan)
|
||||
log.Printf("Features: torrent=%v debrid=%v usenet=%v", d.Features.Torrent, d.Features.Debrid, d.Features.Usenet)
|
||||
log.Printf("Polling every %s, heartbeat every %s", d.cfg.PollInterval, d.cfg.HeartbeatInterval)
|
||||
|
||||
d.HeartbeatTicker = time.NewTicker(d.cfg.HeartbeatInterval)
|
||||
defer d.HeartbeatTicker.Stop()
|
||||
|
||||
d.PollTicker = time.NewTicker(d.cfg.PollInterval)
|
||||
defer d.PollTicker.Stop()
|
||||
|
||||
heartbeatTicker := d.HeartbeatTicker
|
||||
pollTicker := d.PollTicker
|
||||
|
||||
// Initial poll immediately
|
||||
d.poll(ctx)
|
||||
|
||||
eventsCh := d.transport.Events()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Println("Daemon shutting down...")
|
||||
d.deregister()
|
||||
return nil
|
||||
|
||||
case event := <-eventsCh:
|
||||
d.handleEvent(event)
|
||||
|
||||
case <-heartbeatTicker.C:
|
||||
d.heartbeat(ctx)
|
||||
|
||||
case <-pollTicker.C:
|
||||
// Only poll in HTTP mode — WS mode receives tasks via Events
|
||||
if d.transport.Mode() == "http" {
|
||||
d.poll(ctx)
|
||||
}
|
||||
|
||||
case <-d.pollNow:
|
||||
d.poll(ctx)
|
||||
// Wire sync callbacks
|
||||
d.sync.OnNewTasks = func(tasks []Task) {
|
||||
if d.OnTasksClaimed != nil {
|
||||
d.OnTasksClaimed(tasks)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) heartbeat(ctx context.Context) {
|
||||
req := HeartbeatRequest{
|
||||
AgentID: d.cfg.AgentID,
|
||||
Name: d.cfg.AgentName,
|
||||
Version: d.cfg.Version,
|
||||
OS: runtime.GOOS,
|
||||
DownloadDir: d.cfg.DownloadDir,
|
||||
StreamPort: d.cfg.StreamPort,
|
||||
LanIP: d.cfg.LanIP,
|
||||
TailscaleIP: d.cfg.TailscaleIP,
|
||||
}
|
||||
if free, total, err := DiskInfo(d.cfg.DownloadDir); err == nil {
|
||||
req.DiskFreeBytes = free
|
||||
req.DiskTotalBytes = total
|
||||
}
|
||||
|
||||
resp, err := d.transport.SendHeartbeat(ctx, req)
|
||||
if err != nil {
|
||||
d.heartbeatFailures++
|
||||
if d.heartbeatFailures >= 5 && d.heartbeatFailures%5 == 0 {
|
||||
log.Printf("CRITICAL: %d consecutive heartbeat failures — server may be unreachable", d.heartbeatFailures)
|
||||
} else {
|
||||
log.Printf("Heartbeat failed: %v", err)
|
||||
d.sync.OnControl = func(action, taskID string, deleteFiles bool) {
|
||||
if d.OnControlAction != nil {
|
||||
d.OnControlAction(action, taskID, deleteFiles)
|
||||
}
|
||||
return
|
||||
}
|
||||
if d.heartbeatFailures > 0 {
|
||||
log.Printf("Heartbeat recovered after %d failures", d.heartbeatFailures)
|
||||
d.heartbeatFailures = 0
|
||||
d.sync.OnStreamRequest = func(req StreamRequest) {
|
||||
if d.OnStreamRequested != nil {
|
||||
d.OnStreamRequested(req)
|
||||
}
|
||||
}
|
||||
|
||||
// Update watching flag and state file
|
||||
d.Watching.Store(resp.Watching)
|
||||
d.State.LastHeartbeat = time.Now()
|
||||
if d.GetActiveCount != nil {
|
||||
d.State.ActiveTasks = d.GetActiveCount()
|
||||
d.sync.OnUpgrade = func(version string) {
|
||||
if version != d.lastNotifiedVersion {
|
||||
d.lastNotifiedVersion = version
|
||||
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", version)
|
||||
}
|
||||
}
|
||||
WriteState(&d.State)
|
||||
|
||||
// Trigger library scan if requested
|
||||
if resp.Scan {
|
||||
d.sync.OnScan = func() {
|
||||
log.Printf("Library scan requested by server")
|
||||
select {
|
||||
case d.ScanNow <- struct{}{}:
|
||||
default: // scan already pending
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Log once per version when server suggests an upgrade
|
||||
if resp.Upgrade != nil && resp.Upgrade.Version != "" && resp.Upgrade.Version != d.lastNotifiedVersion {
|
||||
d.lastNotifiedVersion = resp.Upgrade.Version
|
||||
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", resp.Upgrade.Version)
|
||||
d.sync.OnWatchingChange = func(watching bool) {
|
||||
d.Watching.Store(watching)
|
||||
}
|
||||
}
|
||||
|
||||
// handleEvent processes a server-initiated event from the WebSocket transport.
|
||||
func (d *Daemon) handleEvent(event ServerEvent) {
|
||||
switch event.Type {
|
||||
case "tasks":
|
||||
if event.Tasks != nil && len(event.Tasks.Tasks) > 0 {
|
||||
log.Printf("Received %d task(s) via WebSocket", len(event.Tasks.Tasks))
|
||||
if d.OnTasksClaimed != nil {
|
||||
d.OnTasksClaimed(event.Tasks.Tasks)
|
||||
}
|
||||
d.sync.OnSyncSuccess = func() {
|
||||
d.State.LastHeartbeat = time.Now()
|
||||
if d.GetActiveCount != nil {
|
||||
d.State.ActiveTasks = d.GetActiveCount()
|
||||
}
|
||||
if event.Tasks != nil && d.OnStreamRequested != nil {
|
||||
for _, sr := range event.Tasks.StreamRequests {
|
||||
d.OnStreamRequested(sr)
|
||||
}
|
||||
}
|
||||
|
||||
case "upgrade":
|
||||
if event.Upgrade != nil && event.Upgrade.Version != "" && event.Upgrade.Version != d.lastNotifiedVersion {
|
||||
d.lastNotifiedVersion = event.Upgrade.Version
|
||||
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", event.Upgrade.Version)
|
||||
}
|
||||
|
||||
case "control":
|
||||
if event.Control != nil {
|
||||
log.Printf("Control action via WebSocket: %s task %s", event.Control.Action, event.Control.TaskID)
|
||||
if event.Control.Action == "scan" {
|
||||
select {
|
||||
case d.ScanNow <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
if d.OnControlAction != nil {
|
||||
d.OnControlAction(event.Control.Action, event.Control.TaskID)
|
||||
}
|
||||
}
|
||||
|
||||
case "disconnected":
|
||||
log.Println("WebSocket disconnected, switching to HTTP polling")
|
||||
WriteState(&d.State)
|
||||
}
|
||||
|
||||
// Start sync loop (blocks)
|
||||
return d.sync.Run(ctx)
|
||||
}
|
||||
|
||||
// UpdateStreamPort updates the stream port reported in heartbeats.
|
||||
// Called after the persistent stream server binds (actual port may differ from configured).
|
||||
func (d *Daemon) UpdateStreamPort(port int) {
|
||||
d.cfg.StreamPort = port
|
||||
// TriggerSync requests an immediate sync cycle.
|
||||
func (d *Daemon) TriggerSync() {
|
||||
d.sync.TriggerSync()
|
||||
}
|
||||
|
||||
// TriggerPoll requests an immediate task poll cycle.
|
||||
// Used when a resume event is received to pick up re-pending tasks faster.
|
||||
func (d *Daemon) TriggerPoll() {
|
||||
select {
|
||||
case d.pollNow <- struct{}{}:
|
||||
default: // already pending
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) deregister() {
|
||||
// Deregister notifies the server of graceful shutdown.
|
||||
func (d *Daemon) Deregister() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := d.transport.Deregister(ctx, d.cfg.AgentID)
|
||||
if err != nil {
|
||||
if err := d.client.Deregister(ctx, d.cfg.AgentID); err != nil {
|
||||
log.Printf("Deregister failed: %v", err)
|
||||
} else {
|
||||
log.Println("Agent deregistered")
|
||||
|
|
@ -338,12 +217,10 @@ func isTransientError(err error) bool {
|
|||
if err == nil {
|
||||
return false
|
||||
}
|
||||
// Structured check: HTTPError carries the status code directly
|
||||
var httpErr *HTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
return httpErr.StatusCode == 429 || httpErr.StatusCode >= 500
|
||||
}
|
||||
// Fallback: network-level errors (no HTTP response received)
|
||||
lower := strings.ToLower(err.Error())
|
||||
for _, keyword := range []string{"connection refused", "no such host", "timeout", "request failed"} {
|
||||
if strings.Contains(lower, keyword) {
|
||||
|
|
@ -352,27 +229,3 @@ func isTransientError(err error) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *Daemon) poll(ctx context.Context) {
|
||||
resp, err := d.transport.ClaimTasks(ctx, d.cfg.AgentID)
|
||||
if err != nil {
|
||||
log.Printf("Poll failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
d.Info.LastPollAt = time.Now()
|
||||
|
||||
if len(resp.Tasks) > 0 {
|
||||
log.Printf("Claimed %d task(s)", len(resp.Tasks))
|
||||
if d.OnTasksClaimed != nil {
|
||||
d.OnTasksClaimed(resp.Tasks)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle stream requests for completed downloads
|
||||
if d.OnStreamRequested != nil {
|
||||
for _, sr := range resp.StreamRequests {
|
||||
d.OnStreamRequested(sr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
195
internal/agent/sync.go
Normal file
195
internal/agent/sync.go
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// SyncIntervalWatching is the sync interval when someone is viewing the web UI.
|
||||
SyncIntervalWatching = 3 * time.Second
|
||||
// SyncIntervalIdle is the sync interval when nobody is watching.
|
||||
SyncIntervalIdle = 60 * time.Second
|
||||
)
|
||||
|
||||
// SyncClient handles bidirectional state synchronization between the CLI and server.
|
||||
// It sends the CLI's full execution state and receives all pending server actions
|
||||
// in a single HTTP round-trip, at an adaptive interval.
|
||||
type SyncClient struct {
|
||||
client *Client
|
||||
cfg DaemonConfig
|
||||
state *LocalState
|
||||
|
||||
// Callbacks — set by the daemon before calling Run.
|
||||
OnNewTasks func(tasks []Task)
|
||||
OnControl func(action, taskID string, deleteFiles bool)
|
||||
OnStreamRequest func(req StreamRequest)
|
||||
OnUpgrade func(version string)
|
||||
OnScan func()
|
||||
OnWatchingChange func(watching bool)
|
||||
OnSyncSuccess func() // called after each successful sync (e.g. to update state file)
|
||||
GetFreeSlots func() int
|
||||
GetTaskStates func() []TaskState // returns current state of all active + recently finished tasks
|
||||
|
||||
// SyncNow triggers an immediate sync (e.g., on task completion).
|
||||
SyncNow chan struct{}
|
||||
|
||||
watching atomic.Bool
|
||||
interval atomic.Int64 // stored as nanoseconds
|
||||
}
|
||||
|
||||
// NewSyncClient creates a sync client.
|
||||
func NewSyncClient(client *Client, cfg DaemonConfig, state *LocalState) *SyncClient {
|
||||
sc := &SyncClient{
|
||||
client: client,
|
||||
cfg: cfg,
|
||||
state: state,
|
||||
SyncNow: make(chan struct{}, 1),
|
||||
}
|
||||
sc.interval.Store(int64(SyncIntervalIdle))
|
||||
return sc
|
||||
}
|
||||
|
||||
// Watching returns whether someone is viewing the web UI.
|
||||
func (sc *SyncClient) Watching() bool {
|
||||
return sc.watching.Load()
|
||||
}
|
||||
|
||||
// TriggerSync requests an immediate sync cycle.
|
||||
func (sc *SyncClient) TriggerSync() {
|
||||
select {
|
||||
case sc.SyncNow <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the adaptive sync loop. Blocks until ctx is cancelled.
|
||||
func (sc *SyncClient) Run(ctx context.Context) error {
|
||||
// Initial sync immediately
|
||||
sc.doSync(ctx)
|
||||
|
||||
ticker := time.NewTicker(sc.currentInterval())
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Final sync to report latest state
|
||||
finalCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
sc.doSync(finalCtx)
|
||||
return nil
|
||||
|
||||
case <-ticker.C:
|
||||
sc.doSync(ctx)
|
||||
ticker.Reset(sc.currentInterval())
|
||||
|
||||
case <-sc.SyncNow:
|
||||
sc.doSync(ctx)
|
||||
ticker.Reset(sc.currentInterval())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *SyncClient) currentInterval() time.Duration {
|
||||
return time.Duration(sc.interval.Load())
|
||||
}
|
||||
|
||||
func (sc *SyncClient) doSync(ctx context.Context) {
|
||||
req := sc.buildRequest()
|
||||
resp, err := sc.client.Sync(ctx, req)
|
||||
if err != nil {
|
||||
if ctx.Err() == nil {
|
||||
log.Printf("sync failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
sc.processResponse(resp)
|
||||
sc.adjustInterval(resp.Watching)
|
||||
if sc.OnSyncSuccess != nil {
|
||||
sc.OnSyncSuccess()
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *SyncClient) buildRequest() SyncRequest {
|
||||
req := SyncRequest{
|
||||
AgentID: sc.cfg.AgentID,
|
||||
Name: sc.cfg.AgentName,
|
||||
Version: sc.cfg.Version,
|
||||
OS: runtime.GOOS,
|
||||
Arch: runtime.GOARCH,
|
||||
DownloadDir: sc.cfg.DownloadDir,
|
||||
StreamPort: sc.cfg.StreamPort,
|
||||
LanIP: sc.cfg.LanIP,
|
||||
TailscaleIP: sc.cfg.TailscaleIP,
|
||||
}
|
||||
if sc.GetTaskStates != nil {
|
||||
req.Tasks = sc.GetTaskStates()
|
||||
} else {
|
||||
req.Tasks = sc.state.Snapshot()
|
||||
}
|
||||
if free, total, err := DiskInfo(sc.cfg.DownloadDir); err == nil {
|
||||
req.DiskFreeBytes = free
|
||||
req.DiskTotalBytes = total
|
||||
}
|
||||
if sc.GetFreeSlots != nil {
|
||||
req.FreeSlots = sc.GetFreeSlots()
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func (sc *SyncClient) processResponse(resp *SyncResponse) {
|
||||
// New tasks
|
||||
if len(resp.NewTasks) > 0 && sc.OnNewTasks != nil {
|
||||
log.Printf("sync: received %d new task(s)", len(resp.NewTasks))
|
||||
sc.OnNewTasks(resp.NewTasks)
|
||||
}
|
||||
|
||||
// Control signals
|
||||
for _, ctrl := range resp.Controls {
|
||||
log.Printf("sync: control %s on task %s", ctrl.Action, ShortID(ctrl.TaskID))
|
||||
if sc.OnControl != nil {
|
||||
sc.OnControl(ctrl.Action, ctrl.TaskID, ctrl.DeleteFiles)
|
||||
}
|
||||
}
|
||||
|
||||
// Stream requests
|
||||
for _, sr := range resp.StreamRequests {
|
||||
if sc.OnStreamRequest != nil {
|
||||
sc.OnStreamRequest(sr)
|
||||
}
|
||||
}
|
||||
|
||||
// Upgrade
|
||||
if resp.Upgrade != nil && resp.Upgrade.Version != "" && sc.OnUpgrade != nil {
|
||||
sc.OnUpgrade(resp.Upgrade.Version)
|
||||
}
|
||||
|
||||
// Scan
|
||||
if resp.Scan && sc.OnScan != nil {
|
||||
sc.OnScan()
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *SyncClient) adjustInterval(watching bool) {
|
||||
prev := sc.watching.Load()
|
||||
sc.watching.Store(watching)
|
||||
|
||||
var newInterval time.Duration
|
||||
if watching {
|
||||
newInterval = SyncIntervalWatching
|
||||
} else {
|
||||
newInterval = SyncIntervalIdle
|
||||
}
|
||||
|
||||
if sc.interval.Swap(int64(newInterval)) != int64(newInterval) {
|
||||
log.Printf("sync: interval=%s (watching=%v)", newInterval, watching)
|
||||
}
|
||||
|
||||
if prev != watching && sc.OnWatchingChange != nil {
|
||||
sc.OnWatchingChange(watching)
|
||||
}
|
||||
}
|
||||
362
internal/agent/sync_test.go
Normal file
362
internal/agent/sync_test.go
Normal file
|
|
@ -0,0 +1,362 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newTestSyncClient(url string) (*SyncClient, *Client) {
|
||||
client := NewClient(url, "test-key", "test-agent/1.0")
|
||||
cfg := DaemonConfig{
|
||||
AgentID: "test-agent",
|
||||
AgentName: "Test",
|
||||
Version: "1.0.0",
|
||||
DownloadDir: "/tmp/downloads",
|
||||
}
|
||||
state := NewLocalState()
|
||||
sc := NewSyncClient(client, cfg, state)
|
||||
return sc, client
|
||||
}
|
||||
|
||||
func TestSyncClient_NewDefaults(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
if sc.Watching() {
|
||||
t.Error("should not be watching initially")
|
||||
}
|
||||
if sc.currentInterval() != SyncIntervalIdle {
|
||||
t.Errorf("expected idle interval %v, got %v", SyncIntervalIdle, sc.currentInterval())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_AdjustInterval_Watching(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
sc.adjustInterval(true)
|
||||
|
||||
if sc.currentInterval() != SyncIntervalWatching {
|
||||
t.Errorf("expected watching interval %v, got %v", SyncIntervalWatching, sc.currentInterval())
|
||||
}
|
||||
if !sc.Watching() {
|
||||
t.Error("expected watching=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_AdjustInterval_NotWatching(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
// First set watching, then unset
|
||||
sc.adjustInterval(true)
|
||||
sc.adjustInterval(false)
|
||||
|
||||
if sc.currentInterval() != SyncIntervalIdle {
|
||||
t.Errorf("expected idle interval %v, got %v", SyncIntervalIdle, sc.currentInterval())
|
||||
}
|
||||
if sc.Watching() {
|
||||
t.Error("expected watching=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_AdjustInterval_CallsOnWatchingChange(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
var changes []bool
|
||||
sc.OnWatchingChange = func(w bool) { changes = append(changes, w) }
|
||||
|
||||
sc.adjustInterval(true)
|
||||
sc.adjustInterval(true) // no change
|
||||
sc.adjustInterval(false) // change
|
||||
|
||||
if len(changes) != 2 {
|
||||
t.Fatalf("expected 2 changes, got %d: %v", len(changes), changes)
|
||||
}
|
||||
if !changes[0] {
|
||||
t.Error("first change should be true")
|
||||
}
|
||||
if changes[1] {
|
||||
t.Error("second change should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_TriggerSync_NonBlocking(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
// Fill the channel
|
||||
sc.TriggerSync()
|
||||
// Should not block
|
||||
sc.TriggerSync()
|
||||
sc.TriggerSync()
|
||||
|
||||
// Drain
|
||||
select {
|
||||
case <-sc.SyncNow:
|
||||
default:
|
||||
t.Error("expected a sync trigger in channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_ProcessResponse_NewTasks(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
var received []Task
|
||||
sc.OnNewTasks = func(tasks []Task) { received = tasks }
|
||||
|
||||
sc.processResponse(&SyncResponse{
|
||||
NewTasks: []Task{
|
||||
{ID: "t1", Title: "Movie 1", InfoHash: "abc"},
|
||||
{ID: "t2", Title: "Movie 2", InfoHash: "def"},
|
||||
},
|
||||
})
|
||||
|
||||
if len(received) != 2 {
|
||||
t.Fatalf("expected 2 tasks, got %d", len(received))
|
||||
}
|
||||
if received[0].Title != "Movie 1" {
|
||||
t.Errorf("expected Movie 1, got %s", received[0].Title)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_ProcessResponse_NoTasks(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
var called bool
|
||||
sc.OnNewTasks = func(tasks []Task) { called = true }
|
||||
|
||||
sc.processResponse(&SyncResponse{NewTasks: nil})
|
||||
|
||||
if called {
|
||||
t.Error("OnNewTasks should not be called with empty tasks")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_ProcessResponse_Controls(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
var actions []string
|
||||
var taskIDs []string
|
||||
sc.OnControl = func(action, taskID string, deleteFiles bool) {
|
||||
actions = append(actions, action)
|
||||
taskIDs = append(taskIDs, taskID)
|
||||
}
|
||||
|
||||
sc.processResponse(&SyncResponse{
|
||||
Controls: []ControlAction{
|
||||
{Action: "cancel", TaskID: "task-1234-5678"},
|
||||
{Action: "pause", TaskID: "task-abcd-efgh"},
|
||||
},
|
||||
})
|
||||
|
||||
if len(actions) != 2 {
|
||||
t.Fatalf("expected 2 controls, got %d", len(actions))
|
||||
}
|
||||
if actions[0] != "cancel" {
|
||||
t.Errorf("expected cancel, got %s", actions[0])
|
||||
}
|
||||
if actions[1] != "pause" {
|
||||
t.Errorf("expected pause, got %s", actions[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_ProcessResponse_Upgrade(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
var version string
|
||||
sc.OnUpgrade = func(v string) { version = v }
|
||||
|
||||
sc.processResponse(&SyncResponse{
|
||||
Upgrade: &UpgradeSignal{Version: "2.0.0"},
|
||||
})
|
||||
|
||||
if version != "2.0.0" {
|
||||
t.Errorf("expected 2.0.0, got %s", version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_ProcessResponse_UpgradeEmpty(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
var called bool
|
||||
sc.OnUpgrade = func(v string) { called = true }
|
||||
|
||||
sc.processResponse(&SyncResponse{
|
||||
Upgrade: &UpgradeSignal{Version: ""},
|
||||
})
|
||||
|
||||
if called {
|
||||
t.Error("OnUpgrade should not be called with empty version")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_ProcessResponse_Scan(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
var called bool
|
||||
sc.OnScan = func() { called = true }
|
||||
|
||||
sc.processResponse(&SyncResponse{Scan: true})
|
||||
|
||||
if !called {
|
||||
t.Error("OnScan should have been called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_ProcessResponse_StreamRequests(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
var received []StreamRequest
|
||||
sc.OnStreamRequest = func(sr StreamRequest) { received = append(received, sr) }
|
||||
|
||||
sc.processResponse(&SyncResponse{
|
||||
StreamRequests: []StreamRequest{
|
||||
{TaskID: "t1", FilePath: "/tmp/movie.mkv"},
|
||||
},
|
||||
})
|
||||
|
||||
if len(received) != 1 {
|
||||
t.Fatalf("expected 1 stream request, got %d", len(received))
|
||||
}
|
||||
if received[0].FilePath != "/tmp/movie.mkv" {
|
||||
t.Errorf("expected /tmp/movie.mkv, got %s", received[0].FilePath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_BuildRequest_WithGetTaskStates(t *testing.T) {
|
||||
sc, _ := newTestSyncClient("http://localhost")
|
||||
|
||||
sc.GetTaskStates = func() []TaskState {
|
||||
return []TaskState{
|
||||
{TaskID: "t1", Status: "downloading", Progress: 50},
|
||||
}
|
||||
}
|
||||
sc.GetFreeSlots = func() int { return 2 }
|
||||
|
||||
req := sc.buildRequest()
|
||||
|
||||
if req.AgentID != "test-agent" {
|
||||
t.Errorf("expected test-agent, got %s", req.AgentID)
|
||||
}
|
||||
if len(req.Tasks) != 1 {
|
||||
t.Fatalf("expected 1 task, got %d", len(req.Tasks))
|
||||
}
|
||||
if req.Tasks[0].Progress != 50 {
|
||||
t.Errorf("expected progress 50, got %d", req.Tasks[0].Progress)
|
||||
}
|
||||
if req.FreeSlots != 2 {
|
||||
t.Errorf("expected 2 free slots, got %d", req.FreeSlots)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_BuildRequest_FallbackToState(t *testing.T) {
|
||||
client := NewClient("http://localhost", "key", "ua")
|
||||
state := NewLocalState()
|
||||
state.Update(TaskState{TaskID: "t1", Status: "completed", Progress: 100})
|
||||
|
||||
sc := NewSyncClient(client, DaemonConfig{AgentID: "a1", Version: "1.0"}, state)
|
||||
// GetTaskStates is nil — should fall back to state.Snapshot()
|
||||
|
||||
req := sc.buildRequest()
|
||||
if len(req.Tasks) != 1 {
|
||||
t.Fatalf("expected 1 task from state fallback, got %d", len(req.Tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_DoSync_Success(t *testing.T) {
|
||||
var syncCount atomic.Int32
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
syncCount.Add(1)
|
||||
json.NewEncoder(w).Encode(SyncResponse{
|
||||
Watching: true,
|
||||
NewTasks: []Task{{ID: "t1", Title: "Test Movie", InfoHash: "abc"}},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
sc, _ := newTestSyncClient(srv.URL)
|
||||
|
||||
var tasksReceived []Task
|
||||
sc.OnNewTasks = func(tasks []Task) { tasksReceived = tasks }
|
||||
|
||||
sc.doSync(context.Background())
|
||||
|
||||
if syncCount.Load() != 1 {
|
||||
t.Errorf("expected 1 sync call, got %d", syncCount.Load())
|
||||
}
|
||||
if len(tasksReceived) != 1 {
|
||||
t.Fatalf("expected 1 task, got %d", len(tasksReceived))
|
||||
}
|
||||
if !sc.Watching() {
|
||||
t.Error("expected watching=true after sync")
|
||||
}
|
||||
if sc.currentInterval() != SyncIntervalWatching {
|
||||
t.Errorf("expected watching interval after sync")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_DoSync_Error(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
sc, _ := newTestSyncClient(srv.URL)
|
||||
|
||||
// Should not panic on error
|
||||
sc.doSync(context.Background())
|
||||
}
|
||||
|
||||
func TestSyncClient_Run_CancelStopsLoop(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(SyncResponse{})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
sc, _ := newTestSyncClient(srv.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := sc.Run(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("expected nil error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncClient_Run_ImmediateSyncOnTrigger(t *testing.T) {
|
||||
var syncCount atomic.Int32
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
syncCount.Add(1)
|
||||
json.NewEncoder(w).Encode(SyncResponse{})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
sc, _ := newTestSyncClient(srv.URL)
|
||||
// Set interval to something long so only triggers cause syncs
|
||||
sc.interval.Store(int64(10 * time.Second))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
// Wait for initial sync, then trigger 2 more
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
sc.TriggerSync()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
sc.TriggerSync()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
sc.Run(ctx)
|
||||
|
||||
// Initial sync (1) + 2 triggers + final sync = 4
|
||||
count := syncCount.Load()
|
||||
if count < 3 {
|
||||
t.Errorf("expected at least 3 syncs (initial + 2 triggers), got %d", count)
|
||||
}
|
||||
}
|
||||
136
internal/agent/taskstate.go
Normal file
136
internal/agent/taskstate.go
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/torrentclaw/unarr/internal/config"
|
||||
)
|
||||
|
||||
// TaskState represents the execution state of a single download task.
|
||||
// Written by the Task Engine, read by the Sync goroutine.
|
||||
type TaskState struct {
|
||||
TaskID string `json:"taskId"`
|
||||
Status string `json:"status"` // resolving, downloading, verifying, organizing, completed, failed
|
||||
Progress int `json:"progress"`
|
||||
DownloadedBytes int64 `json:"downloadedBytes,omitempty"`
|
||||
TotalBytes int64 `json:"totalBytes,omitempty"`
|
||||
SpeedBps int64 `json:"speedBps,omitempty"`
|
||||
ETA int `json:"eta,omitempty"`
|
||||
ResolvedMethod string `json:"resolvedMethod,omitempty"`
|
||||
FileName string `json:"fileName,omitempty"`
|
||||
FilePath string `json:"filePath,omitempty"`
|
||||
StreamURL string `json:"streamUrl,omitempty"`
|
||||
ErrorMessage string `json:"errorMessage,omitempty"`
|
||||
UpdatedAt int64 `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// LocalState holds the CLI's local execution state (tasks.json).
|
||||
// This is the CLI's source of truth for what it's doing right now.
|
||||
type LocalState struct {
|
||||
mu sync.RWMutex
|
||||
tasks map[string]*TaskState
|
||||
}
|
||||
|
||||
// NewLocalState creates an empty local state.
|
||||
func NewLocalState() *LocalState {
|
||||
return &LocalState{
|
||||
tasks: make(map[string]*TaskState),
|
||||
}
|
||||
}
|
||||
|
||||
// Update adds or updates a task in local state.
|
||||
func (s *LocalState) Update(ts TaskState) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
ts.UpdatedAt = time.Now().Unix()
|
||||
copied := ts
|
||||
s.tasks[ts.TaskID] = &copied
|
||||
}
|
||||
|
||||
// Remove removes a task from local state.
|
||||
func (s *LocalState) Remove(taskID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.tasks, taskID)
|
||||
}
|
||||
|
||||
// Snapshot returns a copy of all current task states.
|
||||
func (s *LocalState) Snapshot() []TaskState {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
result := make([]TaskState, 0, len(s.tasks))
|
||||
for _, ts := range s.tasks {
|
||||
result = append(result, *ts)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// TaskStateFromUpdate converts a StatusUpdate into a TaskState.
|
||||
func TaskStateFromUpdate(u StatusUpdate) TaskState {
|
||||
return TaskState{
|
||||
TaskID: u.TaskID,
|
||||
Status: u.Status,
|
||||
Progress: u.Progress,
|
||||
DownloadedBytes: u.DownloadedBytes,
|
||||
TotalBytes: u.TotalBytes,
|
||||
SpeedBps: u.SpeedBps,
|
||||
ETA: u.ETA,
|
||||
ResolvedMethod: u.ResolvedMethod,
|
||||
FileName: u.FileName,
|
||||
FilePath: u.FilePath,
|
||||
StreamURL: u.StreamURL,
|
||||
ErrorMessage: u.ErrorMessage,
|
||||
}
|
||||
}
|
||||
|
||||
// ShortID returns the first 8 characters of an ID, or the full ID if shorter.
|
||||
func ShortID(id string) string {
|
||||
if len(id) > 8 {
|
||||
return id[:8]
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// taskStateFilePathFn is overridable for testing.
|
||||
var taskStateFilePathFn = func() string {
|
||||
return filepath.Join(config.DataDir(), "tasks.json")
|
||||
}
|
||||
|
||||
// WriteToDisk persists local state to disk atomically (best-effort).
|
||||
func (s *LocalState) WriteToDisk() {
|
||||
tasks := s.Snapshot()
|
||||
data, err := json.MarshalIndent(tasks, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
path := taskStateFilePathFn()
|
||||
dir := filepath.Dir(path)
|
||||
os.MkdirAll(dir, 0o755)
|
||||
tmp := path + ".tmp"
|
||||
if err := os.WriteFile(tmp, data, 0o644); err != nil {
|
||||
return
|
||||
}
|
||||
os.Rename(tmp, path)
|
||||
}
|
||||
|
||||
// ReadFromDisk loads local state from disk. Returns empty state on error.
|
||||
func (s *LocalState) ReadFromDisk() {
|
||||
data, err := os.ReadFile(taskStateFilePathFn())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var tasks []TaskState
|
||||
if json.Unmarshal(data, &tasks) != nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.tasks = make(map[string]*TaskState, len(tasks))
|
||||
for i := range tasks {
|
||||
s.tasks[tasks[i].TaskID] = &tasks[i]
|
||||
}
|
||||
}
|
||||
217
internal/agent/taskstate_test.go
Normal file
217
internal/agent/taskstate_test.go
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLocalState_UpdateAndSnapshot(t *testing.T) {
|
||||
s := NewLocalState()
|
||||
|
||||
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 50})
|
||||
s.Update(TaskState{TaskID: "t2", Status: "completed", Progress: 100})
|
||||
|
||||
snap := s.Snapshot()
|
||||
if len(snap) != 2 {
|
||||
t.Fatalf("expected 2 tasks, got %d", len(snap))
|
||||
}
|
||||
|
||||
byID := make(map[string]TaskState, len(snap))
|
||||
for _, ts := range snap {
|
||||
byID[ts.TaskID] = ts
|
||||
}
|
||||
|
||||
if byID["t1"].Progress != 50 {
|
||||
t.Errorf("expected progress 50, got %d", byID["t1"].Progress)
|
||||
}
|
||||
if byID["t2"].Status != "completed" {
|
||||
t.Errorf("expected completed, got %s", byID["t2"].Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_UpdateOverwrites(t *testing.T) {
|
||||
s := NewLocalState()
|
||||
|
||||
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 30})
|
||||
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 70})
|
||||
|
||||
snap := s.Snapshot()
|
||||
if len(snap) != 1 {
|
||||
t.Fatalf("expected 1 task, got %d", len(snap))
|
||||
}
|
||||
if snap[0].Progress != 70 {
|
||||
t.Errorf("expected progress 70, got %d", snap[0].Progress)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_Remove(t *testing.T) {
|
||||
s := NewLocalState()
|
||||
|
||||
s.Update(TaskState{TaskID: "t1", Status: "downloading"})
|
||||
s.Update(TaskState{TaskID: "t2", Status: "downloading"})
|
||||
s.Remove("t1")
|
||||
|
||||
snap := s.Snapshot()
|
||||
if len(snap) != 1 {
|
||||
t.Fatalf("expected 1 task, got %d", len(snap))
|
||||
}
|
||||
if snap[0].TaskID != "t2" {
|
||||
t.Errorf("expected t2, got %s", snap[0].TaskID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_RemoveNonExistent(t *testing.T) {
|
||||
s := NewLocalState()
|
||||
s.Remove("nonexistent") // should not panic
|
||||
}
|
||||
|
||||
func TestLocalState_SnapshotIsACopy(t *testing.T) {
|
||||
s := NewLocalState()
|
||||
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 50})
|
||||
|
||||
snap := s.Snapshot()
|
||||
snap[0].Progress = 999
|
||||
|
||||
snap2 := s.Snapshot()
|
||||
if snap2[0].Progress != 50 {
|
||||
t.Errorf("snapshot mutation leaked: got progress %d", snap2[0].Progress)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_UpdateSetsTimestamp(t *testing.T) {
|
||||
s := NewLocalState()
|
||||
s.Update(TaskState{TaskID: "t1", Status: "downloading"})
|
||||
|
||||
snap := s.Snapshot()
|
||||
if snap[0].UpdatedAt == 0 {
|
||||
t.Error("expected non-zero UpdatedAt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_ConcurrentAccess(t *testing.T) {
|
||||
s := NewLocalState()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := range 100 {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
taskID := "t" + string(rune('0'+n%10))
|
||||
s.Update(TaskState{TaskID: taskID, Status: "downloading", Progress: n})
|
||||
s.Snapshot()
|
||||
if n%3 == 0 {
|
||||
s.Remove(taskID)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
// No race condition = test passes
|
||||
}
|
||||
|
||||
func TestLocalState_WriteToDisk_ReadFromDisk(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "tasks.json")
|
||||
|
||||
// Override the file path for testing
|
||||
orig := taskStateFilePathFn
|
||||
taskStateFilePathFn = func() string { return path }
|
||||
defer func() { taskStateFilePathFn = orig }()
|
||||
|
||||
s := NewLocalState()
|
||||
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 45})
|
||||
s.Update(TaskState{TaskID: "t2", Status: "completed", Progress: 100, FilePath: "/tmp/movie.mkv"})
|
||||
s.WriteToDisk()
|
||||
|
||||
// Verify file exists
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
t.Fatal("tasks.json was not created")
|
||||
}
|
||||
|
||||
// Read into a new LocalState
|
||||
s2 := NewLocalState()
|
||||
s2.ReadFromDisk()
|
||||
|
||||
snap := s2.Snapshot()
|
||||
if len(snap) != 2 {
|
||||
t.Fatalf("expected 2 tasks after read, got %d", len(snap))
|
||||
}
|
||||
|
||||
byID := make(map[string]TaskState, len(snap))
|
||||
for _, ts := range snap {
|
||||
byID[ts.TaskID] = ts
|
||||
}
|
||||
|
||||
if byID["t1"].Progress != 45 {
|
||||
t.Errorf("expected progress 45, got %d", byID["t1"].Progress)
|
||||
}
|
||||
if byID["t2"].FilePath != "/tmp/movie.mkv" {
|
||||
t.Errorf("expected /tmp/movie.mkv, got %s", byID["t2"].FilePath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_ReadFromDisk_CorruptedFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "tasks.json")
|
||||
|
||||
orig := taskStateFilePathFn
|
||||
taskStateFilePathFn = func() string { return path }
|
||||
defer func() { taskStateFilePathFn = orig }()
|
||||
|
||||
// Write corrupted JSON
|
||||
os.WriteFile(path, []byte("{invalid json"), 0o644)
|
||||
|
||||
s := NewLocalState()
|
||||
s.ReadFromDisk() // should not panic
|
||||
|
||||
snap := s.Snapshot()
|
||||
if len(snap) != 0 {
|
||||
t.Errorf("expected 0 tasks from corrupted file, got %d", len(snap))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_ReadFromDisk_FileNotFound(t *testing.T) {
|
||||
orig := taskStateFilePathFn
|
||||
taskStateFilePathFn = func() string { return "/nonexistent/path/tasks.json" }
|
||||
defer func() { taskStateFilePathFn = orig }()
|
||||
|
||||
s := NewLocalState()
|
||||
s.ReadFromDisk() // should not panic
|
||||
|
||||
snap := s.Snapshot()
|
||||
if len(snap) != 0 {
|
||||
t.Errorf("expected 0 tasks, got %d", len(snap))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_AtomicWrite(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "tasks.json")
|
||||
|
||||
orig := taskStateFilePathFn
|
||||
taskStateFilePathFn = func() string { return path }
|
||||
defer func() { taskStateFilePathFn = orig }()
|
||||
|
||||
s := NewLocalState()
|
||||
s.Update(TaskState{TaskID: "t1", Status: "downloading"})
|
||||
s.WriteToDisk()
|
||||
|
||||
// Verify no .tmp file remains
|
||||
tmpPath := path + ".tmp"
|
||||
if _, err := os.Stat(tmpPath); !os.IsNotExist(err) {
|
||||
t.Error("temp file should not exist after write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalState_EmptySnapshot(t *testing.T) {
|
||||
s := NewLocalState()
|
||||
snap := s.Snapshot()
|
||||
if snap == nil {
|
||||
t.Error("snapshot should be non-nil empty slice")
|
||||
}
|
||||
if len(snap) != 0 {
|
||||
t.Errorf("expected 0 tasks, got %d", len(snap))
|
||||
}
|
||||
}
|
||||
|
|
@ -1,51 +0,0 @@
|
|||
package agent
|
||||
|
||||
import "context"
|
||||
|
||||
// Transport abstracts the communication protocol between the agent and server.
|
||||
// Both WebSocket (via CF Durable Object) and HTTP (direct to origin) implement this.
|
||||
type Transport interface {
|
||||
// Connect establishes the transport connection.
|
||||
// Called internally by Daemon.Run — callers must NOT call Connect separately.
|
||||
Connect(ctx context.Context) error
|
||||
|
||||
// Close tears down the connection gracefully.
|
||||
Close() error
|
||||
|
||||
// Mode returns the current transport mode ("ws" or "http").
|
||||
Mode() string
|
||||
|
||||
// Register sends agent registration and returns user info + features.
|
||||
Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error)
|
||||
|
||||
// SendHeartbeat sends a periodic keep-alive.
|
||||
SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error)
|
||||
|
||||
// SendProgress reports download progress for a task.
|
||||
SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error)
|
||||
|
||||
// ClaimTasks polls for new tasks (HTTP mode only; WS receives via Events).
|
||||
ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error)
|
||||
|
||||
// Deregister notifies the server of graceful shutdown.
|
||||
Deregister(ctx context.Context, agentID string) error
|
||||
|
||||
// Events returns a channel that emits server-initiated events.
|
||||
// In HTTP mode this channel is never written to (polling handles it).
|
||||
// In WS mode, tasks/upgrade/control arrive here.
|
||||
Events() <-chan ServerEvent
|
||||
}
|
||||
|
||||
// ServerEvent represents a server-initiated message received via WebSocket.
|
||||
type ServerEvent struct {
|
||||
Type string // "tasks", "upgrade", "control", "disconnected"
|
||||
Tasks *TasksResponse // populated when Type == "tasks"
|
||||
Upgrade *UpgradeSignal // populated when Type == "upgrade"
|
||||
Control *ControlAction // populated when Type == "control"
|
||||
}
|
||||
|
||||
// ControlAction represents a server push for task control.
|
||||
type ControlAction struct {
|
||||
Action string `json:"action"` // "pause", "resume", "cancel", "stream"
|
||||
TaskID string `json:"taskId"`
|
||||
}
|
||||
|
|
@ -1,285 +0,0 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestE2EFullLifecycle tests the full lifecycle:
|
||||
// connect → auth → receive tasks → send progress → receive control → disconnect → reconnect
|
||||
func TestE2EFullLifecycle(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var receivedMessages []map[string]interface{}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var parsed map[string]interface{}
|
||||
json.Unmarshal(msg, &parsed)
|
||||
|
||||
mu.Lock()
|
||||
receivedMessages = append(receivedMessages, parsed)
|
||||
mu.Unlock()
|
||||
|
||||
msgType, _ := parsed["type"].(string)
|
||||
switch msgType {
|
||||
case "auth":
|
||||
conn.WriteJSON(wsRegisteredMessage{
|
||||
Type: "registered",
|
||||
User: UserInfo{Name: "E2E User", Plan: "pro", IsPro: true},
|
||||
Features: FeatureFlags{Torrent: true, Debrid: true},
|
||||
})
|
||||
|
||||
case "heartbeat":
|
||||
// No response in WS mode
|
||||
|
||||
case "progress":
|
||||
// Simulate server-side cancel after progress
|
||||
if progress, ok := parsed["progress"].(float64); ok && progress >= 50 {
|
||||
conn.WriteJSON(map[string]string{
|
||||
"type": "control",
|
||||
"action": "cancel",
|
||||
"taskId": parsed["taskId"].(string),
|
||||
})
|
||||
}
|
||||
|
||||
case "upgrade-result":
|
||||
// Acknowledged
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||
tr := NewWSTransport(wsURL, "e2e-key", "e2e-agent", "test/1.0")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 1. Connect
|
||||
if err := tr.Connect(ctx); err != nil {
|
||||
t.Fatalf("Connect: %v", err)
|
||||
}
|
||||
defer tr.Close()
|
||||
|
||||
// 2. Auth
|
||||
resp, err := tr.Register(ctx, RegisterRequest{
|
||||
AgentID: "e2e-agent",
|
||||
Name: "E2E Test Agent",
|
||||
Version: "1.0.0",
|
||||
OS: "linux",
|
||||
Arch: "amd64",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
if resp.User.Name != "E2E User" {
|
||||
t.Errorf("expected E2E User, got %s", resp.User.Name)
|
||||
}
|
||||
if !resp.Features.Debrid {
|
||||
t.Error("expected debrid feature")
|
||||
}
|
||||
|
||||
// 3. Send heartbeat
|
||||
_, err = tr.SendHeartbeat(ctx, HeartbeatRequest{
|
||||
AgentID: "e2e-agent",
|
||||
DiskFreeBytes: 1000000000,
|
||||
DiskTotalBytes: 5000000000,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendHeartbeat: %v", err)
|
||||
}
|
||||
|
||||
// 4. Send progress (50% → should trigger cancel control)
|
||||
_, err = tr.SendProgress(ctx, StatusUpdate{
|
||||
TaskID: "task-e2e-1",
|
||||
Status: "downloading",
|
||||
Progress: 50,
|
||||
DownloadedBytes: 500,
|
||||
TotalBytes: 1000,
|
||||
SpeedBps: 100,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendProgress: %v", err)
|
||||
}
|
||||
|
||||
// 5. Wait for control event (cancel)
|
||||
select {
|
||||
case event := <-tr.Events():
|
||||
if event.Type != "control" {
|
||||
t.Errorf("expected control event, got %s", event.Type)
|
||||
}
|
||||
if event.Control.Action != "cancel" {
|
||||
t.Errorf("expected cancel, got %s", event.Control.Action)
|
||||
}
|
||||
if event.Control.TaskID != "task-e2e-1" {
|
||||
t.Errorf("expected task-e2e-1, got %s", event.Control.TaskID)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for cancel control")
|
||||
}
|
||||
|
||||
// Verify server received all messages
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if len(receivedMessages) < 3 {
|
||||
t.Fatalf("expected at least 3 messages, got %d", len(receivedMessages))
|
||||
}
|
||||
|
||||
types := make([]string, len(receivedMessages))
|
||||
for i, m := range receivedMessages {
|
||||
types[i], _ = m["type"].(string)
|
||||
}
|
||||
|
||||
expected := []string{"auth", "heartbeat", "progress"}
|
||||
for _, exp := range expected {
|
||||
found := false
|
||||
for _, got := range types {
|
||||
if got == exp {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("missing message type %q in %v", exp, types)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2EHybridFailover tests the full failover scenario:
|
||||
// WS connect → download → WS disconnect → switch to HTTP → continue working
|
||||
func TestE2EHybridFailover(t *testing.T) {
|
||||
connectionCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
connectionCount++
|
||||
connNum := connectionCount
|
||||
mu.Unlock()
|
||||
|
||||
// Read auth
|
||||
conn.ReadMessage()
|
||||
conn.WriteJSON(wsRegisteredMessage{
|
||||
Type: "registered",
|
||||
User: UserInfo{Name: "Failover User"},
|
||||
})
|
||||
|
||||
if connNum == 1 {
|
||||
// First connection: push tasks then disconnect after 200ms
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
conn.WriteJSON(wsTasksMessage{
|
||||
Type: "tasks",
|
||||
Tasks: []Task{{ID: "t1", InfoHash: "abc", Title: "Failover Movie"}},
|
||||
})
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
conn.Close()
|
||||
} else {
|
||||
// Second connection (after reconnect): push upgrade
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
conn.WriteJSON(wsUpgradeMessage{Type: "upgrade", Version: "3.0.0"})
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
conn.Close()
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||
wsT := NewWSTransport(wsURL, "key", "a1", "ua")
|
||||
|
||||
// HTTP mock for fallback
|
||||
httpSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simple heartbeat response
|
||||
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
|
||||
}))
|
||||
defer httpSrv.Close()
|
||||
|
||||
httpT := NewHTTPTransport(httpSrv.URL, "key", "ua")
|
||||
h := NewHybridTransport(wsT, httpT)
|
||||
|
||||
ctx := context.Background()
|
||||
err := h.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Connect: %v", err)
|
||||
}
|
||||
defer h.Close()
|
||||
|
||||
// Should start in WS mode
|
||||
if h.Mode() != "ws" {
|
||||
t.Fatalf("expected ws mode, got %s", h.Mode())
|
||||
}
|
||||
|
||||
// Register via WS
|
||||
_, err = h.Register(ctx, RegisterRequest{AgentID: "a1"})
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
|
||||
// Receive tasks via WS
|
||||
var tasksReceived bool
|
||||
var disconnected bool
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case event := <-h.Events():
|
||||
switch event.Type {
|
||||
case "tasks":
|
||||
tasksReceived = true
|
||||
if len(event.Tasks.Tasks) != 1 || event.Tasks.Tasks[0].Title != "Failover Movie" {
|
||||
t.Errorf("unexpected tasks: %+v", event.Tasks)
|
||||
}
|
||||
case "disconnected":
|
||||
disconnected = true
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
break
|
||||
}
|
||||
if disconnected {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !tasksReceived {
|
||||
t.Error("did not receive tasks before disconnect")
|
||||
}
|
||||
if !disconnected {
|
||||
t.Error("did not receive disconnect event")
|
||||
}
|
||||
|
||||
// Should now be in HTTP mode
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if h.Mode() != "http" {
|
||||
t.Errorf("expected http mode after disconnect, got %s", h.Mode())
|
||||
}
|
||||
|
||||
// Heartbeat should work via HTTP fallback
|
||||
hbResp, err := h.SendHeartbeat(ctx, HeartbeatRequest{AgentID: "a1"})
|
||||
if err != nil {
|
||||
t.Fatalf("SendHeartbeat via HTTP fallback: %v", err)
|
||||
}
|
||||
if !hbResp.Success {
|
||||
t.Error("expected heartbeat success")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
package agent
|
||||
|
||||
import "context"
|
||||
|
||||
// HTTPTransport wraps the existing Client to implement Transport.
|
||||
// This is a thin adapter — no behavioral changes from the current HTTP protocol.
|
||||
type HTTPTransport struct {
|
||||
client *Client
|
||||
events chan ServerEvent
|
||||
}
|
||||
|
||||
// NewHTTPTransport creates a new HTTP-based transport.
|
||||
func NewHTTPTransport(baseURL, apiKey, userAgent string) *HTTPTransport {
|
||||
return &HTTPTransport{
|
||||
client: NewClient(baseURL, apiKey, userAgent),
|
||||
events: make(chan ServerEvent, 10),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *HTTPTransport) Connect(_ context.Context) error { return nil }
|
||||
func (t *HTTPTransport) Close() error { return nil }
|
||||
func (t *HTTPTransport) Mode() string { return "http" }
|
||||
func (t *HTTPTransport) Events() <-chan ServerEvent { return t.events }
|
||||
|
||||
func (t *HTTPTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
|
||||
return t.client.Register(ctx, req)
|
||||
}
|
||||
|
||||
func (t *HTTPTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
||||
return t.client.Heartbeat(ctx, req)
|
||||
}
|
||||
|
||||
func (t *HTTPTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) {
|
||||
return t.client.ReportStatus(ctx, update)
|
||||
}
|
||||
|
||||
func (t *HTTPTransport) BatchReportStatus(ctx context.Context, updates []StatusUpdate) (*BatchStatusResponse, error) {
|
||||
return t.client.BatchReportStatus(ctx, updates)
|
||||
}
|
||||
|
||||
func (t *HTTPTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
|
||||
return t.client.ClaimTasks(ctx, agentID)
|
||||
}
|
||||
|
||||
func (t *HTTPTransport) Deregister(ctx context.Context, agentID string) error {
|
||||
return t.client.Deregister(ctx, agentID)
|
||||
}
|
||||
|
||||
// Client returns the underlying HTTP client for direct use if needed.
|
||||
func (t *HTTPTransport) Client() *Client { return t.client }
|
||||
|
|
@ -1,214 +0,0 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HybridTransport tries WebSocket first, falls back to HTTP if WS fails.
|
||||
// Automatically reconnects WS in the background.
|
||||
type HybridTransport struct {
|
||||
ws *WSTransport
|
||||
http *HTTPTransport
|
||||
|
||||
mode atomic.Value // "ws" or "http"
|
||||
events chan ServerEvent
|
||||
|
||||
reconnectMu sync.Mutex
|
||||
reconnectRunning bool
|
||||
reconnectStop chan struct{}
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
// NewHybridTransport creates a transport that prefers WS with HTTP fallback.
|
||||
func NewHybridTransport(ws *WSTransport, http *HTTPTransport) *HybridTransport {
|
||||
h := &HybridTransport{
|
||||
ws: ws,
|
||||
http: http,
|
||||
events: make(chan ServerEvent, 50),
|
||||
reconnectStop: make(chan struct{}),
|
||||
}
|
||||
h.mode.Store("http") // start in HTTP, upgrade to WS on Connect
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *HybridTransport) Mode() string { return h.mode.Load().(string) }
|
||||
func (h *HybridTransport) Events() <-chan ServerEvent { return h.events }
|
||||
|
||||
// Connect tries WS first. If it fails, falls back to HTTP and starts reconnection loop.
|
||||
func (h *HybridTransport) Connect(ctx context.Context) error {
|
||||
// Try WebSocket first
|
||||
if err := h.ws.Connect(ctx); err != nil {
|
||||
log.Printf("[transport] WebSocket connect failed (%v), using HTTP fallback", err)
|
||||
h.mode.Store("http")
|
||||
h.startReconnectLoop()
|
||||
return h.http.Connect(ctx)
|
||||
}
|
||||
|
||||
h.mode.Store("ws")
|
||||
log.Println("[transport] Connected via WebSocket")
|
||||
|
||||
// Forward WS events to unified channel + watch for disconnection
|
||||
go h.forwardWSEvents()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close shuts down both transports and stops reconnection.
|
||||
func (h *HybridTransport) Close() error {
|
||||
h.closed.Store(true)
|
||||
select {
|
||||
case <-h.reconnectStop:
|
||||
default:
|
||||
close(h.reconnectStop)
|
||||
}
|
||||
_ = h.ws.Close()
|
||||
return h.http.Close()
|
||||
}
|
||||
|
||||
// Register delegates to the active transport.
|
||||
func (h *HybridTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
|
||||
if h.mode.Load() == "ws" {
|
||||
return h.ws.Register(ctx, req)
|
||||
}
|
||||
return h.http.Register(ctx, req)
|
||||
}
|
||||
|
||||
// SendHeartbeat delegates to the active transport.
|
||||
func (h *HybridTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
||||
if h.mode.Load() == "ws" {
|
||||
resp, err := h.ws.SendHeartbeat(ctx, req)
|
||||
if err != nil {
|
||||
// WS write failed — switch to HTTP
|
||||
h.switchToHTTP()
|
||||
return h.http.SendHeartbeat(ctx, req)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
return h.http.SendHeartbeat(ctx, req)
|
||||
}
|
||||
|
||||
// SendProgress delegates to the active transport.
|
||||
func (h *HybridTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) {
|
||||
if h.mode.Load() == "ws" {
|
||||
resp, err := h.ws.SendProgress(ctx, update)
|
||||
if err != nil {
|
||||
h.switchToHTTP()
|
||||
return h.http.SendProgress(ctx, update)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
return h.http.SendProgress(ctx, update)
|
||||
}
|
||||
|
||||
// ClaimTasks delegates to the active transport.
|
||||
func (h *HybridTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
|
||||
if h.mode.Load() == "ws" {
|
||||
return h.ws.ClaimTasks(ctx, agentID) // no-op in WS mode
|
||||
}
|
||||
return h.http.ClaimTasks(ctx, agentID)
|
||||
}
|
||||
|
||||
// Deregister delegates to the active transport.
|
||||
func (h *HybridTransport) Deregister(ctx context.Context, agentID string) error {
|
||||
if h.mode.Load() == "ws" {
|
||||
return h.ws.Deregister(ctx, agentID)
|
||||
}
|
||||
return h.http.Deregister(ctx, agentID)
|
||||
}
|
||||
|
||||
// ── Internal ─────────────────────────────────────────────────────────────────
|
||||
|
||||
func (h *HybridTransport) switchToHTTP() {
|
||||
if h.mode.Load() == "http" {
|
||||
return
|
||||
}
|
||||
log.Println("[transport] Switching to HTTP fallback")
|
||||
h.mode.Store("http")
|
||||
_ = h.ws.Close()
|
||||
h.startReconnectLoop()
|
||||
}
|
||||
|
||||
func (h *HybridTransport) forwardWSEvents() {
|
||||
for {
|
||||
select {
|
||||
case <-h.reconnectStop:
|
||||
return
|
||||
case event, ok := <-h.ws.Events():
|
||||
if !ok {
|
||||
return // channel closed
|
||||
}
|
||||
if event.Type == "disconnected" {
|
||||
h.switchToHTTP()
|
||||
select {
|
||||
case h.events <- event:
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
select {
|
||||
case h.events <- event:
|
||||
default:
|
||||
log.Printf("[transport] events channel full, dropping %s event", event.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HybridTransport) startReconnectLoop() {
|
||||
h.reconnectMu.Lock()
|
||||
defer h.reconnectMu.Unlock()
|
||||
if h.reconnectRunning {
|
||||
return
|
||||
}
|
||||
h.reconnectRunning = true
|
||||
go h.reconnectLoop()
|
||||
}
|
||||
|
||||
func (h *HybridTransport) reconnectLoop() {
|
||||
backoff := 5 * time.Second
|
||||
maxBackoff := 60 * time.Second
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.reconnectStop:
|
||||
return
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
|
||||
if h.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
// Already on WS? (someone else reconnected)
|
||||
if h.mode.Load() == "ws" {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
err := h.ws.Connect(ctx)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[transport] WS reconnect failed: %v (retry in %v)", err, backoff)
|
||||
backoff = min(backoff*2, maxBackoff)
|
||||
continue
|
||||
}
|
||||
|
||||
// WS reconnected — switch back
|
||||
log.Println("[transport] WebSocket reconnected")
|
||||
h.mode.Store("ws")
|
||||
|
||||
// Reset reconnect flag so loop can start again if WS drops
|
||||
h.reconnectMu.Lock()
|
||||
h.reconnectRunning = false
|
||||
h.reconnectMu.Unlock()
|
||||
|
||||
// Forward events from new WS connection
|
||||
go h.forwardWSEvents()
|
||||
return
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,395 +0,0 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// WSTransport communicates with the server via WebSocket through a Cloudflare Durable Object.
|
||||
type WSTransport struct {
|
||||
wsURL string // wss://unarr.torrentclaw.com/ws/{agentId}
|
||||
apiKey string
|
||||
agentID string
|
||||
userAgent string
|
||||
|
||||
conn *websocket.Conn
|
||||
mu sync.Mutex
|
||||
events chan ServerEvent
|
||||
closed atomic.Bool
|
||||
|
||||
// Cached auth response from the DO
|
||||
authResp *RegisterResponse
|
||||
authMu sync.Mutex
|
||||
authDone chan struct{}
|
||||
authDoneOnce sync.Once
|
||||
}
|
||||
|
||||
// NewWSTransport creates a WebSocket-based transport.
|
||||
func NewWSTransport(wsURL, apiKey, agentID, userAgent string) *WSTransport {
|
||||
return &WSTransport{
|
||||
wsURL: wsURL,
|
||||
apiKey: apiKey,
|
||||
agentID: agentID,
|
||||
userAgent: userAgent,
|
||||
events: make(chan ServerEvent, 50),
|
||||
authDone: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WSTransport) Mode() string { return "ws" }
|
||||
func (t *WSTransport) Events() <-chan ServerEvent { return t.events }
|
||||
|
||||
// Connect dials the WebSocket server and starts the read loop.
|
||||
func (t *WSTransport) Connect(ctx context.Context) error {
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
header := http.Header{}
|
||||
header.Set("User-Agent", t.userAgent)
|
||||
|
||||
// Append API key as query param for auth on WS upgrade
|
||||
wsURLWithKey := t.wsURL
|
||||
if t.apiKey != "" {
|
||||
sep := "?"
|
||||
if strings.Contains(wsURLWithKey, "?") {
|
||||
sep = "&"
|
||||
}
|
||||
wsURLWithKey += sep + "key=" + t.apiKey
|
||||
}
|
||||
|
||||
conn, wsResp, err := dialer.DialContext(ctx, wsURLWithKey, header)
|
||||
if wsResp != nil && wsResp.Body != nil {
|
||||
defer wsResp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("ws dial: %w", err)
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
t.conn = conn
|
||||
t.closed.Store(false)
|
||||
t.authDone = make(chan struct{})
|
||||
t.authDoneOnce = sync.Once{}
|
||||
t.mu.Unlock()
|
||||
|
||||
go t.readLoop(conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close sends a close frame and shuts down the connection.
|
||||
func (t *WSTransport) Close() error {
|
||||
t.closed.Store(true)
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.conn != nil {
|
||||
_ = t.conn.WriteMessage(
|
||||
websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
|
||||
)
|
||||
err := t.conn.Close()
|
||||
t.conn = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Register sends auth message and waits for the registered response.
|
||||
func (t *WSTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
|
||||
msg := wsAuthMessage{
|
||||
Type: "auth",
|
||||
APIKey: t.apiKey,
|
||||
AgentID: req.AgentID,
|
||||
Name: req.Name,
|
||||
OS: req.OS,
|
||||
Arch: req.Arch,
|
||||
Version: req.Version,
|
||||
DownloadDir: req.DownloadDir,
|
||||
DiskFreeBytes: req.DiskFreeBytes,
|
||||
DiskTotalBytes: req.DiskTotalBytes,
|
||||
}
|
||||
|
||||
if err := t.send(msg); err != nil {
|
||||
return nil, fmt.Errorf("ws auth send: %w", err)
|
||||
}
|
||||
|
||||
// Wait for the auth response or context cancellation
|
||||
select {
|
||||
case <-t.authDone:
|
||||
t.authMu.Lock()
|
||||
resp := t.authResp
|
||||
t.authMu.Unlock()
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("ws auth: no response received")
|
||||
}
|
||||
return resp, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(15 * time.Second):
|
||||
return nil, fmt.Errorf("ws auth: timeout waiting for registered response")
|
||||
}
|
||||
}
|
||||
|
||||
// SendHeartbeat sends a heartbeat message. No blocking response in WS mode.
|
||||
func (t *WSTransport) SendHeartbeat(_ context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
||||
msg := struct {
|
||||
Type string `json:"type"`
|
||||
Disk *struct {
|
||||
Free int64 `json:"free"`
|
||||
Total int64 `json:"total"`
|
||||
} `json:"disk,omitempty"`
|
||||
}{Type: "heartbeat"}
|
||||
|
||||
if req.DiskFreeBytes > 0 || req.DiskTotalBytes > 0 {
|
||||
msg.Disk = &struct {
|
||||
Free int64 `json:"free"`
|
||||
Total int64 `json:"total"`
|
||||
}{Free: req.DiskFreeBytes, Total: req.DiskTotalBytes}
|
||||
}
|
||||
|
||||
if err := t.send(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// WS mode: heartbeat is fire-and-forget. Upgrade signals arrive via Events().
|
||||
return &HeartbeatResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
// SendProgress sends a progress update. Control signals arrive async via Events().
|
||||
func (t *WSTransport) SendProgress(_ context.Context, update StatusUpdate) (*StatusResponse, error) {
|
||||
msg := struct {
|
||||
Type string `json:"type"`
|
||||
TaskID string `json:"taskId"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Progress int `json:"progress,omitempty"`
|
||||
DownloadedBytes int64 `json:"downloadedBytes,omitempty"`
|
||||
TotalBytes int64 `json:"totalBytes,omitempty"`
|
||||
SpeedBps int64 `json:"speedBps,omitempty"`
|
||||
ETA int `json:"eta,omitempty"`
|
||||
ResolvedMethod string `json:"resolvedMethod,omitempty"`
|
||||
FileName string `json:"fileName,omitempty"`
|
||||
FilePath string `json:"filePath,omitempty"`
|
||||
StreamURL string `json:"streamUrl,omitempty"`
|
||||
StreamReady bool `json:"streamReady,omitempty"`
|
||||
ErrorMessage string `json:"errorMessage,omitempty"`
|
||||
}{
|
||||
Type: "progress",
|
||||
TaskID: update.TaskID,
|
||||
Status: update.Status,
|
||||
Progress: update.Progress,
|
||||
DownloadedBytes: update.DownloadedBytes,
|
||||
TotalBytes: update.TotalBytes,
|
||||
SpeedBps: update.SpeedBps,
|
||||
ETA: update.ETA,
|
||||
ResolvedMethod: update.ResolvedMethod,
|
||||
FileName: update.FileName,
|
||||
FilePath: update.FilePath,
|
||||
StreamURL: update.StreamURL,
|
||||
StreamReady: update.StreamReady,
|
||||
ErrorMessage: update.ErrorMessage,
|
||||
}
|
||||
|
||||
if err := t.send(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// In WS mode, control signals come via Events(), not in the progress response.
|
||||
return &StatusResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
// ClaimTasks is a no-op in WS mode — tasks arrive via Events().
|
||||
func (t *WSTransport) ClaimTasks(_ context.Context, _ string) (*TasksResponse, error) {
|
||||
return &TasksResponse{}, nil
|
||||
}
|
||||
|
||||
// Deregister is handled by WebSocket close (DO detects disconnection).
|
||||
func (t *WSTransport) Deregister(_ context.Context, _ string) error {
|
||||
return t.Close()
|
||||
}
|
||||
|
||||
// ── Internal ─────────────────────────────────────────────────────────────────
|
||||
|
||||
func (t *WSTransport) send(msg any) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.conn == nil {
|
||||
return fmt.Errorf("ws: not connected")
|
||||
}
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = t.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
return t.conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
func (t *WSTransport) readLoop(conn *websocket.Conn) {
|
||||
// Cloudflare idle timeout is 100s. We send pings every 30s and expect
|
||||
// either a pong or a server message within 45s. If neither arrives,
|
||||
// the read deadline fires and we detect the zombie connection.
|
||||
const (
|
||||
pongWait = 45 * time.Second
|
||||
pingPeriod = 30 * time.Second
|
||||
)
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
conn.SetPongHandler(func(string) error {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
|
||||
// Ping ticker goroutine — stops when readLoop returns.
|
||||
pingDone := make(chan struct{})
|
||||
go func() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
t.mu.Lock()
|
||||
if t.conn != nil {
|
||||
_ = t.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
err := t.conn.WriteMessage(websocket.PingMessage, nil)
|
||||
_ = t.conn.SetWriteDeadline(time.Time{})
|
||||
if err != nil {
|
||||
t.mu.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
t.mu.Unlock()
|
||||
case <-pingDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer close(pingDone)
|
||||
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if !t.closed.Load() {
|
||||
log.Printf("[ws] read error: %v", err)
|
||||
// Signal disconnection to the daemon
|
||||
select {
|
||||
case t.events <- ServerEvent{Type: "disconnected"}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Any message (text or pong) proves the connection is alive.
|
||||
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal(msg, &envelope); err != nil {
|
||||
log.Printf("[ws] invalid message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
switch envelope.Type {
|
||||
case "registered":
|
||||
var resp wsRegisteredMessage
|
||||
if json.Unmarshal(msg, &resp) == nil {
|
||||
t.authMu.Lock()
|
||||
t.authResp = &RegisterResponse{
|
||||
Success: true,
|
||||
User: resp.User,
|
||||
Features: resp.Features,
|
||||
}
|
||||
t.authMu.Unlock()
|
||||
// Signal that auth is complete (sync.Once prevents double-close panic)
|
||||
t.authDoneOnce.Do(func() { close(t.authDone) })
|
||||
}
|
||||
|
||||
case "tasks":
|
||||
var resp wsTasksMessage
|
||||
if json.Unmarshal(msg, &resp) == nil {
|
||||
select {
|
||||
case t.events <- ServerEvent{
|
||||
Type: "tasks",
|
||||
Tasks: &TasksResponse{
|
||||
Tasks: resp.Tasks,
|
||||
StreamRequests: resp.StreamRequests,
|
||||
},
|
||||
}:
|
||||
default:
|
||||
log.Printf("[ws] events channel full, dropping tasks message")
|
||||
}
|
||||
}
|
||||
|
||||
case "upgrade":
|
||||
var resp wsUpgradeMessage
|
||||
if json.Unmarshal(msg, &resp) == nil {
|
||||
select {
|
||||
case t.events <- ServerEvent{
|
||||
Type: "upgrade",
|
||||
Upgrade: &UpgradeSignal{Version: resp.Version},
|
||||
}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
case "control":
|
||||
var resp ControlAction
|
||||
if json.Unmarshal(msg, &resp) == nil {
|
||||
select {
|
||||
case t.events <- ServerEvent{
|
||||
Type: "control",
|
||||
Control: &resp,
|
||||
}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
case "error":
|
||||
var resp struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if json.Unmarshal(msg, &resp) == nil {
|
||||
log.Printf("[ws] server error: %s", resp.Message)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── WS message types ─────────────────────────────────────────────────────────
|
||||
|
||||
type wsAuthMessage struct {
|
||||
Type string `json:"type"`
|
||||
APIKey string `json:"apiKey"`
|
||||
AgentID string `json:"agentId"`
|
||||
Name string `json:"name,omitempty"`
|
||||
OS string `json:"os,omitempty"`
|
||||
Arch string `json:"arch,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
DownloadDir string `json:"downloadDir,omitempty"`
|
||||
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
|
||||
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
|
||||
}
|
||||
|
||||
type wsRegisteredMessage struct {
|
||||
Type string `json:"type"`
|
||||
User UserInfo `json:"user"`
|
||||
Features FeatureFlags `json:"features"`
|
||||
}
|
||||
|
||||
type wsTasksMessage struct {
|
||||
Type string `json:"type"`
|
||||
Tasks []Task `json:"tasks"`
|
||||
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
|
||||
}
|
||||
|
||||
type wsUpgradeMessage struct {
|
||||
Type string `json:"type"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
|
@ -50,20 +50,6 @@ type UsenetServerInfo struct {
|
|||
SSL bool `json:"ssl"`
|
||||
}
|
||||
|
||||
// HeartbeatRequest is sent every 30s to keep the agent alive.
|
||||
type HeartbeatRequest struct {
|
||||
AgentID string `json:"agentId"`
|
||||
Name string `json:"name,omitempty"`
|
||||
OS string `json:"os,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
DownloadDir string `json:"downloadDir,omitempty"`
|
||||
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
|
||||
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
|
||||
StreamPort int `json:"streamPort,omitempty"`
|
||||
LanIP string `json:"lanIp,omitempty"`
|
||||
TailscaleIP string `json:"tailscaleIp,omitempty"`
|
||||
}
|
||||
|
||||
// Task represents a download task claimed from the server.
|
||||
type Task struct {
|
||||
ID string `json:"id"`
|
||||
|
|
@ -88,12 +74,6 @@ type Task struct {
|
|||
CollectionName string `json:"collectionName,omitempty"` // Collection name (e.g., "Harry Potter Collection")
|
||||
}
|
||||
|
||||
// TasksResponse wraps the array of tasks returned by the server.
|
||||
type TasksResponse struct {
|
||||
Tasks []Task `json:"tasks"`
|
||||
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
|
||||
}
|
||||
|
||||
// StreamRequest is a request to stream a completed download from disk.
|
||||
type StreamRequest struct {
|
||||
TaskID string `json:"taskId"`
|
||||
|
|
@ -139,14 +119,6 @@ type BatchStatusResponse struct {
|
|||
Watching bool `json:"watching,omitempty"`
|
||||
}
|
||||
|
||||
// HeartbeatResponse is returned by the server on heartbeat.
|
||||
type HeartbeatResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Upgrade *UpgradeSignal `json:"upgrade,omitempty"`
|
||||
Watching bool `json:"watching,omitempty"` // true when a user is viewing download progress in the web UI
|
||||
Scan bool `json:"scan,omitempty"` // true when user triggered a library scan from the web UI
|
||||
}
|
||||
|
||||
// UpgradeSignal tells the agent to upgrade to a specific version.
|
||||
type UpgradeSignal struct {
|
||||
Version string `json:"version"`
|
||||
|
|
@ -176,7 +148,6 @@ type AgentInfo struct {
|
|||
User UserInfo
|
||||
Features FeatureFlags
|
||||
StartedAt time.Time
|
||||
LastPollAt time.Time
|
||||
ActiveTasks int
|
||||
}
|
||||
|
||||
|
|
@ -334,6 +305,45 @@ type LibrarySyncResponse struct {
|
|||
Removed int `json:"removed"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Sync types (unified CLI ↔ Server communication)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SyncRequest is sent by the CLI periodically to synchronize state with the server.
|
||||
// Contains the CLI's full execution state — the server responds with pending actions.
|
||||
type SyncRequest struct {
|
||||
AgentID string `json:"agentId"`
|
||||
Version string `json:"version,omitempty"`
|
||||
OS string `json:"os,omitempty"`
|
||||
Arch string `json:"arch,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
DownloadDir string `json:"downloadDir,omitempty"`
|
||||
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
|
||||
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
|
||||
StreamPort int `json:"streamPort,omitempty"`
|
||||
LanIP string `json:"lanIp,omitempty"`
|
||||
TailscaleIP string `json:"tailscaleIp,omitempty"`
|
||||
FreeSlots int `json:"freeSlots"`
|
||||
Tasks []TaskState `json:"tasks"`
|
||||
}
|
||||
|
||||
// ControlAction represents a server-side control signal for a task.
|
||||
type ControlAction struct {
|
||||
Action string `json:"action"` // "pause", "resume", "cancel", "stream"
|
||||
TaskID string `json:"taskId"`
|
||||
DeleteFiles bool `json:"deleteFiles,omitempty"`
|
||||
}
|
||||
|
||||
// SyncResponse is returned by the server with all pending actions for the CLI.
|
||||
type SyncResponse struct {
|
||||
NewTasks []Task `json:"newTasks,omitempty"`
|
||||
Controls []ControlAction `json:"controls,omitempty"`
|
||||
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
|
||||
Watching bool `json:"watching"`
|
||||
Upgrade *UpgradeSignal `json:"upgrade,omitempty"`
|
||||
Scan bool `json:"scan,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Watch progress types (used by stream tracking)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue