Added tests and readme
This commit is contained in:
@@ -0,0 +1,76 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFrameRoundtrip(t *testing.T) {
|
||||
cases := [][]byte{
|
||||
nil,
|
||||
{},
|
||||
[]byte("hello"),
|
||||
bytes.Repeat([]byte("x"), 1<<14),
|
||||
}
|
||||
for _, payload := range cases {
|
||||
var buf bytes.Buffer
|
||||
if err := writeFrame(&buf, payload); err != nil {
|
||||
t.Fatalf("write %d bytes: %v", len(payload), err)
|
||||
}
|
||||
out, err := readFrame(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("read %d bytes: %v", len(payload), err)
|
||||
}
|
||||
if !bytes.Equal(out, payload) {
|
||||
t.Errorf("roundtrip lost data for %d bytes", len(payload))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrameRejectsOversize(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
if err := writeFrame(&buf, bytes.Repeat([]byte{0}, MaxFrameSize+1)); err == nil {
|
||||
t.Error("oversized write was accepted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrameRejectsOversizeOnRead(t *testing.T) {
|
||||
// hand-crafted header announcing a size beyond the cap
|
||||
var buf bytes.Buffer
|
||||
buf.Write([]byte{0xFF, 0xFF, 0xFF, 0xFF}) // ~4GiB
|
||||
if _, err := readFrame(&buf); err == nil {
|
||||
t.Error("oversized read was accepted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrameReportsShortRead(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
// header says 10 bytes, body only 3
|
||||
buf.Write([]byte{0, 0, 0, 10})
|
||||
buf.WriteString("abc")
|
||||
if _, err := readFrame(&buf); err == nil {
|
||||
t.Error("short body did not error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleFramesInOneStream(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
for _, s := range []string{"first", "second", "third"} {
|
||||
if err := writeFrame(&buf, []byte(s)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
for _, want := range []string{"first", "second", "third"} {
|
||||
got, err := readFrame(&buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(got) != want {
|
||||
t.Errorf("got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
if _, err := readFrame(&buf); err != io.EOF {
|
||||
t.Errorf("expected EOF, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package transport
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -322,8 +321,3 @@ func peerNodeIDFromConnState(cs tls.ConnectionState) string {
|
||||
}
|
||||
return cs.PeerCertificates[0].Subject.CommonName
|
||||
}
|
||||
|
||||
// fingerprintOf is a small local mirror to keep this file independent
|
||||
// of the crypto package's import path at link time; we recompute the
|
||||
// SPKI hash here. Defined in tofu.go.
|
||||
var _ = (*x509.Certificate)(nil)
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jasper/quptime/internal/crypto"
|
||||
"github.com/jasper/quptime/internal/trust"
|
||||
)
|
||||
|
||||
// testNode bundles everything one side of the handshake needs.
|
||||
type testNode struct {
|
||||
id string
|
||||
dir string
|
||||
assets *TLSAssets
|
||||
fp string
|
||||
}
|
||||
|
||||
// makeNode builds keys + cert + an empty trust store rooted at dir.
|
||||
// After every disk-touching trust operation the caller must ensure
|
||||
// QUPTIME_DIR points back at this node's dir.
|
||||
func makeNode(t *testing.T, dir, id string) *testNode {
|
||||
t.Helper()
|
||||
t.Setenv("QUPTIME_DIR", dir)
|
||||
priv, err := crypto.GenerateKeyPair(id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
certPEM, err := crypto.LoadCertPEM()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fp, err := crypto.FingerprintFromCertPEM(certPEM)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
store, err := trust.Load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &testNode{
|
||||
id: id,
|
||||
dir: dir,
|
||||
assets: &TLSAssets{Cert: certPEM, Key: priv, Trust: store},
|
||||
fp: fp,
|
||||
}
|
||||
}
|
||||
|
||||
func (n *testNode) trust(t *testing.T, other *testNode, addr string) {
|
||||
t.Helper()
|
||||
t.Setenv("QUPTIME_DIR", n.dir)
|
||||
if err := n.assets.Trust.Add(trust.Entry{
|
||||
NodeID: other.id, Address: addr, Fingerprint: other.fp,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPCRoundtrip(t *testing.T) {
|
||||
a := makeNode(t, t.TempDir(), "node-a")
|
||||
b := makeNode(t, t.TempDir(), "node-b")
|
||||
|
||||
// pre-pick a free port; brief race window is acceptable for tests
|
||||
tmpLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
addr := tmpLn.Addr().String()
|
||||
tmpLn.Close()
|
||||
|
||||
a.trust(t, b, addr)
|
||||
b.trust(t, a, addr)
|
||||
|
||||
srv := NewServer(a.assets)
|
||||
srv.Handle("Echo", func(_ context.Context, peer string, payload json.RawMessage) (any, error) {
|
||||
var s string
|
||||
if err := json.Unmarshal(payload, &s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if peer != b.id {
|
||||
return nil, errors.New("unexpected peer id: " + peer)
|
||||
}
|
||||
return s + " ack", nil
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- srv.Serve(ctx, addr) }()
|
||||
defer srv.Stop()
|
||||
|
||||
if !waitForDial(addr, 2*time.Second) {
|
||||
t.Fatal("server did not start listening in time")
|
||||
}
|
||||
|
||||
cli := NewClient(b.assets)
|
||||
defer cli.Close()
|
||||
|
||||
callCtx, callCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer callCancel()
|
||||
var got string
|
||||
if err := cli.Call(callCtx, a.id, addr, "Echo", "hello", &got); err != nil {
|
||||
t.Fatalf("Call: %v", err)
|
||||
}
|
||||
if got != "hello ack" {
|
||||
t.Errorf("got %q want %q", got, "hello ack")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPCUnknownMethod(t *testing.T) {
|
||||
a := makeNode(t, t.TempDir(), "node-a")
|
||||
b := makeNode(t, t.TempDir(), "node-b")
|
||||
|
||||
tmpLn, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||
addr := tmpLn.Addr().String()
|
||||
tmpLn.Close()
|
||||
|
||||
a.trust(t, b, addr)
|
||||
b.trust(t, a, addr)
|
||||
|
||||
srv := NewServer(a.assets)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go srv.Serve(ctx, addr)
|
||||
defer srv.Stop()
|
||||
if !waitForDial(addr, 2*time.Second) {
|
||||
t.Fatal("server not up")
|
||||
}
|
||||
|
||||
cli := NewClient(b.assets)
|
||||
defer cli.Close()
|
||||
err := cli.Call(ctx, a.id, addr, "DoesNotExist", nil, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown method")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPCRejectsUntrustedPeer(t *testing.T) {
|
||||
a := makeNode(t, t.TempDir(), "node-a")
|
||||
b := makeNode(t, t.TempDir(), "node-b")
|
||||
|
||||
tmpLn, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||
addr := tmpLn.Addr().String()
|
||||
tmpLn.Close()
|
||||
|
||||
// Deliberately omit b.trust(...) on the server side: b is unknown to a.
|
||||
t.Setenv("QUPTIME_DIR", b.dir)
|
||||
_ = b.assets.Trust.Add(trust.Entry{NodeID: a.id, Address: addr, Fingerprint: a.fp})
|
||||
|
||||
srv := NewServer(a.assets)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go srv.Serve(ctx, addr)
|
||||
defer srv.Stop()
|
||||
if !waitForDial(addr, 2*time.Second) {
|
||||
t.Fatal("server not up")
|
||||
}
|
||||
|
||||
cli := NewClient(b.assets)
|
||||
defer cli.Close()
|
||||
|
||||
callCtx, callCancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer callCancel()
|
||||
if err := cli.Call(callCtx, a.id, addr, "Ping", nil, nil); err == nil {
|
||||
t.Error("untrusted client was admitted")
|
||||
}
|
||||
}
|
||||
|
||||
// waitForDial polls a TCP listener until it accepts a plain TCP
|
||||
// connection, signalling that Serve has begun listening.
|
||||
func waitForDial(addr string, max time.Duration) bool {
|
||||
deadline := time.Now().Add(max)
|
||||
for time.Now().Before(deadline) {
|
||||
c, err := net.DialTimeout("tcp", addr, 200*time.Millisecond)
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
return true
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user