feat(agent): add WebSocket transport with HTTP fallback
Add Transport interface abstraction supporting WebSocket (via CF Durable Objects) and HTTP (direct to origin) with automatic failover. - Transport interface: Register, SendHeartbeat, SendProgress, Events() - HTTPTransport: thin adapter over existing Client - WSTransport: gorilla/websocket with auth handshake, readLoop, reconnect - HybridTransport: tries WS first, falls back to HTTP, reconnects in bg - Daemon refactored to always use Transport (no dual-path forks) - ProgressReporter accepts StatusReporter interface - deriveWSURL skips localhost/dev (returns "" → HTTP-only) - API key passed in WS query param for connection auth - Fixed: reconnectOnce race (mutex+bool), authDone double-close (sync.Once) - Fixed: forwardWSEvents goroutine leak (select with stop signal) - 20 transport tests + 2 E2E tests (full lifecycle, hybrid failover)
This commit is contained in:
parent
5e80911501
commit
5f337eebd7
10 changed files with 1646 additions and 64 deletions
295
internal/agent/transport_e2e_test.go
Normal file
295
internal/agent/transport_e2e_test.go
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestE2EFullLifecycle tests the full lifecycle:
|
||||
// connect → auth → receive tasks → send progress → receive control → disconnect → reconnect
|
||||
func TestE2EFullLifecycle(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var receivedMessages []map[string]interface{}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var parsed map[string]interface{}
|
||||
json.Unmarshal(msg, &parsed)
|
||||
|
||||
mu.Lock()
|
||||
receivedMessages = append(receivedMessages, parsed)
|
||||
mu.Unlock()
|
||||
|
||||
msgType, _ := parsed["type"].(string)
|
||||
switch msgType {
|
||||
case "auth":
|
||||
conn.WriteJSON(wsRegisteredMessage{
|
||||
Type: "registered",
|
||||
User: UserInfo{Name: "E2E User", Plan: "pro", IsPro: true},
|
||||
Features: FeatureFlags{Torrent: true, Debrid: true},
|
||||
})
|
||||
|
||||
case "heartbeat":
|
||||
// No response in WS mode
|
||||
|
||||
case "progress":
|
||||
// Simulate server-side cancel after progress
|
||||
if progress, ok := parsed["progress"].(float64); ok && progress >= 50 {
|
||||
conn.WriteJSON(map[string]string{
|
||||
"type": "control",
|
||||
"action": "cancel",
|
||||
"taskId": parsed["taskId"].(string),
|
||||
})
|
||||
}
|
||||
|
||||
case "upgrade-result":
|
||||
// Acknowledged
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||
tr := NewWSTransport(wsURL, "e2e-key", "e2e-agent", "test/1.0")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 1. Connect
|
||||
if err := tr.Connect(ctx); err != nil {
|
||||
t.Fatalf("Connect: %v", err)
|
||||
}
|
||||
defer tr.Close()
|
||||
|
||||
// 2. Auth
|
||||
resp, err := tr.Register(ctx, RegisterRequest{
|
||||
AgentID: "e2e-agent",
|
||||
Name: "E2E Test Agent",
|
||||
Version: "1.0.0",
|
||||
OS: "linux",
|
||||
Arch: "amd64",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
if resp.User.Name != "E2E User" {
|
||||
t.Errorf("expected E2E User, got %s", resp.User.Name)
|
||||
}
|
||||
if !resp.Features.Debrid {
|
||||
t.Error("expected debrid feature")
|
||||
}
|
||||
|
||||
// 3. Send heartbeat
|
||||
_, err = tr.SendHeartbeat(ctx, HeartbeatRequest{
|
||||
AgentID: "e2e-agent",
|
||||
DiskFreeBytes: 1000000000,
|
||||
DiskTotalBytes: 5000000000,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendHeartbeat: %v", err)
|
||||
}
|
||||
|
||||
// 4. Send progress (50% → should trigger cancel control)
|
||||
_, err = tr.SendProgress(ctx, StatusUpdate{
|
||||
TaskID: "task-e2e-1",
|
||||
Status: "downloading",
|
||||
Progress: 50,
|
||||
DownloadedBytes: 500,
|
||||
TotalBytes: 1000,
|
||||
SpeedBps: 100,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendProgress: %v", err)
|
||||
}
|
||||
|
||||
// 5. Wait for control event (cancel)
|
||||
select {
|
||||
case event := <-tr.Events():
|
||||
if event.Type != "control" {
|
||||
t.Errorf("expected control event, got %s", event.Type)
|
||||
}
|
||||
if event.Control.Action != "cancel" {
|
||||
t.Errorf("expected cancel, got %s", event.Control.Action)
|
||||
}
|
||||
if event.Control.TaskID != "task-e2e-1" {
|
||||
t.Errorf("expected task-e2e-1, got %s", event.Control.TaskID)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for cancel control")
|
||||
}
|
||||
|
||||
// 6. Send upgrade result
|
||||
err = tr.ReportUpgradeResult(ctx, UpgradeResult{
|
||||
AgentID: "e2e-agent",
|
||||
Success: true,
|
||||
Version: "2.0.0",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ReportUpgradeResult: %v", err)
|
||||
}
|
||||
|
||||
// Verify server received all messages
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if len(receivedMessages) < 4 {
|
||||
t.Fatalf("expected at least 4 messages, got %d", len(receivedMessages))
|
||||
}
|
||||
|
||||
types := make([]string, len(receivedMessages))
|
||||
for i, m := range receivedMessages {
|
||||
types[i], _ = m["type"].(string)
|
||||
}
|
||||
|
||||
expected := []string{"auth", "heartbeat", "progress", "upgrade-result"}
|
||||
for _, exp := range expected {
|
||||
found := false
|
||||
for _, got := range types {
|
||||
if got == exp {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("missing message type %q in %v", exp, types)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2EHybridFailover tests the full failover scenario:
|
||||
// WS connect → download → WS disconnect → switch to HTTP → continue working
|
||||
func TestE2EHybridFailover(t *testing.T) {
|
||||
connectionCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
connectionCount++
|
||||
connNum := connectionCount
|
||||
mu.Unlock()
|
||||
|
||||
// Read auth
|
||||
conn.ReadMessage()
|
||||
conn.WriteJSON(wsRegisteredMessage{
|
||||
Type: "registered",
|
||||
User: UserInfo{Name: "Failover User"},
|
||||
})
|
||||
|
||||
if connNum == 1 {
|
||||
// First connection: push tasks then disconnect after 200ms
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
conn.WriteJSON(wsTasksMessage{
|
||||
Type: "tasks",
|
||||
Tasks: []Task{{ID: "t1", InfoHash: "abc", Title: "Failover Movie"}},
|
||||
})
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
conn.Close()
|
||||
} else {
|
||||
// Second connection (after reconnect): push upgrade
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
conn.WriteJSON(wsUpgradeMessage{Type: "upgrade", Version: "3.0.0"})
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
conn.Close()
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||
wsT := NewWSTransport(wsURL, "key", "a1", "ua")
|
||||
|
||||
// HTTP mock for fallback
|
||||
httpSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simple heartbeat response
|
||||
json.NewEncoder(w).Encode(HeartbeatResponse{Success: true})
|
||||
}))
|
||||
defer httpSrv.Close()
|
||||
|
||||
httpT := NewHTTPTransport(httpSrv.URL, "key", "ua")
|
||||
h := NewHybridTransport(wsT, httpT)
|
||||
|
||||
ctx := context.Background()
|
||||
err := h.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Connect: %v", err)
|
||||
}
|
||||
defer h.Close()
|
||||
|
||||
// Should start in WS mode
|
||||
if h.Mode() != "ws" {
|
||||
t.Fatalf("expected ws mode, got %s", h.Mode())
|
||||
}
|
||||
|
||||
// Register via WS
|
||||
_, err = h.Register(ctx, RegisterRequest{AgentID: "a1"})
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
|
||||
// Receive tasks via WS
|
||||
var tasksReceived bool
|
||||
var disconnected bool
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case event := <-h.Events():
|
||||
switch event.Type {
|
||||
case "tasks":
|
||||
tasksReceived = true
|
||||
if len(event.Tasks.Tasks) != 1 || event.Tasks.Tasks[0].Title != "Failover Movie" {
|
||||
t.Errorf("unexpected tasks: %+v", event.Tasks)
|
||||
}
|
||||
case "disconnected":
|
||||
disconnected = true
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
break
|
||||
}
|
||||
if disconnected {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !tasksReceived {
|
||||
t.Error("did not receive tasks before disconnect")
|
||||
}
|
||||
if !disconnected {
|
||||
t.Error("did not receive disconnect event")
|
||||
}
|
||||
|
||||
// Should now be in HTTP mode
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if h.Mode() != "http" {
|
||||
t.Errorf("expected http mode after disconnect, got %s", h.Mode())
|
||||
}
|
||||
|
||||
// Heartbeat should work via HTTP fallback
|
||||
hbResp, err := h.SendHeartbeat(ctx, HeartbeatRequest{AgentID: "a1"})
|
||||
if err != nil {
|
||||
t.Fatalf("SendHeartbeat via HTTP fallback: %v", err)
|
||||
}
|
||||
if !hbResp.Success {
|
||||
t.Error("expected heartbeat success")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue