Files
QUptime/internal/transport/rpc_test.go
T
2026-05-12 06:53:15 +00:00

187 lines
4.4 KiB
Go

package transport
import (
"context"
"encoding/json"
"errors"
"net"
"testing"
"time"
"git.cer.sh/axodouble/quptime/internal/crypto"
"git.cer.sh/axodouble/quptime/internal/trust"
)
// testNode bundles everything one side of the handshake needs.
type testNode struct {
id string
dir string
assets *TLSAssets
fp string
}
// makeNode builds keys + cert + an empty trust store rooted at dir.
// After every disk-touching trust operation the caller must ensure
// QUPTIME_DIR points back at this node's dir.
func makeNode(t *testing.T, dir, id string) *testNode {
t.Helper()
t.Setenv("QUPTIME_DIR", dir)
priv, err := crypto.GenerateKeyPair(id)
if err != nil {
t.Fatal(err)
}
certPEM, err := crypto.LoadCertPEM()
if err != nil {
t.Fatal(err)
}
fp, err := crypto.FingerprintFromCertPEM(certPEM)
if err != nil {
t.Fatal(err)
}
store, err := trust.Load()
if err != nil {
t.Fatal(err)
}
return &testNode{
id: id,
dir: dir,
assets: &TLSAssets{Cert: certPEM, Key: priv, Trust: store},
fp: fp,
}
}
func (n *testNode) trust(t *testing.T, other *testNode, addr string) {
t.Helper()
t.Setenv("QUPTIME_DIR", n.dir)
if err := n.assets.Trust.Add(trust.Entry{
NodeID: other.id, Address: addr, Fingerprint: other.fp,
}); err != nil {
t.Fatal(err)
}
}
func TestRPCRoundtrip(t *testing.T) {
a := makeNode(t, t.TempDir(), "node-a")
b := makeNode(t, t.TempDir(), "node-b")
// pre-pick a free port; brief race window is acceptable for tests
tmpLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
addr := tmpLn.Addr().String()
tmpLn.Close()
a.trust(t, b, addr)
b.trust(t, a, addr)
srv := NewServer(a.assets)
srv.Handle("Echo", func(_ context.Context, peer string, payload json.RawMessage) (any, error) {
var s string
if err := json.Unmarshal(payload, &s); err != nil {
return nil, err
}
if peer != b.id {
return nil, errors.New("unexpected peer id: " + peer)
}
return s + " ack", nil
})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan error, 1)
go func() { done <- srv.Serve(ctx, addr) }()
defer srv.Stop()
if !waitForDial(addr, 2*time.Second) {
t.Fatal("server did not start listening in time")
}
cli := NewClient(b.assets)
defer cli.Close()
callCtx, callCancel := context.WithTimeout(ctx, 5*time.Second)
defer callCancel()
var got string
if err := cli.Call(callCtx, a.id, addr, "Echo", "hello", &got); err != nil {
t.Fatalf("Call: %v", err)
}
if got != "hello ack" {
t.Errorf("got %q want %q", got, "hello ack")
}
}
func TestRPCUnknownMethod(t *testing.T) {
a := makeNode(t, t.TempDir(), "node-a")
b := makeNode(t, t.TempDir(), "node-b")
tmpLn, _ := net.Listen("tcp", "127.0.0.1:0")
addr := tmpLn.Addr().String()
tmpLn.Close()
a.trust(t, b, addr)
b.trust(t, a, addr)
srv := NewServer(a.assets)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go srv.Serve(ctx, addr)
defer srv.Stop()
if !waitForDial(addr, 2*time.Second) {
t.Fatal("server not up")
}
cli := NewClient(b.assets)
defer cli.Close()
err := cli.Call(ctx, a.id, addr, "DoesNotExist", nil, nil)
if err == nil {
t.Fatal("expected error for unknown method")
}
}
func TestRPCRejectsUntrustedPeer(t *testing.T) {
a := makeNode(t, t.TempDir(), "node-a")
b := makeNode(t, t.TempDir(), "node-b")
tmpLn, _ := net.Listen("tcp", "127.0.0.1:0")
addr := tmpLn.Addr().String()
tmpLn.Close()
// Deliberately omit b.trust(...) on the server side: b is unknown to a.
t.Setenv("QUPTIME_DIR", b.dir)
_ = b.assets.Trust.Add(trust.Entry{NodeID: a.id, Address: addr, Fingerprint: a.fp})
srv := NewServer(a.assets)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go srv.Serve(ctx, addr)
defer srv.Stop()
if !waitForDial(addr, 2*time.Second) {
t.Fatal("server not up")
}
cli := NewClient(b.assets)
defer cli.Close()
callCtx, callCancel := context.WithTimeout(ctx, 2*time.Second)
defer callCancel()
if err := cli.Call(callCtx, a.id, addr, "Ping", nil, nil); err == nil {
t.Error("untrusted client was admitted")
}
}
// waitForDial polls a TCP listener until it accepts a plain TCP
// connection, signalling that Serve has begun listening.
func waitForDial(addr string, max time.Duration) bool {
deadline := time.Now().Add(max)
for time.Now().Before(deadline) {
c, err := net.DialTimeout("tcp", addr, 200*time.Millisecond)
if err == nil {
_ = c.Close()
return true
}
time.Sleep(20 * time.Millisecond)
}
return false
}