Initial structure
This commit is contained in:
@@ -0,0 +1,329 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HandlerFunc is registered by callers for a specific method name. The
|
||||
// raw JSON request body and the peer's verified node ID are provided.
|
||||
// The returned value (if any) is JSON-marshalled into the response.
|
||||
type HandlerFunc func(ctx context.Context, peerNodeID string, payload json.RawMessage) (any, error)
|
||||
|
||||
// Server is a registry of method handlers plus an accept loop. It
|
||||
// owns no business logic; callers register methods and Serve dispatches.
|
||||
type Server struct {
|
||||
assets *TLSAssets
|
||||
handlers map[string]HandlerFunc
|
||||
|
||||
mu sync.Mutex
|
||||
ln net.Listener
|
||||
conns map[net.Conn]struct{}
|
||||
}
|
||||
|
||||
// NewServer constructs a Server with no handlers registered.
|
||||
func NewServer(assets *TLSAssets) *Server {
|
||||
return &Server{
|
||||
assets: assets,
|
||||
handlers: map[string]HandlerFunc{},
|
||||
conns: map[net.Conn]struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
// Handle registers fn for the given method name. Replaces any prior
|
||||
// handler for the same method.
|
||||
func (s *Server) Handle(method string, fn HandlerFunc) {
|
||||
s.handlers[method] = fn
|
||||
}
|
||||
|
||||
// Serve binds the listener at addr and dispatches incoming RPCs until
|
||||
// Stop is called or the listener errors out.
|
||||
func (s *Server) Serve(ctx context.Context, addr string) error {
|
||||
tlsCfg, err := s.assets.ServerConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ln, err := tls.Listen("tcp", addr, tlsCfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen %s: %w", addr, err)
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.ln = ln
|
||||
s.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = ln.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
go s.handleConn(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop closes the listener and all in-flight connections. Safe to call
|
||||
// from any goroutine.
|
||||
func (s *Server) Stop() {
|
||||
s.mu.Lock()
|
||||
if s.ln != nil {
|
||||
_ = s.ln.Close()
|
||||
}
|
||||
for c := range s.conns {
|
||||
_ = c.Close()
|
||||
}
|
||||
s.conns = map[net.Conn]struct{}{}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) trackConn(c net.Conn) { s.mu.Lock(); s.conns[c] = struct{}{}; s.mu.Unlock() }
|
||||
func (s *Server) untrackConn(c net.Conn) { s.mu.Lock(); delete(s.conns, c); s.mu.Unlock() }
|
||||
|
||||
func (s *Server) handleConn(ctx context.Context, raw net.Conn) {
|
||||
s.trackConn(raw)
|
||||
defer func() {
|
||||
s.untrackConn(raw)
|
||||
_ = raw.Close()
|
||||
}()
|
||||
|
||||
tlsConn, ok := raw.(*tls.Conn)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
peerID := peerNodeIDFromConnState(tlsConn.ConnectionState())
|
||||
|
||||
for {
|
||||
body, err := readFrame(tlsConn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var req requestEnvelope
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
_ = writeError(tlsConn, 0, "decode request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
fn, exists := s.handlers[req.Method]
|
||||
if !exists {
|
||||
_ = writeError(tlsConn, req.ID, "unknown method: "+req.Method)
|
||||
continue
|
||||
}
|
||||
|
||||
result, err := fn(ctx, peerID, req.Params)
|
||||
if err != nil {
|
||||
_ = writeError(tlsConn, req.ID, err.Error())
|
||||
continue
|
||||
}
|
||||
if err := writeResult(tlsConn, req.ID, result); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Client opens and pools one mTLS connection per peer node ID. Each
|
||||
// connection serialises outstanding calls under a mutex; concurrent
|
||||
// calls to different peers proceed in parallel.
|
||||
type Client struct {
|
||||
assets *TLSAssets
|
||||
|
||||
mu sync.Mutex
|
||||
conns map[string]*clientConn // by peer node ID
|
||||
|
||||
nextID atomic.Uint64
|
||||
}
|
||||
|
||||
// NewClient constructs an empty connection pool.
|
||||
func NewClient(assets *TLSAssets) *Client {
|
||||
return &Client{assets: assets, conns: map[string]*clientConn{}}
|
||||
}
|
||||
|
||||
type clientConn struct {
|
||||
mu sync.Mutex
|
||||
conn *tls.Conn
|
||||
}
|
||||
|
||||
// Close drops every pooled connection. Safe to call multiple times.
|
||||
func (c *Client) Close() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for id, cc := range c.conns {
|
||||
if cc.conn != nil {
|
||||
_ = cc.conn.Close()
|
||||
}
|
||||
delete(c.conns, id)
|
||||
}
|
||||
}
|
||||
|
||||
// Call invokes method on the peer at addr (identified by nodeID for
|
||||
// fingerprint pinning), marshalling params to JSON and unmarshalling
|
||||
// the result into out. out may be nil if the caller doesn't care.
|
||||
func (c *Client) Call(ctx context.Context, nodeID, addr, method string, params any, out any) error {
|
||||
cc, err := c.getConn(ctx, nodeID, addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.callOn(ctx, cc, method, params, out); err != nil {
|
||||
// drop the connection on error so the next call reconnects fresh
|
||||
c.dropConn(nodeID)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) callOn(ctx context.Context, cc *clientConn, method string, params any, out any) error {
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal params: %w", err)
|
||||
}
|
||||
id := c.nextID.Add(1)
|
||||
env := requestEnvelope{ID: id, Method: method, Params: paramsJSON}
|
||||
body, err := json.Marshal(env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
|
||||
if dl, ok := ctx.Deadline(); ok {
|
||||
_ = cc.conn.SetDeadline(dl)
|
||||
defer func() { _ = cc.conn.SetDeadline(time.Time{}) }()
|
||||
}
|
||||
|
||||
if err := writeFrame(cc.conn, body); err != nil {
|
||||
return err
|
||||
}
|
||||
respBody, err := readFrame(cc.conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var resp responseEnvelope
|
||||
if err := json.Unmarshal(respBody, &resp); err != nil {
|
||||
return fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
if resp.Error != "" {
|
||||
return fmt.Errorf("remote: %s", resp.Error)
|
||||
}
|
||||
if out != nil && len(resp.Result) > 0 {
|
||||
if err := json.Unmarshal(resp.Result, out); err != nil {
|
||||
return fmt.Errorf("decode result: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) getConn(ctx context.Context, nodeID, addr string) (*clientConn, error) {
|
||||
c.mu.Lock()
|
||||
cc, ok := c.conns[nodeID]
|
||||
c.mu.Unlock()
|
||||
if ok && cc.conn != nil {
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
tlsCfg, err := c.assets.ClientConfig(nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := tls.Dialer{Config: tlsCfg}
|
||||
raw, err := d.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial %s: %w", addr, err)
|
||||
}
|
||||
tc, ok := raw.(*tls.Conn)
|
||||
if !ok {
|
||||
_ = raw.Close()
|
||||
return nil, errors.New("dial returned non-tls conn")
|
||||
}
|
||||
cc = &clientConn{conn: tc}
|
||||
c.mu.Lock()
|
||||
if existing, ok := c.conns[nodeID]; ok && existing.conn != nil {
|
||||
// concurrent dial — drop ours, reuse existing
|
||||
_ = tc.Close()
|
||||
c.mu.Unlock()
|
||||
return existing, nil
|
||||
}
|
||||
c.conns[nodeID] = cc
|
||||
c.mu.Unlock()
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
func (c *Client) dropConn(nodeID string) {
|
||||
c.mu.Lock()
|
||||
if cc, ok := c.conns[nodeID]; ok {
|
||||
if cc.conn != nil {
|
||||
_ = cc.conn.Close()
|
||||
}
|
||||
delete(c.conns, nodeID)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// requestEnvelope is the wire shape of an RPC request frame.
|
||||
type requestEnvelope struct {
|
||||
ID uint64 `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Params json.RawMessage `json:"params"`
|
||||
}
|
||||
|
||||
// responseEnvelope is the wire shape of an RPC response frame.
|
||||
type responseEnvelope struct {
|
||||
ID uint64 `json:"id"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func writeResult(w io.Writer, id uint64, result any) error {
|
||||
var raw json.RawMessage
|
||||
if result != nil {
|
||||
b, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return writeError(w, id, "marshal result: "+err.Error())
|
||||
}
|
||||
raw = b
|
||||
}
|
||||
body, err := json.Marshal(responseEnvelope{ID: id, Result: raw})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeFrame(w, body)
|
||||
}
|
||||
|
||||
func writeError(w io.Writer, id uint64, msg string) error {
|
||||
body, err := json.Marshal(responseEnvelope{ID: id, Error: msg})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeFrame(w, body)
|
||||
}
|
||||
|
||||
// peerNodeIDFromConnState extracts the peer's NodeID from the cert's
|
||||
// CommonName field. The init flow sets CN to the local NodeID.
|
||||
func peerNodeIDFromConnState(cs tls.ConnectionState) string {
|
||||
if len(cs.PeerCertificates) == 0 {
|
||||
return ""
|
||||
}
|
||||
return cs.PeerCertificates[0].Subject.CommonName
|
||||
}
|
||||
|
||||
// fingerprintOf is a small local mirror to keep this file independent
|
||||
// of the crypto package's import path at link time; we recompute the
|
||||
// SPKI hash here. Defined in tofu.go.
|
||||
var _ = (*x509.Certificate)(nil)
|
||||
Reference in New Issue
Block a user