feat(sync): replace WS+DO transport with unified HTTP sync

Replace the WebSocket + Cloudflare Durable Object architecture with a
single POST /sync endpoint. The CLI now operates autonomously with local
state (tasks.json) and syncs bidirectionally via adaptive-interval HTTP
polling (3s watching, 60s idle).

- Remove transport_ws, transport_hybrid, transport_http (~2,600 lines)
- Add SyncClient with adaptive interval loop
- Add LocalState for CLI-side task persistence
- Add TaskStateFromUpdate() helper (DRY)
- Extract finalize() to deduplicate processTask/processTaskRetry
- Consolidate shortID() into agent.ShortID (was in 3 packages)
- Wire GetActiveCount so `unarr status` shows active tasks
- Remove poll_interval, heartbeat_interval, ws_url from config
- Simplify ProgressReporter (sync replaces direct HTTP reporting)
This commit is contained in:
Deivid Soto 2026-04-08 18:50:59 +02:00
parent 2398707cc1
commit 5d4a67c7a2
26 changed files with 1320 additions and 3400 deletions

View file

@ -40,27 +40,6 @@ func (c *Client) Register(ctx context.Context, req RegisterRequest) (*RegisterRe
return &resp, nil
}
// Heartbeat sends a periodic keep-alive signal and returns server directives.
func (c *Client) Heartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
var resp HeartbeatResponse
if err := c.doPost(ctx, "/api/internal/agent/heartbeat", req, &resp); err != nil {
return nil, fmt.Errorf("heartbeat: %w", err)
}
return &resp, nil
}
// ClaimTasks polls for pending download tasks and claims them atomically.
// Also returns any stream requests for completed downloads.
func (c *Client) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
url := fmt.Sprintf("/api/internal/agent/tasks?agentId=%s", agentID)
var resp TasksResponse
if err := c.doGet(ctx, url, &resp); err != nil {
return nil, fmt.Errorf("claim tasks: %w", err)
}
return &resp, nil
}
// ReportStatus reports download progress or completion for a task.
// Deregister notifies the server that the agent is shutting down.
func (c *Client) Deregister(ctx context.Context, agentID string) error {
req := struct {
@ -91,6 +70,16 @@ func (c *Client) BatchReportStatus(ctx context.Context, updates []StatusUpdate)
return &resp, nil
}
// Sync sends the CLI's full state and receives all pending server actions.
// This is the single endpoint for bidirectional state synchronization.
func (c *Client) Sync(ctx context.Context, req SyncRequest) (*SyncResponse, error) {
var resp SyncResponse
if err := c.doPost(ctx, "/api/internal/agent/sync", req, &resp); err != nil {
return nil, fmt.Errorf("sync: %w", err)
}
return &resp, nil
}
// ---------------------------------------------------------------------------
// Usenet endpoints
// ---------------------------------------------------------------------------

View file

@ -72,70 +72,6 @@ func TestRegister(t *testing.T) {
}
}
func TestHeartbeat(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/internal/agent/heartbeat" {
t.Errorf("path = %s, want /api/internal/agent/heartbeat", r.URL.Path)
}
var req HeartbeatRequest
json.NewDecoder(r.Body).Decode(&req)
if req.AgentID != "agent-123" {
t.Errorf("agentId = %q, want agent-123", req.AgentID)
}
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-123"})
if err != nil {
t.Fatalf("Heartbeat failed: %v", err)
}
if !resp.Success {
t.Error("expected success=true")
}
}
func TestClaimTasks(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Errorf("method = %s, want GET", r.Method)
}
if r.URL.Query().Get("agentId") != "agent-123" {
t.Errorf("agentId param = %q, want agent-123", r.URL.Query().Get("agentId"))
}
json.NewEncoder(w).Encode(TasksResponse{
Tasks: []Task{
{
ID: "task-uuid-1",
InfoHash: "abc123def456abc123def456abc123def456abc1",
Title: "The Matrix (1999)",
PreferredMethod: "auto",
},
},
})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
resp, err := c.ClaimTasks(context.Background(), "agent-123")
if err != nil {
t.Fatalf("ClaimTasks failed: %v", err)
}
if len(resp.Tasks) != 1 {
t.Fatalf("len(tasks) = %d, want 1", len(resp.Tasks))
}
if resp.Tasks[0].ID != "task-uuid-1" {
t.Errorf("task.ID = %q, want task-uuid-1", resp.Tasks[0].ID)
}
if resp.Tasks[0].InfoHash != "abc123def456abc123def456abc123def456abc1" {
t.Errorf("task.InfoHash = %q", resp.Tasks[0].InfoHash)
}
if resp.Tasks[0].PreferredMethod != "auto" {
t.Errorf("task.PreferredMethod = %q, want auto", resp.Tasks[0].PreferredMethod)
}
}
func TestReportStatus(t *testing.T) {
var received StatusUpdate
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -173,22 +109,6 @@ func TestReportStatus(t *testing.T) {
}
}
func TestClaimTasksEmpty(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(TasksResponse{Tasks: []Task{}})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
resp, err := c.ClaimTasks(context.Background(), "agent-123")
if err != nil {
t.Fatalf("ClaimTasks failed: %v", err)
}
if len(resp.Tasks) != 0 {
t.Errorf("expected empty tasks, got %d", len(resp.Tasks))
}
}
func TestAPIError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
@ -279,50 +199,12 @@ func TestUserAgent(t *testing.T) {
if r.Header.Get("User-Agent") != "unarr/0.2.0" {
t.Errorf("User-Agent = %q, want unarr/0.2.0", r.Header.Get("User-Agent"))
}
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
json.NewEncoder(w).Encode(RegisterResponse{Success: true})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr/0.2.0")
c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "x"})
}
func TestHeartbeatWithUpgradeSignal(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(HeartbeatResponse{
Success: true,
Upgrade: &UpgradeSignal{Version: "2.0.0"},
})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-1"})
if err != nil {
t.Fatalf("Heartbeat failed: %v", err)
}
if resp.Upgrade == nil {
t.Fatal("expected upgrade signal, got nil")
}
if resp.Upgrade.Version != "2.0.0" {
t.Errorf("upgrade version = %q, want 2.0.0", resp.Upgrade.Version)
}
}
func TestHeartbeatWithoutUpgradeSignal(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
}))
defer srv.Close()
c := NewClient(srv.URL, "test-key", "unarr-test")
resp, err := c.Heartbeat(context.Background(), HeartbeatRequest{AgentID: "agent-1"})
if err != nil {
t.Fatalf("Heartbeat failed: %v", err)
}
if resp.Upgrade != nil {
t.Errorf("expected no upgrade signal, got %+v", resp.Upgrade)
}
c.Register(context.Background(), RegisterRequest{AgentID: "x"})
}
func TestDeregister(t *testing.T) {

View file

@ -14,75 +14,62 @@ import (
// DaemonConfig holds daemon runtime settings.
type DaemonConfig struct {
AgentID string
AgentName string
Version string
DownloadDir string
PollInterval time.Duration
HeartbeatInterval time.Duration
StreamPort int // port for the HTTP stream server (reported in heartbeat)
LanIP string // LAN IP (reported in heartbeat for stream URL resolution)
TailscaleIP string // Tailscale IP (reported in heartbeat for stream URL resolution)
AgentID string
AgentName string
Version string
DownloadDir string
StreamPort int // port for the HTTP stream server
LanIP string // LAN IP (reported in sync for stream URL resolution)
TailscaleIP string // Tailscale IP (reported in sync for stream URL resolution)
}
// Daemon manages the main loop: register, heartbeat, poll tasks.
// Daemon manages agent registration and the sync loop.
type Daemon struct {
cfg DaemonConfig
transport Transport
cfg DaemonConfig
client *Client
sync *SyncClient
state *LocalState
// Callbacks
// Callbacks — set by cmd/daemon.go before calling Run.
OnTasksClaimed func(tasks []Task)
OnStreamRequested func(req StreamRequest)
OnControlAction func(action, taskID string)
OnControlAction func(action, taskID string, deleteFiles bool)
GetActiveCount func() int // returns number of active downloads (wired from manager)
// State
User UserInfo
Features FeatureFlags
Info AgentInfo
State DaemonState
heartbeatFailures int
lastNotifiedVersion string
// Callbacks for state tracking (set by cmd/daemon.go)
GetActiveCount func() int
GetCleanableBytes func() int64
// Watching tracks whether a user is viewing download progress in the web UI.
// When false, the progress reporter skips detailed updates (only sends final states).
// Accessed from heartbeat goroutine, flush goroutine, and WatchingFunc closure — must be atomic.
Watching atomic.Bool
// Exposed tickers for hot-reload
PollTicker *time.Ticker
HeartbeatTicker *time.Ticker
// pollNow triggers an immediate poll (e.g. on resume)
pollNow chan struct{}
// ScanNow triggers an immediate library scan (from heartbeat or WebSocket control event)
// ScanNow triggers an immediate library scan.
ScanNow chan struct{}
}
// NewDaemon creates a daemon with the given transport.
// Use NewHTTPTransport for HTTP-only, or NewHybridTransport for WS+HTTP.
func NewDaemon(cfg DaemonConfig, transport Transport) *Daemon {
if cfg.PollInterval == 0 {
cfg.PollInterval = 30 * time.Second
}
if cfg.HeartbeatInterval == 0 {
cfg.HeartbeatInterval = 30 * time.Second
}
// NewDaemon creates a daemon with an HTTP client for sync-based communication.
func NewDaemon(cfg DaemonConfig, client *Client) *Daemon {
state := NewLocalState()
return &Daemon{
cfg: cfg,
transport: transport,
pollNow: make(chan struct{}, 1),
ScanNow: make(chan struct{}, 1),
cfg: cfg,
client: client,
state: state,
sync: NewSyncClient(client, cfg, state),
ScanNow: make(chan struct{}, 1),
}
}
// Transport returns the configured transport.
func (d *Daemon) Transport() Transport { return d.transport }
// SyncClient returns the sync client for external wiring.
func (d *Daemon) SyncClient() *SyncClient { return d.sync }
// UpdateStreamPort updates the stream port reported in sync requests.
func (d *Daemon) UpdateStreamPort(port int) {
d.cfg.StreamPort = port
d.sync.cfg.StreamPort = port
}
// Register registers the agent and fetches user info + features.
// Retries with exponential backoff on transient errors (429, 5xx, network).
@ -109,11 +96,10 @@ func (d *Daemon) Register(ctx context.Context) error {
var resp *RegisterResponse
var err error
for attempt := range maxRetries {
resp, err = d.transport.Register(ctx, req)
resp, err = d.client.Register(ctx, req)
if err == nil {
break
}
// Only retry on transient errors (429, 5xx, network failures)
if !isTransientError(err) {
return fmt.Errorf("register: %w", err)
}
@ -154,14 +140,9 @@ func (d *Daemon) Register(ctx context.Context) error {
return nil
}
// Run connects the transport, registers the agent, and starts the main loop.
// Blocks until ctx is cancelled. Callers must NOT call transport.Connect before Run.
// Run registers the agent and starts the sync loop.
// Blocks until ctx is cancelled.
func (d *Daemon) Run(ctx context.Context) error {
// Connect transport (establishes WebSocket if available, falls back to HTTP)
if err := d.transport.Connect(ctx); err != nil {
return fmt.Errorf("connect transport: %w", err)
}
// Register
if err := d.Register(ctx); err != nil {
return err
@ -169,163 +150,61 @@ func (d *Daemon) Run(ctx context.Context) error {
log.Printf("Agent registered: %s (%s) [%s]", d.User.Name, d.User.Email, d.User.Plan)
log.Printf("Features: torrent=%v debrid=%v usenet=%v", d.Features.Torrent, d.Features.Debrid, d.Features.Usenet)
log.Printf("Polling every %s, heartbeat every %s", d.cfg.PollInterval, d.cfg.HeartbeatInterval)
d.HeartbeatTicker = time.NewTicker(d.cfg.HeartbeatInterval)
defer d.HeartbeatTicker.Stop()
d.PollTicker = time.NewTicker(d.cfg.PollInterval)
defer d.PollTicker.Stop()
heartbeatTicker := d.HeartbeatTicker
pollTicker := d.PollTicker
// Initial poll immediately
d.poll(ctx)
eventsCh := d.transport.Events()
for {
select {
case <-ctx.Done():
log.Println("Daemon shutting down...")
d.deregister()
return nil
case event := <-eventsCh:
d.handleEvent(event)
case <-heartbeatTicker.C:
d.heartbeat(ctx)
case <-pollTicker.C:
// Only poll in HTTP mode — WS mode receives tasks via Events
if d.transport.Mode() == "http" {
d.poll(ctx)
}
case <-d.pollNow:
d.poll(ctx)
// Wire sync callbacks
d.sync.OnNewTasks = func(tasks []Task) {
if d.OnTasksClaimed != nil {
d.OnTasksClaimed(tasks)
}
}
}
func (d *Daemon) heartbeat(ctx context.Context) {
req := HeartbeatRequest{
AgentID: d.cfg.AgentID,
Name: d.cfg.AgentName,
Version: d.cfg.Version,
OS: runtime.GOOS,
DownloadDir: d.cfg.DownloadDir,
StreamPort: d.cfg.StreamPort,
LanIP: d.cfg.LanIP,
TailscaleIP: d.cfg.TailscaleIP,
}
if free, total, err := DiskInfo(d.cfg.DownloadDir); err == nil {
req.DiskFreeBytes = free
req.DiskTotalBytes = total
}
resp, err := d.transport.SendHeartbeat(ctx, req)
if err != nil {
d.heartbeatFailures++
if d.heartbeatFailures >= 5 && d.heartbeatFailures%5 == 0 {
log.Printf("CRITICAL: %d consecutive heartbeat failures — server may be unreachable", d.heartbeatFailures)
} else {
log.Printf("Heartbeat failed: %v", err)
d.sync.OnControl = func(action, taskID string, deleteFiles bool) {
if d.OnControlAction != nil {
d.OnControlAction(action, taskID, deleteFiles)
}
return
}
if d.heartbeatFailures > 0 {
log.Printf("Heartbeat recovered after %d failures", d.heartbeatFailures)
d.heartbeatFailures = 0
d.sync.OnStreamRequest = func(req StreamRequest) {
if d.OnStreamRequested != nil {
d.OnStreamRequested(req)
}
}
// Update watching flag and state file
d.Watching.Store(resp.Watching)
d.State.LastHeartbeat = time.Now()
if d.GetActiveCount != nil {
d.State.ActiveTasks = d.GetActiveCount()
d.sync.OnUpgrade = func(version string) {
if version != d.lastNotifiedVersion {
d.lastNotifiedVersion = version
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", version)
}
}
WriteState(&d.State)
// Trigger library scan if requested
if resp.Scan {
d.sync.OnScan = func() {
log.Printf("Library scan requested by server")
select {
case d.ScanNow <- struct{}{}:
default: // scan already pending
default:
}
}
// Log once per version when server suggests an upgrade
if resp.Upgrade != nil && resp.Upgrade.Version != "" && resp.Upgrade.Version != d.lastNotifiedVersion {
d.lastNotifiedVersion = resp.Upgrade.Version
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", resp.Upgrade.Version)
d.sync.OnWatchingChange = func(watching bool) {
d.Watching.Store(watching)
}
}
// handleEvent processes a server-initiated event from the WebSocket transport.
func (d *Daemon) handleEvent(event ServerEvent) {
switch event.Type {
case "tasks":
if event.Tasks != nil && len(event.Tasks.Tasks) > 0 {
log.Printf("Received %d task(s) via WebSocket", len(event.Tasks.Tasks))
if d.OnTasksClaimed != nil {
d.OnTasksClaimed(event.Tasks.Tasks)
}
d.sync.OnSyncSuccess = func() {
d.State.LastHeartbeat = time.Now()
if d.GetActiveCount != nil {
d.State.ActiveTasks = d.GetActiveCount()
}
if event.Tasks != nil && d.OnStreamRequested != nil {
for _, sr := range event.Tasks.StreamRequests {
d.OnStreamRequested(sr)
}
}
case "upgrade":
if event.Upgrade != nil && event.Upgrade.Version != "" && event.Upgrade.Version != d.lastNotifiedVersion {
d.lastNotifiedVersion = event.Upgrade.Version
log.Printf("New version available: %s (run `unarr self-update` to upgrade)", event.Upgrade.Version)
}
case "control":
if event.Control != nil {
log.Printf("Control action via WebSocket: %s task %s", event.Control.Action, event.Control.TaskID)
if event.Control.Action == "scan" {
select {
case d.ScanNow <- struct{}{}:
default:
}
}
if d.OnControlAction != nil {
d.OnControlAction(event.Control.Action, event.Control.TaskID)
}
}
case "disconnected":
log.Println("WebSocket disconnected, switching to HTTP polling")
WriteState(&d.State)
}
// Start sync loop (blocks)
return d.sync.Run(ctx)
}
// UpdateStreamPort updates the stream port reported in heartbeats.
// Called after the persistent stream server binds (actual port may differ from configured).
func (d *Daemon) UpdateStreamPort(port int) {
d.cfg.StreamPort = port
// TriggerSync requests an immediate sync cycle.
func (d *Daemon) TriggerSync() {
d.sync.TriggerSync()
}
// TriggerPoll requests an immediate task poll cycle.
// Used when a resume event is received to pick up re-pending tasks faster.
func (d *Daemon) TriggerPoll() {
select {
case d.pollNow <- struct{}{}:
default: // already pending
}
}
func (d *Daemon) deregister() {
// Deregister notifies the server of graceful shutdown.
func (d *Daemon) Deregister() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := d.transport.Deregister(ctx, d.cfg.AgentID)
if err != nil {
if err := d.client.Deregister(ctx, d.cfg.AgentID); err != nil {
log.Printf("Deregister failed: %v", err)
} else {
log.Println("Agent deregistered")
@ -338,12 +217,10 @@ func isTransientError(err error) bool {
if err == nil {
return false
}
// Structured check: HTTPError carries the status code directly
var httpErr *HTTPError
if errors.As(err, &httpErr) {
return httpErr.StatusCode == 429 || httpErr.StatusCode >= 500
}
// Fallback: network-level errors (no HTTP response received)
lower := strings.ToLower(err.Error())
for _, keyword := range []string{"connection refused", "no such host", "timeout", "request failed"} {
if strings.Contains(lower, keyword) {
@ -352,27 +229,3 @@ func isTransientError(err error) bool {
}
return false
}
func (d *Daemon) poll(ctx context.Context) {
resp, err := d.transport.ClaimTasks(ctx, d.cfg.AgentID)
if err != nil {
log.Printf("Poll failed: %v", err)
return
}
d.Info.LastPollAt = time.Now()
if len(resp.Tasks) > 0 {
log.Printf("Claimed %d task(s)", len(resp.Tasks))
if d.OnTasksClaimed != nil {
d.OnTasksClaimed(resp.Tasks)
}
}
// Handle stream requests for completed downloads
if d.OnStreamRequested != nil {
for _, sr := range resp.StreamRequests {
d.OnStreamRequested(sr)
}
}
}

195
internal/agent/sync.go Normal file
View file

@ -0,0 +1,195 @@
package agent
import (
"context"
"log"
"runtime"
"sync/atomic"
"time"
)
const (
// SyncIntervalWatching is the sync interval when someone is viewing the web UI.
SyncIntervalWatching = 3 * time.Second
// SyncIntervalIdle is the sync interval when nobody is watching.
SyncIntervalIdle = 60 * time.Second
)
// SyncClient handles bidirectional state synchronization between the CLI and server.
// It sends the CLI's full execution state and receives all pending server actions
// in a single HTTP round-trip, at an adaptive interval.
type SyncClient struct {
client *Client
cfg DaemonConfig
state *LocalState
// Callbacks — set by the daemon before calling Run.
OnNewTasks func(tasks []Task)
OnControl func(action, taskID string, deleteFiles bool)
OnStreamRequest func(req StreamRequest)
OnUpgrade func(version string)
OnScan func()
OnWatchingChange func(watching bool)
OnSyncSuccess func() // called after each successful sync (e.g. to update state file)
GetFreeSlots func() int
GetTaskStates func() []TaskState // returns current state of all active + recently finished tasks
// SyncNow triggers an immediate sync (e.g., on task completion).
SyncNow chan struct{}
watching atomic.Bool
interval atomic.Int64 // stored as nanoseconds
}
// NewSyncClient creates a sync client.
func NewSyncClient(client *Client, cfg DaemonConfig, state *LocalState) *SyncClient {
sc := &SyncClient{
client: client,
cfg: cfg,
state: state,
SyncNow: make(chan struct{}, 1),
}
sc.interval.Store(int64(SyncIntervalIdle))
return sc
}
// Watching returns whether someone is viewing the web UI.
func (sc *SyncClient) Watching() bool {
return sc.watching.Load()
}
// TriggerSync requests an immediate sync cycle.
func (sc *SyncClient) TriggerSync() {
select {
case sc.SyncNow <- struct{}{}:
default:
}
}
// Run starts the adaptive sync loop. Blocks until ctx is cancelled.
func (sc *SyncClient) Run(ctx context.Context) error {
// Initial sync immediately
sc.doSync(ctx)
ticker := time.NewTicker(sc.currentInterval())
defer ticker.Stop()
for {
select {
case <-ctx.Done():
// Final sync to report latest state
finalCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
sc.doSync(finalCtx)
return nil
case <-ticker.C:
sc.doSync(ctx)
ticker.Reset(sc.currentInterval())
case <-sc.SyncNow:
sc.doSync(ctx)
ticker.Reset(sc.currentInterval())
}
}
}
func (sc *SyncClient) currentInterval() time.Duration {
return time.Duration(sc.interval.Load())
}
func (sc *SyncClient) doSync(ctx context.Context) {
req := sc.buildRequest()
resp, err := sc.client.Sync(ctx, req)
if err != nil {
if ctx.Err() == nil {
log.Printf("sync failed: %v", err)
}
return
}
sc.processResponse(resp)
sc.adjustInterval(resp.Watching)
if sc.OnSyncSuccess != nil {
sc.OnSyncSuccess()
}
}
func (sc *SyncClient) buildRequest() SyncRequest {
req := SyncRequest{
AgentID: sc.cfg.AgentID,
Name: sc.cfg.AgentName,
Version: sc.cfg.Version,
OS: runtime.GOOS,
Arch: runtime.GOARCH,
DownloadDir: sc.cfg.DownloadDir,
StreamPort: sc.cfg.StreamPort,
LanIP: sc.cfg.LanIP,
TailscaleIP: sc.cfg.TailscaleIP,
}
if sc.GetTaskStates != nil {
req.Tasks = sc.GetTaskStates()
} else {
req.Tasks = sc.state.Snapshot()
}
if free, total, err := DiskInfo(sc.cfg.DownloadDir); err == nil {
req.DiskFreeBytes = free
req.DiskTotalBytes = total
}
if sc.GetFreeSlots != nil {
req.FreeSlots = sc.GetFreeSlots()
}
return req
}
func (sc *SyncClient) processResponse(resp *SyncResponse) {
// New tasks
if len(resp.NewTasks) > 0 && sc.OnNewTasks != nil {
log.Printf("sync: received %d new task(s)", len(resp.NewTasks))
sc.OnNewTasks(resp.NewTasks)
}
// Control signals
for _, ctrl := range resp.Controls {
log.Printf("sync: control %s on task %s", ctrl.Action, ShortID(ctrl.TaskID))
if sc.OnControl != nil {
sc.OnControl(ctrl.Action, ctrl.TaskID, ctrl.DeleteFiles)
}
}
// Stream requests
for _, sr := range resp.StreamRequests {
if sc.OnStreamRequest != nil {
sc.OnStreamRequest(sr)
}
}
// Upgrade
if resp.Upgrade != nil && resp.Upgrade.Version != "" && sc.OnUpgrade != nil {
sc.OnUpgrade(resp.Upgrade.Version)
}
// Scan
if resp.Scan && sc.OnScan != nil {
sc.OnScan()
}
}
func (sc *SyncClient) adjustInterval(watching bool) {
prev := sc.watching.Load()
sc.watching.Store(watching)
var newInterval time.Duration
if watching {
newInterval = SyncIntervalWatching
} else {
newInterval = SyncIntervalIdle
}
if sc.interval.Swap(int64(newInterval)) != int64(newInterval) {
log.Printf("sync: interval=%s (watching=%v)", newInterval, watching)
}
if prev != watching && sc.OnWatchingChange != nil {
sc.OnWatchingChange(watching)
}
}

362
internal/agent/sync_test.go Normal file
View file

@ -0,0 +1,362 @@
package agent
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
)
func newTestSyncClient(url string) (*SyncClient, *Client) {
client := NewClient(url, "test-key", "test-agent/1.0")
cfg := DaemonConfig{
AgentID: "test-agent",
AgentName: "Test",
Version: "1.0.0",
DownloadDir: "/tmp/downloads",
}
state := NewLocalState()
sc := NewSyncClient(client, cfg, state)
return sc, client
}
func TestSyncClient_NewDefaults(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
if sc.Watching() {
t.Error("should not be watching initially")
}
if sc.currentInterval() != SyncIntervalIdle {
t.Errorf("expected idle interval %v, got %v", SyncIntervalIdle, sc.currentInterval())
}
}
func TestSyncClient_AdjustInterval_Watching(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
sc.adjustInterval(true)
if sc.currentInterval() != SyncIntervalWatching {
t.Errorf("expected watching interval %v, got %v", SyncIntervalWatching, sc.currentInterval())
}
if !sc.Watching() {
t.Error("expected watching=true")
}
}
func TestSyncClient_AdjustInterval_NotWatching(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
// First set watching, then unset
sc.adjustInterval(true)
sc.adjustInterval(false)
if sc.currentInterval() != SyncIntervalIdle {
t.Errorf("expected idle interval %v, got %v", SyncIntervalIdle, sc.currentInterval())
}
if sc.Watching() {
t.Error("expected watching=false")
}
}
func TestSyncClient_AdjustInterval_CallsOnWatchingChange(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
var changes []bool
sc.OnWatchingChange = func(w bool) { changes = append(changes, w) }
sc.adjustInterval(true)
sc.adjustInterval(true) // no change
sc.adjustInterval(false) // change
if len(changes) != 2 {
t.Fatalf("expected 2 changes, got %d: %v", len(changes), changes)
}
if !changes[0] {
t.Error("first change should be true")
}
if changes[1] {
t.Error("second change should be false")
}
}
func TestSyncClient_TriggerSync_NonBlocking(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
// Fill the channel
sc.TriggerSync()
// Should not block
sc.TriggerSync()
sc.TriggerSync()
// Drain
select {
case <-sc.SyncNow:
default:
t.Error("expected a sync trigger in channel")
}
}
func TestSyncClient_ProcessResponse_NewTasks(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
var received []Task
sc.OnNewTasks = func(tasks []Task) { received = tasks }
sc.processResponse(&SyncResponse{
NewTasks: []Task{
{ID: "t1", Title: "Movie 1", InfoHash: "abc"},
{ID: "t2", Title: "Movie 2", InfoHash: "def"},
},
})
if len(received) != 2 {
t.Fatalf("expected 2 tasks, got %d", len(received))
}
if received[0].Title != "Movie 1" {
t.Errorf("expected Movie 1, got %s", received[0].Title)
}
}
func TestSyncClient_ProcessResponse_NoTasks(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
var called bool
sc.OnNewTasks = func(tasks []Task) { called = true }
sc.processResponse(&SyncResponse{NewTasks: nil})
if called {
t.Error("OnNewTasks should not be called with empty tasks")
}
}
func TestSyncClient_ProcessResponse_Controls(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
var actions []string
var taskIDs []string
sc.OnControl = func(action, taskID string, deleteFiles bool) {
actions = append(actions, action)
taskIDs = append(taskIDs, taskID)
}
sc.processResponse(&SyncResponse{
Controls: []ControlAction{
{Action: "cancel", TaskID: "task-1234-5678"},
{Action: "pause", TaskID: "task-abcd-efgh"},
},
})
if len(actions) != 2 {
t.Fatalf("expected 2 controls, got %d", len(actions))
}
if actions[0] != "cancel" {
t.Errorf("expected cancel, got %s", actions[0])
}
if actions[1] != "pause" {
t.Errorf("expected pause, got %s", actions[1])
}
}
func TestSyncClient_ProcessResponse_Upgrade(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
var version string
sc.OnUpgrade = func(v string) { version = v }
sc.processResponse(&SyncResponse{
Upgrade: &UpgradeSignal{Version: "2.0.0"},
})
if version != "2.0.0" {
t.Errorf("expected 2.0.0, got %s", version)
}
}
func TestSyncClient_ProcessResponse_UpgradeEmpty(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
var called bool
sc.OnUpgrade = func(v string) { called = true }
sc.processResponse(&SyncResponse{
Upgrade: &UpgradeSignal{Version: ""},
})
if called {
t.Error("OnUpgrade should not be called with empty version")
}
}
func TestSyncClient_ProcessResponse_Scan(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
var called bool
sc.OnScan = func() { called = true }
sc.processResponse(&SyncResponse{Scan: true})
if !called {
t.Error("OnScan should have been called")
}
}
func TestSyncClient_ProcessResponse_StreamRequests(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
var received []StreamRequest
sc.OnStreamRequest = func(sr StreamRequest) { received = append(received, sr) }
sc.processResponse(&SyncResponse{
StreamRequests: []StreamRequest{
{TaskID: "t1", FilePath: "/tmp/movie.mkv"},
},
})
if len(received) != 1 {
t.Fatalf("expected 1 stream request, got %d", len(received))
}
if received[0].FilePath != "/tmp/movie.mkv" {
t.Errorf("expected /tmp/movie.mkv, got %s", received[0].FilePath)
}
}
func TestSyncClient_BuildRequest_WithGetTaskStates(t *testing.T) {
sc, _ := newTestSyncClient("http://localhost")
sc.GetTaskStates = func() []TaskState {
return []TaskState{
{TaskID: "t1", Status: "downloading", Progress: 50},
}
}
sc.GetFreeSlots = func() int { return 2 }
req := sc.buildRequest()
if req.AgentID != "test-agent" {
t.Errorf("expected test-agent, got %s", req.AgentID)
}
if len(req.Tasks) != 1 {
t.Fatalf("expected 1 task, got %d", len(req.Tasks))
}
if req.Tasks[0].Progress != 50 {
t.Errorf("expected progress 50, got %d", req.Tasks[0].Progress)
}
if req.FreeSlots != 2 {
t.Errorf("expected 2 free slots, got %d", req.FreeSlots)
}
}
func TestSyncClient_BuildRequest_FallbackToState(t *testing.T) {
client := NewClient("http://localhost", "key", "ua")
state := NewLocalState()
state.Update(TaskState{TaskID: "t1", Status: "completed", Progress: 100})
sc := NewSyncClient(client, DaemonConfig{AgentID: "a1", Version: "1.0"}, state)
// GetTaskStates is nil — should fall back to state.Snapshot()
req := sc.buildRequest()
if len(req.Tasks) != 1 {
t.Fatalf("expected 1 task from state fallback, got %d", len(req.Tasks))
}
}
func TestSyncClient_DoSync_Success(t *testing.T) {
var syncCount atomic.Int32
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
syncCount.Add(1)
json.NewEncoder(w).Encode(SyncResponse{
Watching: true,
NewTasks: []Task{{ID: "t1", Title: "Test Movie", InfoHash: "abc"}},
})
}))
defer srv.Close()
sc, _ := newTestSyncClient(srv.URL)
var tasksReceived []Task
sc.OnNewTasks = func(tasks []Task) { tasksReceived = tasks }
sc.doSync(context.Background())
if syncCount.Load() != 1 {
t.Errorf("expected 1 sync call, got %d", syncCount.Load())
}
if len(tasksReceived) != 1 {
t.Fatalf("expected 1 task, got %d", len(tasksReceived))
}
if !sc.Watching() {
t.Error("expected watching=true after sync")
}
if sc.currentInterval() != SyncIntervalWatching {
t.Errorf("expected watching interval after sync")
}
}
func TestSyncClient_DoSync_Error(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer srv.Close()
sc, _ := newTestSyncClient(srv.URL)
// Should not panic on error
sc.doSync(context.Background())
}
func TestSyncClient_Run_CancelStopsLoop(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(SyncResponse{})
}))
defer srv.Close()
sc, _ := newTestSyncClient(srv.URL)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
err := sc.Run(ctx)
if err != nil {
t.Errorf("expected nil error, got %v", err)
}
}
func TestSyncClient_Run_ImmediateSyncOnTrigger(t *testing.T) {
var syncCount atomic.Int32
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
syncCount.Add(1)
json.NewEncoder(w).Encode(SyncResponse{})
}))
defer srv.Close()
sc, _ := newTestSyncClient(srv.URL)
// Set interval to something long so only triggers cause syncs
sc.interval.Store(int64(10 * time.Second))
ctx, cancel := context.WithCancel(context.Background())
go func() {
// Wait for initial sync, then trigger 2 more
time.Sleep(50 * time.Millisecond)
sc.TriggerSync()
time.Sleep(50 * time.Millisecond)
sc.TriggerSync()
time.Sleep(50 * time.Millisecond)
cancel()
}()
sc.Run(ctx)
// Initial sync (1) + 2 triggers + final sync = 4
count := syncCount.Load()
if count < 3 {
t.Errorf("expected at least 3 syncs (initial + 2 triggers), got %d", count)
}
}

136
internal/agent/taskstate.go Normal file
View file

@ -0,0 +1,136 @@
package agent
import (
"encoding/json"
"os"
"path/filepath"
"sync"
"time"
"github.com/torrentclaw/unarr/internal/config"
)
// TaskState represents the execution state of a single download task.
// Written by the Task Engine, read by the Sync goroutine.
type TaskState struct {
TaskID string `json:"taskId"`
Status string `json:"status"` // resolving, downloading, verifying, organizing, completed, failed
Progress int `json:"progress"`
DownloadedBytes int64 `json:"downloadedBytes,omitempty"`
TotalBytes int64 `json:"totalBytes,omitempty"`
SpeedBps int64 `json:"speedBps,omitempty"`
ETA int `json:"eta,omitempty"`
ResolvedMethod string `json:"resolvedMethod,omitempty"`
FileName string `json:"fileName,omitempty"`
FilePath string `json:"filePath,omitempty"`
StreamURL string `json:"streamUrl,omitempty"`
ErrorMessage string `json:"errorMessage,omitempty"`
UpdatedAt int64 `json:"updatedAt"`
}
// LocalState holds the CLI's local execution state (tasks.json).
// This is the CLI's source of truth for what it's doing right now.
type LocalState struct {
mu sync.RWMutex
tasks map[string]*TaskState
}
// NewLocalState creates an empty local state.
func NewLocalState() *LocalState {
return &LocalState{
tasks: make(map[string]*TaskState),
}
}
// Update adds or updates a task in local state.
func (s *LocalState) Update(ts TaskState) {
s.mu.Lock()
defer s.mu.Unlock()
ts.UpdatedAt = time.Now().Unix()
copied := ts
s.tasks[ts.TaskID] = &copied
}
// Remove removes a task from local state.
func (s *LocalState) Remove(taskID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tasks, taskID)
}
// Snapshot returns a copy of all current task states.
func (s *LocalState) Snapshot() []TaskState {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]TaskState, 0, len(s.tasks))
for _, ts := range s.tasks {
result = append(result, *ts)
}
return result
}
// TaskStateFromUpdate converts a StatusUpdate into a TaskState.
func TaskStateFromUpdate(u StatusUpdate) TaskState {
return TaskState{
TaskID: u.TaskID,
Status: u.Status,
Progress: u.Progress,
DownloadedBytes: u.DownloadedBytes,
TotalBytes: u.TotalBytes,
SpeedBps: u.SpeedBps,
ETA: u.ETA,
ResolvedMethod: u.ResolvedMethod,
FileName: u.FileName,
FilePath: u.FilePath,
StreamURL: u.StreamURL,
ErrorMessage: u.ErrorMessage,
}
}
// ShortID returns the first 8 characters of an ID, or the full ID if shorter.
func ShortID(id string) string {
if len(id) > 8 {
return id[:8]
}
return id
}
// taskStateFilePathFn is overridable for testing.
var taskStateFilePathFn = func() string {
return filepath.Join(config.DataDir(), "tasks.json")
}
// WriteToDisk persists local state to disk atomically (best-effort).
func (s *LocalState) WriteToDisk() {
tasks := s.Snapshot()
data, err := json.MarshalIndent(tasks, "", " ")
if err != nil {
return
}
path := taskStateFilePathFn()
dir := filepath.Dir(path)
os.MkdirAll(dir, 0o755)
tmp := path + ".tmp"
if err := os.WriteFile(tmp, data, 0o644); err != nil {
return
}
os.Rename(tmp, path)
}
// ReadFromDisk loads local state from disk. Returns empty state on error.
func (s *LocalState) ReadFromDisk() {
data, err := os.ReadFile(taskStateFilePathFn())
if err != nil {
return
}
var tasks []TaskState
if json.Unmarshal(data, &tasks) != nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.tasks = make(map[string]*TaskState, len(tasks))
for i := range tasks {
s.tasks[tasks[i].TaskID] = &tasks[i]
}
}

View file

@ -0,0 +1,217 @@
package agent
import (
"os"
"path/filepath"
"sync"
"testing"
)
func TestLocalState_UpdateAndSnapshot(t *testing.T) {
s := NewLocalState()
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 50})
s.Update(TaskState{TaskID: "t2", Status: "completed", Progress: 100})
snap := s.Snapshot()
if len(snap) != 2 {
t.Fatalf("expected 2 tasks, got %d", len(snap))
}
byID := make(map[string]TaskState, len(snap))
for _, ts := range snap {
byID[ts.TaskID] = ts
}
if byID["t1"].Progress != 50 {
t.Errorf("expected progress 50, got %d", byID["t1"].Progress)
}
if byID["t2"].Status != "completed" {
t.Errorf("expected completed, got %s", byID["t2"].Status)
}
}
func TestLocalState_UpdateOverwrites(t *testing.T) {
s := NewLocalState()
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 30})
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 70})
snap := s.Snapshot()
if len(snap) != 1 {
t.Fatalf("expected 1 task, got %d", len(snap))
}
if snap[0].Progress != 70 {
t.Errorf("expected progress 70, got %d", snap[0].Progress)
}
}
func TestLocalState_Remove(t *testing.T) {
s := NewLocalState()
s.Update(TaskState{TaskID: "t1", Status: "downloading"})
s.Update(TaskState{TaskID: "t2", Status: "downloading"})
s.Remove("t1")
snap := s.Snapshot()
if len(snap) != 1 {
t.Fatalf("expected 1 task, got %d", len(snap))
}
if snap[0].TaskID != "t2" {
t.Errorf("expected t2, got %s", snap[0].TaskID)
}
}
func TestLocalState_RemoveNonExistent(t *testing.T) {
s := NewLocalState()
s.Remove("nonexistent") // should not panic
}
func TestLocalState_SnapshotIsACopy(t *testing.T) {
s := NewLocalState()
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 50})
snap := s.Snapshot()
snap[0].Progress = 999
snap2 := s.Snapshot()
if snap2[0].Progress != 50 {
t.Errorf("snapshot mutation leaked: got progress %d", snap2[0].Progress)
}
}
func TestLocalState_UpdateSetsTimestamp(t *testing.T) {
s := NewLocalState()
s.Update(TaskState{TaskID: "t1", Status: "downloading"})
snap := s.Snapshot()
if snap[0].UpdatedAt == 0 {
t.Error("expected non-zero UpdatedAt")
}
}
func TestLocalState_ConcurrentAccess(t *testing.T) {
s := NewLocalState()
var wg sync.WaitGroup
for i := range 100 {
wg.Add(1)
go func(n int) {
defer wg.Done()
taskID := "t" + string(rune('0'+n%10))
s.Update(TaskState{TaskID: taskID, Status: "downloading", Progress: n})
s.Snapshot()
if n%3 == 0 {
s.Remove(taskID)
}
}(i)
}
wg.Wait()
// No race condition = test passes
}
func TestLocalState_WriteToDisk_ReadFromDisk(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "tasks.json")
// Override the file path for testing
orig := taskStateFilePathFn
taskStateFilePathFn = func() string { return path }
defer func() { taskStateFilePathFn = orig }()
s := NewLocalState()
s.Update(TaskState{TaskID: "t1", Status: "downloading", Progress: 45})
s.Update(TaskState{TaskID: "t2", Status: "completed", Progress: 100, FilePath: "/tmp/movie.mkv"})
s.WriteToDisk()
// Verify file exists
if _, err := os.Stat(path); os.IsNotExist(err) {
t.Fatal("tasks.json was not created")
}
// Read into a new LocalState
s2 := NewLocalState()
s2.ReadFromDisk()
snap := s2.Snapshot()
if len(snap) != 2 {
t.Fatalf("expected 2 tasks after read, got %d", len(snap))
}
byID := make(map[string]TaskState, len(snap))
for _, ts := range snap {
byID[ts.TaskID] = ts
}
if byID["t1"].Progress != 45 {
t.Errorf("expected progress 45, got %d", byID["t1"].Progress)
}
if byID["t2"].FilePath != "/tmp/movie.mkv" {
t.Errorf("expected /tmp/movie.mkv, got %s", byID["t2"].FilePath)
}
}
func TestLocalState_ReadFromDisk_CorruptedFile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "tasks.json")
orig := taskStateFilePathFn
taskStateFilePathFn = func() string { return path }
defer func() { taskStateFilePathFn = orig }()
// Write corrupted JSON
os.WriteFile(path, []byte("{invalid json"), 0o644)
s := NewLocalState()
s.ReadFromDisk() // should not panic
snap := s.Snapshot()
if len(snap) != 0 {
t.Errorf("expected 0 tasks from corrupted file, got %d", len(snap))
}
}
func TestLocalState_ReadFromDisk_FileNotFound(t *testing.T) {
orig := taskStateFilePathFn
taskStateFilePathFn = func() string { return "/nonexistent/path/tasks.json" }
defer func() { taskStateFilePathFn = orig }()
s := NewLocalState()
s.ReadFromDisk() // should not panic
snap := s.Snapshot()
if len(snap) != 0 {
t.Errorf("expected 0 tasks, got %d", len(snap))
}
}
func TestLocalState_AtomicWrite(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "tasks.json")
orig := taskStateFilePathFn
taskStateFilePathFn = func() string { return path }
defer func() { taskStateFilePathFn = orig }()
s := NewLocalState()
s.Update(TaskState{TaskID: "t1", Status: "downloading"})
s.WriteToDisk()
// Verify no .tmp file remains
tmpPath := path + ".tmp"
if _, err := os.Stat(tmpPath); !os.IsNotExist(err) {
t.Error("temp file should not exist after write")
}
}
func TestLocalState_EmptySnapshot(t *testing.T) {
s := NewLocalState()
snap := s.Snapshot()
if snap == nil {
t.Error("snapshot should be non-nil empty slice")
}
if len(snap) != 0 {
t.Errorf("expected 0 tasks, got %d", len(snap))
}
}

View file

@ -1,51 +0,0 @@
package agent
import "context"
// Transport abstracts the communication protocol between the agent and server.
// Both WebSocket (via CF Durable Object) and HTTP (direct to origin) implement this.
type Transport interface {
// Connect establishes the transport connection.
// Called internally by Daemon.Run — callers must NOT call Connect separately.
Connect(ctx context.Context) error
// Close tears down the connection gracefully.
Close() error
// Mode returns the current transport mode ("ws" or "http").
Mode() string
// Register sends agent registration and returns user info + features.
Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error)
// SendHeartbeat sends a periodic keep-alive.
SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error)
// SendProgress reports download progress for a task.
SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error)
// ClaimTasks polls for new tasks (HTTP mode only; WS receives via Events).
ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error)
// Deregister notifies the server of graceful shutdown.
Deregister(ctx context.Context, agentID string) error
// Events returns a channel that emits server-initiated events.
// In HTTP mode this channel is never written to (polling handles it).
// In WS mode, tasks/upgrade/control arrive here.
Events() <-chan ServerEvent
}
// ServerEvent represents a server-initiated message received via WebSocket.
type ServerEvent struct {
Type string // "tasks", "upgrade", "control", "disconnected"
Tasks *TasksResponse // populated when Type == "tasks"
Upgrade *UpgradeSignal // populated when Type == "upgrade"
Control *ControlAction // populated when Type == "control"
}
// ControlAction represents a server push for task control.
type ControlAction struct {
Action string `json:"action"` // "pause", "resume", "cancel", "stream"
TaskID string `json:"taskId"`
}

View file

@ -1,285 +0,0 @@
package agent
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
)
// TestE2EFullLifecycle tests the full lifecycle:
// connect → auth → receive tasks → send progress → receive control → disconnect → reconnect
func TestE2EFullLifecycle(t *testing.T) {
var mu sync.Mutex
var receivedMessages []map[string]interface{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
_, msg, err := conn.ReadMessage()
if err != nil {
return
}
var parsed map[string]interface{}
json.Unmarshal(msg, &parsed)
mu.Lock()
receivedMessages = append(receivedMessages, parsed)
mu.Unlock()
msgType, _ := parsed["type"].(string)
switch msgType {
case "auth":
conn.WriteJSON(wsRegisteredMessage{
Type: "registered",
User: UserInfo{Name: "E2E User", Plan: "pro", IsPro: true},
Features: FeatureFlags{Torrent: true, Debrid: true},
})
case "heartbeat":
// No response in WS mode
case "progress":
// Simulate server-side cancel after progress
if progress, ok := parsed["progress"].(float64); ok && progress >= 50 {
conn.WriteJSON(map[string]string{
"type": "control",
"action": "cancel",
"taskId": parsed["taskId"].(string),
})
}
case "upgrade-result":
// Acknowledged
}
}
}))
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
tr := NewWSTransport(wsURL, "e2e-key", "e2e-agent", "test/1.0")
ctx := context.Background()
// 1. Connect
if err := tr.Connect(ctx); err != nil {
t.Fatalf("Connect: %v", err)
}
defer tr.Close()
// 2. Auth
resp, err := tr.Register(ctx, RegisterRequest{
AgentID: "e2e-agent",
Name: "E2E Test Agent",
Version: "1.0.0",
OS: "linux",
Arch: "amd64",
})
if err != nil {
t.Fatalf("Register: %v", err)
}
if resp.User.Name != "E2E User" {
t.Errorf("expected E2E User, got %s", resp.User.Name)
}
if !resp.Features.Debrid {
t.Error("expected debrid feature")
}
// 3. Send heartbeat
_, err = tr.SendHeartbeat(ctx, HeartbeatRequest{
AgentID: "e2e-agent",
DiskFreeBytes: 1000000000,
DiskTotalBytes: 5000000000,
})
if err != nil {
t.Fatalf("SendHeartbeat: %v", err)
}
// 4. Send progress (50% → should trigger cancel control)
_, err = tr.SendProgress(ctx, StatusUpdate{
TaskID: "task-e2e-1",
Status: "downloading",
Progress: 50,
DownloadedBytes: 500,
TotalBytes: 1000,
SpeedBps: 100,
})
if err != nil {
t.Fatalf("SendProgress: %v", err)
}
// 5. Wait for control event (cancel)
select {
case event := <-tr.Events():
if event.Type != "control" {
t.Errorf("expected control event, got %s", event.Type)
}
if event.Control.Action != "cancel" {
t.Errorf("expected cancel, got %s", event.Control.Action)
}
if event.Control.TaskID != "task-e2e-1" {
t.Errorf("expected task-e2e-1, got %s", event.Control.TaskID)
}
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for cancel control")
}
// Verify server received all messages
time.Sleep(100 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
if len(receivedMessages) < 3 {
t.Fatalf("expected at least 3 messages, got %d", len(receivedMessages))
}
types := make([]string, len(receivedMessages))
for i, m := range receivedMessages {
types[i], _ = m["type"].(string)
}
expected := []string{"auth", "heartbeat", "progress"}
for _, exp := range expected {
found := false
for _, got := range types {
if got == exp {
found = true
break
}
}
if !found {
t.Errorf("missing message type %q in %v", exp, types)
}
}
}
// TestE2EHybridFailover tests the full failover scenario:
// WS connect → download → WS disconnect → switch to HTTP → continue working
func TestE2EHybridFailover(t *testing.T) {
connectionCount := 0
var mu sync.Mutex
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
mu.Lock()
connectionCount++
connNum := connectionCount
mu.Unlock()
// Read auth
conn.ReadMessage()
conn.WriteJSON(wsRegisteredMessage{
Type: "registered",
User: UserInfo{Name: "Failover User"},
})
if connNum == 1 {
// First connection: push tasks then disconnect after 200ms
time.Sleep(50 * time.Millisecond)
conn.WriteJSON(wsTasksMessage{
Type: "tasks",
Tasks: []Task{{ID: "t1", InfoHash: "abc", Title: "Failover Movie"}},
})
time.Sleep(150 * time.Millisecond)
conn.Close()
} else {
// Second connection (after reconnect): push upgrade
time.Sleep(50 * time.Millisecond)
conn.WriteJSON(wsUpgradeMessage{Type: "upgrade", Version: "3.0.0"})
time.Sleep(500 * time.Millisecond)
conn.Close()
}
}))
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
wsT := NewWSTransport(wsURL, "key", "a1", "ua")
// HTTP mock for fallback
httpSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simple heartbeat response
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
}))
defer httpSrv.Close()
httpT := NewHTTPTransport(httpSrv.URL, "key", "ua")
h := NewHybridTransport(wsT, httpT)
ctx := context.Background()
err := h.Connect(ctx)
if err != nil {
t.Fatalf("Connect: %v", err)
}
defer h.Close()
// Should start in WS mode
if h.Mode() != "ws" {
t.Fatalf("expected ws mode, got %s", h.Mode())
}
// Register via WS
_, err = h.Register(ctx, RegisterRequest{AgentID: "a1"})
if err != nil {
t.Fatalf("Register: %v", err)
}
// Receive tasks via WS
var tasksReceived bool
var disconnected bool
for i := 0; i < 3; i++ {
select {
case event := <-h.Events():
switch event.Type {
case "tasks":
tasksReceived = true
if len(event.Tasks.Tasks) != 1 || event.Tasks.Tasks[0].Title != "Failover Movie" {
t.Errorf("unexpected tasks: %+v", event.Tasks)
}
case "disconnected":
disconnected = true
}
case <-time.After(2 * time.Second):
break
}
if disconnected {
break
}
}
if !tasksReceived {
t.Error("did not receive tasks before disconnect")
}
if !disconnected {
t.Error("did not receive disconnect event")
}
// Should now be in HTTP mode
time.Sleep(100 * time.Millisecond)
if h.Mode() != "http" {
t.Errorf("expected http mode after disconnect, got %s", h.Mode())
}
// Heartbeat should work via HTTP fallback
hbResp, err := h.SendHeartbeat(ctx, HeartbeatRequest{AgentID: "a1"})
if err != nil {
t.Fatalf("SendHeartbeat via HTTP fallback: %v", err)
}
if !hbResp.Success {
t.Error("expected heartbeat success")
}
}

View file

@ -1,50 +0,0 @@
package agent
import "context"
// HTTPTransport wraps the existing Client to implement Transport.
// This is a thin adapter — no behavioral changes from the current HTTP protocol.
type HTTPTransport struct {
client *Client
events chan ServerEvent
}
// NewHTTPTransport creates a new HTTP-based transport.
func NewHTTPTransport(baseURL, apiKey, userAgent string) *HTTPTransport {
return &HTTPTransport{
client: NewClient(baseURL, apiKey, userAgent),
events: make(chan ServerEvent, 10),
}
}
func (t *HTTPTransport) Connect(_ context.Context) error { return nil }
func (t *HTTPTransport) Close() error { return nil }
func (t *HTTPTransport) Mode() string { return "http" }
func (t *HTTPTransport) Events() <-chan ServerEvent { return t.events }
func (t *HTTPTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
return t.client.Register(ctx, req)
}
func (t *HTTPTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
return t.client.Heartbeat(ctx, req)
}
func (t *HTTPTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) {
return t.client.ReportStatus(ctx, update)
}
func (t *HTTPTransport) BatchReportStatus(ctx context.Context, updates []StatusUpdate) (*BatchStatusResponse, error) {
return t.client.BatchReportStatus(ctx, updates)
}
func (t *HTTPTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
return t.client.ClaimTasks(ctx, agentID)
}
func (t *HTTPTransport) Deregister(ctx context.Context, agentID string) error {
return t.client.Deregister(ctx, agentID)
}
// Client returns the underlying HTTP client for direct use if needed.
func (t *HTTPTransport) Client() *Client { return t.client }

View file

@ -1,214 +0,0 @@
package agent
import (
"context"
"log"
"sync"
"sync/atomic"
"time"
)
// HybridTransport tries WebSocket first, falls back to HTTP if WS fails.
// Automatically reconnects WS in the background.
type HybridTransport struct {
ws *WSTransport
http *HTTPTransport
mode atomic.Value // "ws" or "http"
events chan ServerEvent
reconnectMu sync.Mutex
reconnectRunning bool
reconnectStop chan struct{}
closed atomic.Bool
}
// NewHybridTransport creates a transport that prefers WS with HTTP fallback.
func NewHybridTransport(ws *WSTransport, http *HTTPTransport) *HybridTransport {
h := &HybridTransport{
ws: ws,
http: http,
events: make(chan ServerEvent, 50),
reconnectStop: make(chan struct{}),
}
h.mode.Store("http") // start in HTTP, upgrade to WS on Connect
return h
}
func (h *HybridTransport) Mode() string { return h.mode.Load().(string) }
func (h *HybridTransport) Events() <-chan ServerEvent { return h.events }
// Connect tries WS first. If it fails, falls back to HTTP and starts reconnection loop.
func (h *HybridTransport) Connect(ctx context.Context) error {
// Try WebSocket first
if err := h.ws.Connect(ctx); err != nil {
log.Printf("[transport] WebSocket connect failed (%v), using HTTP fallback", err)
h.mode.Store("http")
h.startReconnectLoop()
return h.http.Connect(ctx)
}
h.mode.Store("ws")
log.Println("[transport] Connected via WebSocket")
// Forward WS events to unified channel + watch for disconnection
go h.forwardWSEvents()
return nil
}
// Close shuts down both transports and stops reconnection.
func (h *HybridTransport) Close() error {
h.closed.Store(true)
select {
case <-h.reconnectStop:
default:
close(h.reconnectStop)
}
_ = h.ws.Close()
return h.http.Close()
}
// Register delegates to the active transport.
func (h *HybridTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
if h.mode.Load() == "ws" {
return h.ws.Register(ctx, req)
}
return h.http.Register(ctx, req)
}
// SendHeartbeat delegates to the active transport.
func (h *HybridTransport) SendHeartbeat(ctx context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
if h.mode.Load() == "ws" {
resp, err := h.ws.SendHeartbeat(ctx, req)
if err != nil {
// WS write failed — switch to HTTP
h.switchToHTTP()
return h.http.SendHeartbeat(ctx, req)
}
return resp, nil
}
return h.http.SendHeartbeat(ctx, req)
}
// SendProgress delegates to the active transport.
func (h *HybridTransport) SendProgress(ctx context.Context, update StatusUpdate) (*StatusResponse, error) {
if h.mode.Load() == "ws" {
resp, err := h.ws.SendProgress(ctx, update)
if err != nil {
h.switchToHTTP()
return h.http.SendProgress(ctx, update)
}
return resp, nil
}
return h.http.SendProgress(ctx, update)
}
// ClaimTasks delegates to the active transport.
func (h *HybridTransport) ClaimTasks(ctx context.Context, agentID string) (*TasksResponse, error) {
if h.mode.Load() == "ws" {
return h.ws.ClaimTasks(ctx, agentID) // no-op in WS mode
}
return h.http.ClaimTasks(ctx, agentID)
}
// Deregister delegates to the active transport.
func (h *HybridTransport) Deregister(ctx context.Context, agentID string) error {
if h.mode.Load() == "ws" {
return h.ws.Deregister(ctx, agentID)
}
return h.http.Deregister(ctx, agentID)
}
// ── Internal ─────────────────────────────────────────────────────────────────
func (h *HybridTransport) switchToHTTP() {
if h.mode.Load() == "http" {
return
}
log.Println("[transport] Switching to HTTP fallback")
h.mode.Store("http")
_ = h.ws.Close()
h.startReconnectLoop()
}
func (h *HybridTransport) forwardWSEvents() {
for {
select {
case <-h.reconnectStop:
return
case event, ok := <-h.ws.Events():
if !ok {
return // channel closed
}
if event.Type == "disconnected" {
h.switchToHTTP()
select {
case h.events <- event:
default:
}
return
}
select {
case h.events <- event:
default:
log.Printf("[transport] events channel full, dropping %s event", event.Type)
}
}
}
}
func (h *HybridTransport) startReconnectLoop() {
h.reconnectMu.Lock()
defer h.reconnectMu.Unlock()
if h.reconnectRunning {
return
}
h.reconnectRunning = true
go h.reconnectLoop()
}
func (h *HybridTransport) reconnectLoop() {
backoff := 5 * time.Second
maxBackoff := 60 * time.Second
for {
select {
case <-h.reconnectStop:
return
case <-time.After(backoff):
}
if h.closed.Load() {
return
}
// Already on WS? (someone else reconnected)
if h.mode.Load() == "ws" {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
err := h.ws.Connect(ctx)
cancel()
if err != nil {
log.Printf("[transport] WS reconnect failed: %v (retry in %v)", err, backoff)
backoff = min(backoff*2, maxBackoff)
continue
}
// WS reconnected — switch back
log.Println("[transport] WebSocket reconnected")
h.mode.Store("ws")
// Reset reconnect flag so loop can start again if WS drops
h.reconnectMu.Lock()
h.reconnectRunning = false
h.reconnectMu.Unlock()
// Forward events from new WS connection
go h.forwardWSEvents()
return
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,395 +0,0 @@
package agent
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
)
// WSTransport communicates with the server via WebSocket through a Cloudflare Durable Object.
type WSTransport struct {
wsURL string // wss://unarr.torrentclaw.com/ws/{agentId}
apiKey string
agentID string
userAgent string
conn *websocket.Conn
mu sync.Mutex
events chan ServerEvent
closed atomic.Bool
// Cached auth response from the DO
authResp *RegisterResponse
authMu sync.Mutex
authDone chan struct{}
authDoneOnce sync.Once
}
// NewWSTransport creates a WebSocket-based transport.
func NewWSTransport(wsURL, apiKey, agentID, userAgent string) *WSTransport {
return &WSTransport{
wsURL: wsURL,
apiKey: apiKey,
agentID: agentID,
userAgent: userAgent,
events: make(chan ServerEvent, 50),
authDone: make(chan struct{}),
}
}
func (t *WSTransport) Mode() string { return "ws" }
func (t *WSTransport) Events() <-chan ServerEvent { return t.events }
// Connect dials the WebSocket server and starts the read loop.
func (t *WSTransport) Connect(ctx context.Context) error {
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
header := http.Header{}
header.Set("User-Agent", t.userAgent)
// Append API key as query param for auth on WS upgrade
wsURLWithKey := t.wsURL
if t.apiKey != "" {
sep := "?"
if strings.Contains(wsURLWithKey, "?") {
sep = "&"
}
wsURLWithKey += sep + "key=" + t.apiKey
}
conn, wsResp, err := dialer.DialContext(ctx, wsURLWithKey, header)
if wsResp != nil && wsResp.Body != nil {
defer wsResp.Body.Close()
}
if err != nil {
return fmt.Errorf("ws dial: %w", err)
}
t.mu.Lock()
t.conn = conn
t.closed.Store(false)
t.authDone = make(chan struct{})
t.authDoneOnce = sync.Once{}
t.mu.Unlock()
go t.readLoop(conn)
return nil
}
// Close sends a close frame and shuts down the connection.
func (t *WSTransport) Close() error {
t.closed.Store(true)
t.mu.Lock()
defer t.mu.Unlock()
if t.conn != nil {
_ = t.conn.WriteMessage(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
)
err := t.conn.Close()
t.conn = nil
return err
}
return nil
}
// Register sends auth message and waits for the registered response.
func (t *WSTransport) Register(ctx context.Context, req RegisterRequest) (*RegisterResponse, error) {
msg := wsAuthMessage{
Type: "auth",
APIKey: t.apiKey,
AgentID: req.AgentID,
Name: req.Name,
OS: req.OS,
Arch: req.Arch,
Version: req.Version,
DownloadDir: req.DownloadDir,
DiskFreeBytes: req.DiskFreeBytes,
DiskTotalBytes: req.DiskTotalBytes,
}
if err := t.send(msg); err != nil {
return nil, fmt.Errorf("ws auth send: %w", err)
}
// Wait for the auth response or context cancellation
select {
case <-t.authDone:
t.authMu.Lock()
resp := t.authResp
t.authMu.Unlock()
if resp == nil {
return nil, fmt.Errorf("ws auth: no response received")
}
return resp, nil
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(15 * time.Second):
return nil, fmt.Errorf("ws auth: timeout waiting for registered response")
}
}
// SendHeartbeat sends a heartbeat message. No blocking response in WS mode.
func (t *WSTransport) SendHeartbeat(_ context.Context, req HeartbeatRequest) (*HeartbeatResponse, error) {
msg := struct {
Type string `json:"type"`
Disk *struct {
Free int64 `json:"free"`
Total int64 `json:"total"`
} `json:"disk,omitempty"`
}{Type: "heartbeat"}
if req.DiskFreeBytes > 0 || req.DiskTotalBytes > 0 {
msg.Disk = &struct {
Free int64 `json:"free"`
Total int64 `json:"total"`
}{Free: req.DiskFreeBytes, Total: req.DiskTotalBytes}
}
if err := t.send(msg); err != nil {
return nil, err
}
// WS mode: heartbeat is fire-and-forget. Upgrade signals arrive via Events().
return &HeartbeatResponse{Success: true}, nil
}
// SendProgress sends a progress update. Control signals arrive async via Events().
func (t *WSTransport) SendProgress(_ context.Context, update StatusUpdate) (*StatusResponse, error) {
msg := struct {
Type string `json:"type"`
TaskID string `json:"taskId"`
Status string `json:"status,omitempty"`
Progress int `json:"progress,omitempty"`
DownloadedBytes int64 `json:"downloadedBytes,omitempty"`
TotalBytes int64 `json:"totalBytes,omitempty"`
SpeedBps int64 `json:"speedBps,omitempty"`
ETA int `json:"eta,omitempty"`
ResolvedMethod string `json:"resolvedMethod,omitempty"`
FileName string `json:"fileName,omitempty"`
FilePath string `json:"filePath,omitempty"`
StreamURL string `json:"streamUrl,omitempty"`
StreamReady bool `json:"streamReady,omitempty"`
ErrorMessage string `json:"errorMessage,omitempty"`
}{
Type: "progress",
TaskID: update.TaskID,
Status: update.Status,
Progress: update.Progress,
DownloadedBytes: update.DownloadedBytes,
TotalBytes: update.TotalBytes,
SpeedBps: update.SpeedBps,
ETA: update.ETA,
ResolvedMethod: update.ResolvedMethod,
FileName: update.FileName,
FilePath: update.FilePath,
StreamURL: update.StreamURL,
StreamReady: update.StreamReady,
ErrorMessage: update.ErrorMessage,
}
if err := t.send(msg); err != nil {
return nil, err
}
// In WS mode, control signals come via Events(), not in the progress response.
return &StatusResponse{Success: true}, nil
}
// ClaimTasks is a no-op in WS mode — tasks arrive via Events().
func (t *WSTransport) ClaimTasks(_ context.Context, _ string) (*TasksResponse, error) {
return &TasksResponse{}, nil
}
// Deregister is handled by WebSocket close (DO detects disconnection).
func (t *WSTransport) Deregister(_ context.Context, _ string) error {
return t.Close()
}
// ── Internal ─────────────────────────────────────────────────────────────────
func (t *WSTransport) send(msg any) error {
t.mu.Lock()
defer t.mu.Unlock()
if t.conn == nil {
return fmt.Errorf("ws: not connected")
}
data, err := json.Marshal(msg)
if err != nil {
return err
}
_ = t.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
return t.conn.WriteMessage(websocket.TextMessage, data)
}
func (t *WSTransport) readLoop(conn *websocket.Conn) {
// Cloudflare idle timeout is 100s. We send pings every 30s and expect
// either a pong or a server message within 45s. If neither arrives,
// the read deadline fires and we detect the zombie connection.
const (
pongWait = 45 * time.Second
pingPeriod = 30 * time.Second
)
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
conn.SetPongHandler(func(string) error {
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
// Ping ticker goroutine — stops when readLoop returns.
pingDone := make(chan struct{})
go func() {
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
t.mu.Lock()
if t.conn != nil {
_ = t.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
err := t.conn.WriteMessage(websocket.PingMessage, nil)
_ = t.conn.SetWriteDeadline(time.Time{})
if err != nil {
t.mu.Unlock()
return
}
}
t.mu.Unlock()
case <-pingDone:
return
}
}
}()
defer close(pingDone)
for {
_, msg, err := conn.ReadMessage()
if err != nil {
if !t.closed.Load() {
log.Printf("[ws] read error: %v", err)
// Signal disconnection to the daemon
select {
case t.events <- ServerEvent{Type: "disconnected"}:
default:
}
}
return
}
// Any message (text or pong) proves the connection is alive.
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
var envelope struct {
Type string `json:"type"`
}
if err := json.Unmarshal(msg, &envelope); err != nil {
log.Printf("[ws] invalid message: %v", err)
continue
}
switch envelope.Type {
case "registered":
var resp wsRegisteredMessage
if json.Unmarshal(msg, &resp) == nil {
t.authMu.Lock()
t.authResp = &RegisterResponse{
Success: true,
User: resp.User,
Features: resp.Features,
}
t.authMu.Unlock()
// Signal that auth is complete (sync.Once prevents double-close panic)
t.authDoneOnce.Do(func() { close(t.authDone) })
}
case "tasks":
var resp wsTasksMessage
if json.Unmarshal(msg, &resp) == nil {
select {
case t.events <- ServerEvent{
Type: "tasks",
Tasks: &TasksResponse{
Tasks: resp.Tasks,
StreamRequests: resp.StreamRequests,
},
}:
default:
log.Printf("[ws] events channel full, dropping tasks message")
}
}
case "upgrade":
var resp wsUpgradeMessage
if json.Unmarshal(msg, &resp) == nil {
select {
case t.events <- ServerEvent{
Type: "upgrade",
Upgrade: &UpgradeSignal{Version: resp.Version},
}:
default:
}
}
case "control":
var resp ControlAction
if json.Unmarshal(msg, &resp) == nil {
select {
case t.events <- ServerEvent{
Type: "control",
Control: &resp,
}:
default:
}
}
case "error":
var resp struct {
Message string `json:"message"`
}
if json.Unmarshal(msg, &resp) == nil {
log.Printf("[ws] server error: %s", resp.Message)
}
}
}
}
// ── WS message types ─────────────────────────────────────────────────────────
type wsAuthMessage struct {
Type string `json:"type"`
APIKey string `json:"apiKey"`
AgentID string `json:"agentId"`
Name string `json:"name,omitempty"`
OS string `json:"os,omitempty"`
Arch string `json:"arch,omitempty"`
Version string `json:"version,omitempty"`
DownloadDir string `json:"downloadDir,omitempty"`
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
}
type wsRegisteredMessage struct {
Type string `json:"type"`
User UserInfo `json:"user"`
Features FeatureFlags `json:"features"`
}
type wsTasksMessage struct {
Type string `json:"type"`
Tasks []Task `json:"tasks"`
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
}
type wsUpgradeMessage struct {
Type string `json:"type"`
Version string `json:"version"`
}

View file

@ -50,20 +50,6 @@ type UsenetServerInfo struct {
SSL bool `json:"ssl"`
}
// HeartbeatRequest is sent every 30s to keep the agent alive.
type HeartbeatRequest struct {
AgentID string `json:"agentId"`
Name string `json:"name,omitempty"`
OS string `json:"os,omitempty"`
Version string `json:"version,omitempty"`
DownloadDir string `json:"downloadDir,omitempty"`
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
StreamPort int `json:"streamPort,omitempty"`
LanIP string `json:"lanIp,omitempty"`
TailscaleIP string `json:"tailscaleIp,omitempty"`
}
// Task represents a download task claimed from the server.
type Task struct {
ID string `json:"id"`
@ -88,12 +74,6 @@ type Task struct {
CollectionName string `json:"collectionName,omitempty"` // Collection name (e.g., "Harry Potter Collection")
}
// TasksResponse wraps the array of tasks returned by the server.
type TasksResponse struct {
Tasks []Task `json:"tasks"`
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
}
// StreamRequest is a request to stream a completed download from disk.
type StreamRequest struct {
TaskID string `json:"taskId"`
@ -139,14 +119,6 @@ type BatchStatusResponse struct {
Watching bool `json:"watching,omitempty"`
}
// HeartbeatResponse is returned by the server on heartbeat.
type HeartbeatResponse struct {
Success bool `json:"success"`
Upgrade *UpgradeSignal `json:"upgrade,omitempty"`
Watching bool `json:"watching,omitempty"` // true when a user is viewing download progress in the web UI
Scan bool `json:"scan,omitempty"` // true when user triggered a library scan from the web UI
}
// UpgradeSignal tells the agent to upgrade to a specific version.
type UpgradeSignal struct {
Version string `json:"version"`
@ -176,7 +148,6 @@ type AgentInfo struct {
User UserInfo
Features FeatureFlags
StartedAt time.Time
LastPollAt time.Time
ActiveTasks int
}
@ -334,6 +305,45 @@ type LibrarySyncResponse struct {
Removed int `json:"removed"`
}
// ---------------------------------------------------------------------------
// Sync types (unified CLI ↔ Server communication)
// ---------------------------------------------------------------------------
// SyncRequest is sent by the CLI periodically to synchronize state with the server.
// Contains the CLI's full execution state — the server responds with pending actions.
type SyncRequest struct {
AgentID string `json:"agentId"`
Version string `json:"version,omitempty"`
OS string `json:"os,omitempty"`
Arch string `json:"arch,omitempty"`
Name string `json:"name,omitempty"`
DownloadDir string `json:"downloadDir,omitempty"`
DiskFreeBytes int64 `json:"diskFreeBytes,omitempty"`
DiskTotalBytes int64 `json:"diskTotalBytes,omitempty"`
StreamPort int `json:"streamPort,omitempty"`
LanIP string `json:"lanIp,omitempty"`
TailscaleIP string `json:"tailscaleIp,omitempty"`
FreeSlots int `json:"freeSlots"`
Tasks []TaskState `json:"tasks"`
}
// ControlAction represents a server-side control signal for a task.
type ControlAction struct {
Action string `json:"action"` // "pause", "resume", "cancel", "stream"
TaskID string `json:"taskId"`
DeleteFiles bool `json:"deleteFiles,omitempty"`
}
// SyncResponse is returned by the server with all pending actions for the CLI.
type SyncResponse struct {
NewTasks []Task `json:"newTasks,omitempty"`
Controls []ControlAction `json:"controls,omitempty"`
StreamRequests []StreamRequest `json:"streamRequests,omitempty"`
Watching bool `json:"watching"`
Upgrade *UpgradeSignal `json:"upgrade,omitempty"`
Scan bool `json:"scan,omitempty"`
}
// ---------------------------------------------------------------------------
// Watch progress types (used by stream tracking)
// ---------------------------------------------------------------------------