Initial structure
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// MaxFrameSize caps the size of an individual on-wire message. 16 MiB
|
||||
// is comfortably above any plausible cluster.yaml or status payload
|
||||
// and rejects malicious giants up front.
|
||||
const MaxFrameSize = 16 * 1024 * 1024
|
||||
|
||||
// writeFrame emits a single length-prefixed message: 4-byte big-endian
|
||||
// length followed by the body.
|
||||
func writeFrame(w io.Writer, body []byte) error {
|
||||
if len(body) > MaxFrameSize {
|
||||
return errors.New("frame too large")
|
||||
}
|
||||
var hdr [4]byte
|
||||
binary.BigEndian.PutUint32(hdr[:], uint32(len(body)))
|
||||
if _, err := w.Write(hdr[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := w.Write(body)
|
||||
return err
|
||||
}
|
||||
|
||||
// readFrame reads the next length-prefixed message. Returns io.EOF
|
||||
// cleanly when the connection closes on a frame boundary.
|
||||
func readFrame(r io.Reader) ([]byte, error) {
|
||||
var hdr [4]byte
|
||||
if _, err := io.ReadFull(r, hdr[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n := binary.BigEndian.Uint32(hdr[:])
|
||||
if n > MaxFrameSize {
|
||||
return nil, errors.New("incoming frame exceeds MaxFrameSize")
|
||||
}
|
||||
buf := make([]byte, n)
|
||||
if _, err := io.ReadFull(r, buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
@@ -0,0 +1,193 @@
|
||||
// Package transport carries inter-node RPC over mTLS. It owns three
|
||||
// concerns and nothing else:
|
||||
//
|
||||
// 1. Building tls.Config values that pin peer certs against the local
|
||||
// trust store (server and client side).
|
||||
// 2. Length-prefixed JSON framing on top of the TLS connection.
|
||||
// 3. A tiny method-dispatch RPC: callers register handlers by method
|
||||
// name; remote peers invoke them via Client.Call.
|
||||
//
|
||||
// Higher-level concerns (heartbeats, quorum, replication, check
|
||||
// shipping) live in their own packages and use this one purely as a
|
||||
// pipe. That keeps the wire format easy to reason about and the
|
||||
// surrounding packages testable without a real network.
|
||||
package transport
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/jasper/quptime/internal/config"
|
||||
)
|
||||
|
||||
// Method names. Defined here so every package agrees on the wire-level
|
||||
// identifier without importing each other.
|
||||
const (
|
||||
MethodPing = "Ping"
|
||||
MethodWhoAmI = "WhoAmI"
|
||||
MethodJoin = "Join"
|
||||
MethodHeartbeat = "Heartbeat"
|
||||
MethodGetClusterCfg = "GetClusterCfg"
|
||||
MethodApplyClusterCfg = "ApplyClusterCfg"
|
||||
MethodProposeMutation = "ProposeMutation"
|
||||
MethodReportResult = "ReportResult"
|
||||
MethodStatus = "Status"
|
||||
)
|
||||
|
||||
// PingRequest is an empty liveness probe. PingResponse carries the
|
||||
// responder's wall clock so the caller can sanity-check drift.
|
||||
type PingRequest struct{}
|
||||
|
||||
// PingResponse is returned by MethodPing.
|
||||
type PingResponse struct {
|
||||
NodeID string `json:"node_id"`
|
||||
Now time.Time `json:"now"`
|
||||
}
|
||||
|
||||
// WhoAmIRequest asks the remote node to identify itself. Used during
|
||||
// the TOFU handshake before the caller commits a trust entry.
|
||||
type WhoAmIRequest struct{}
|
||||
|
||||
// WhoAmIResponse carries the node's identity. The fingerprint is
|
||||
// recomputed by the caller from the TLS cert and compared against the
|
||||
// claim here as a defense-in-depth check.
|
||||
type WhoAmIResponse struct {
|
||||
NodeID string `json:"node_id"`
|
||||
Advertise string `json:"advertise"`
|
||||
Fingerprint string `json:"fingerprint"`
|
||||
CertPEM string `json:"cert_pem"`
|
||||
}
|
||||
|
||||
// JoinRequest is sent by a node that has just learned the remote's
|
||||
// fingerprint out of band and wants the remote to record this node in
|
||||
// its own trust store too (so the relationship is symmetric).
|
||||
type JoinRequest struct {
|
||||
NodeID string `json:"node_id"`
|
||||
Advertise string `json:"advertise"`
|
||||
Fingerprint string `json:"fingerprint"`
|
||||
CertPEM string `json:"cert_pem"`
|
||||
}
|
||||
|
||||
// JoinResponse echoes a non-empty Error string when the remote refuses
|
||||
// the join (e.g. operator declined the prompt or fingerprint mismatch).
|
||||
type JoinResponse struct {
|
||||
Accepted bool `json:"accepted"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// HeartbeatRequest is the periodic liveness ping sent over the
|
||||
// inter-node channel. It also carries the sender's view of who the
|
||||
// master is, so disagreements surface quickly.
|
||||
type HeartbeatRequest struct {
|
||||
FromNodeID string `json:"from_node_id"`
|
||||
Term uint64 `json:"term"`
|
||||
MasterID string `json:"master_id"`
|
||||
Version uint64 `json:"config_version"`
|
||||
}
|
||||
|
||||
// HeartbeatResponse is returned by MethodHeartbeat.
|
||||
type HeartbeatResponse struct {
|
||||
NodeID string `json:"node_id"`
|
||||
Term uint64 `json:"term"`
|
||||
MasterID string `json:"master_id"`
|
||||
Version uint64 `json:"config_version"`
|
||||
}
|
||||
|
||||
// GetClusterCfgRequest fetches the responder's view of cluster.yaml.
|
||||
// Used by stale followers to pull the canonical config from master.
|
||||
type GetClusterCfgRequest struct{}
|
||||
|
||||
// GetClusterCfgResponse contains a cluster.yaml snapshot.
|
||||
type GetClusterCfgResponse struct {
|
||||
Config *config.ClusterConfig `json:"config"`
|
||||
}
|
||||
|
||||
// ApplyClusterCfgRequest is the master pushing a new replicated config
|
||||
// to a follower. The follower applies only if Version is strictly
|
||||
// greater than its local Version.
|
||||
type ApplyClusterCfgRequest struct {
|
||||
Config *config.ClusterConfig `json:"config"`
|
||||
}
|
||||
|
||||
// ApplyClusterCfgResponse acknowledges with whether the follower
|
||||
// stored the new config.
|
||||
type ApplyClusterCfgResponse struct {
|
||||
Applied bool `json:"applied"`
|
||||
Version uint64 `json:"current_version"`
|
||||
}
|
||||
|
||||
// MutationKind enumerates the cluster-config edit operations that
|
||||
// followers forward to the master.
|
||||
type MutationKind string
|
||||
|
||||
const (
|
||||
MutationAddCheck MutationKind = "add_check"
|
||||
MutationRemoveCheck MutationKind = "remove_check"
|
||||
MutationAddAlert MutationKind = "add_alert"
|
||||
MutationRemoveAlert MutationKind = "remove_alert"
|
||||
MutationAddPeer MutationKind = "add_peer"
|
||||
MutationRemovePeer MutationKind = "remove_peer"
|
||||
)
|
||||
|
||||
// ProposeMutationRequest is a follower-to-master message. The payload
|
||||
// is the JSON-encoded body of the new entity (a Check, an Alert, or a
|
||||
// PeerInfo) for the "add" variants, or the target ID/NodeID string for
|
||||
// removals.
|
||||
type ProposeMutationRequest struct {
|
||||
FromNodeID string `json:"from_node_id"`
|
||||
Kind MutationKind `json:"kind"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
}
|
||||
|
||||
// ProposeMutationResponse is the master's reply to ProposeMutation.
|
||||
type ProposeMutationResponse struct {
|
||||
NewVersion uint64 `json:"new_version"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ReportResultRequest is a follower-to-master message reporting the
|
||||
// outcome of a single local probe.
|
||||
type ReportResultRequest struct {
|
||||
FromNodeID string `json:"from_node_id"`
|
||||
CheckID string `json:"check_id"`
|
||||
OK bool `json:"ok"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
LatencyMS int64 `json:"latency_ms"`
|
||||
At time.Time `json:"at"`
|
||||
}
|
||||
|
||||
// ReportResultResponse acknowledges a result. Empty body for now.
|
||||
type ReportResultResponse struct{}
|
||||
|
||||
// StatusRequest asks a peer for its operational state.
|
||||
type StatusRequest struct{}
|
||||
|
||||
// StatusResponse is what `qu status` aggregates and displays.
|
||||
type StatusResponse struct {
|
||||
NodeID string `json:"node_id"`
|
||||
Term uint64 `json:"term"`
|
||||
MasterID string `json:"master_id"`
|
||||
Version uint64 `json:"config_version"`
|
||||
Peers []PeerLiveness `json:"peers"`
|
||||
Checks []CheckSnapshot `json:"checks"`
|
||||
HasQuorum bool `json:"has_quorum"`
|
||||
QuorumSize int `json:"quorum_size"`
|
||||
}
|
||||
|
||||
// PeerLiveness summarises one peer for status output.
|
||||
type PeerLiveness struct {
|
||||
NodeID string `json:"node_id"`
|
||||
Advertise string `json:"advertise"`
|
||||
Live bool `json:"live"`
|
||||
LastSeen time.Time `json:"last_seen"`
|
||||
}
|
||||
|
||||
// CheckSnapshot is the aggregate state of one configured check.
|
||||
type CheckSnapshot struct {
|
||||
CheckID string `json:"check_id"`
|
||||
Name string `json:"name"`
|
||||
State string `json:"state"` // "up", "down", "unknown"
|
||||
OKCount int `json:"ok_count"`
|
||||
Total int `json:"total"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
}
|
||||
@@ -0,0 +1,329 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HandlerFunc is registered by callers for a specific method name. The
|
||||
// raw JSON request body and the peer's verified node ID are provided.
|
||||
// The returned value (if any) is JSON-marshalled into the response.
|
||||
type HandlerFunc func(ctx context.Context, peerNodeID string, payload json.RawMessage) (any, error)
|
||||
|
||||
// Server is a registry of method handlers plus an accept loop. It
|
||||
// owns no business logic; callers register methods and Serve dispatches.
|
||||
type Server struct {
|
||||
assets *TLSAssets
|
||||
handlers map[string]HandlerFunc
|
||||
|
||||
mu sync.Mutex
|
||||
ln net.Listener
|
||||
conns map[net.Conn]struct{}
|
||||
}
|
||||
|
||||
// NewServer constructs a Server with no handlers registered.
|
||||
func NewServer(assets *TLSAssets) *Server {
|
||||
return &Server{
|
||||
assets: assets,
|
||||
handlers: map[string]HandlerFunc{},
|
||||
conns: map[net.Conn]struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
// Handle registers fn for the given method name. Replaces any prior
|
||||
// handler for the same method.
|
||||
func (s *Server) Handle(method string, fn HandlerFunc) {
|
||||
s.handlers[method] = fn
|
||||
}
|
||||
|
||||
// Serve binds the listener at addr and dispatches incoming RPCs until
|
||||
// Stop is called or the listener errors out.
|
||||
func (s *Server) Serve(ctx context.Context, addr string) error {
|
||||
tlsCfg, err := s.assets.ServerConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ln, err := tls.Listen("tcp", addr, tlsCfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen %s: %w", addr, err)
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.ln = ln
|
||||
s.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = ln.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
go s.handleConn(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop closes the listener and all in-flight connections. Safe to call
|
||||
// from any goroutine.
|
||||
func (s *Server) Stop() {
|
||||
s.mu.Lock()
|
||||
if s.ln != nil {
|
||||
_ = s.ln.Close()
|
||||
}
|
||||
for c := range s.conns {
|
||||
_ = c.Close()
|
||||
}
|
||||
s.conns = map[net.Conn]struct{}{}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) trackConn(c net.Conn) { s.mu.Lock(); s.conns[c] = struct{}{}; s.mu.Unlock() }
|
||||
func (s *Server) untrackConn(c net.Conn) { s.mu.Lock(); delete(s.conns, c); s.mu.Unlock() }
|
||||
|
||||
func (s *Server) handleConn(ctx context.Context, raw net.Conn) {
|
||||
s.trackConn(raw)
|
||||
defer func() {
|
||||
s.untrackConn(raw)
|
||||
_ = raw.Close()
|
||||
}()
|
||||
|
||||
tlsConn, ok := raw.(*tls.Conn)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
peerID := peerNodeIDFromConnState(tlsConn.ConnectionState())
|
||||
|
||||
for {
|
||||
body, err := readFrame(tlsConn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var req requestEnvelope
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
_ = writeError(tlsConn, 0, "decode request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
fn, exists := s.handlers[req.Method]
|
||||
if !exists {
|
||||
_ = writeError(tlsConn, req.ID, "unknown method: "+req.Method)
|
||||
continue
|
||||
}
|
||||
|
||||
result, err := fn(ctx, peerID, req.Params)
|
||||
if err != nil {
|
||||
_ = writeError(tlsConn, req.ID, err.Error())
|
||||
continue
|
||||
}
|
||||
if err := writeResult(tlsConn, req.ID, result); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Client opens and pools one mTLS connection per peer node ID. Each
|
||||
// connection serialises outstanding calls under a mutex; concurrent
|
||||
// calls to different peers proceed in parallel.
|
||||
type Client struct {
|
||||
assets *TLSAssets
|
||||
|
||||
mu sync.Mutex
|
||||
conns map[string]*clientConn // by peer node ID
|
||||
|
||||
nextID atomic.Uint64
|
||||
}
|
||||
|
||||
// NewClient constructs an empty connection pool.
|
||||
func NewClient(assets *TLSAssets) *Client {
|
||||
return &Client{assets: assets, conns: map[string]*clientConn{}}
|
||||
}
|
||||
|
||||
type clientConn struct {
|
||||
mu sync.Mutex
|
||||
conn *tls.Conn
|
||||
}
|
||||
|
||||
// Close drops every pooled connection. Safe to call multiple times.
|
||||
func (c *Client) Close() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for id, cc := range c.conns {
|
||||
if cc.conn != nil {
|
||||
_ = cc.conn.Close()
|
||||
}
|
||||
delete(c.conns, id)
|
||||
}
|
||||
}
|
||||
|
||||
// Call invokes method on the peer at addr (identified by nodeID for
|
||||
// fingerprint pinning), marshalling params to JSON and unmarshalling
|
||||
// the result into out. out may be nil if the caller doesn't care.
|
||||
func (c *Client) Call(ctx context.Context, nodeID, addr, method string, params any, out any) error {
|
||||
cc, err := c.getConn(ctx, nodeID, addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.callOn(ctx, cc, method, params, out); err != nil {
|
||||
// drop the connection on error so the next call reconnects fresh
|
||||
c.dropConn(nodeID)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) callOn(ctx context.Context, cc *clientConn, method string, params any, out any) error {
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal params: %w", err)
|
||||
}
|
||||
id := c.nextID.Add(1)
|
||||
env := requestEnvelope{ID: id, Method: method, Params: paramsJSON}
|
||||
body, err := json.Marshal(env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
|
||||
if dl, ok := ctx.Deadline(); ok {
|
||||
_ = cc.conn.SetDeadline(dl)
|
||||
defer func() { _ = cc.conn.SetDeadline(time.Time{}) }()
|
||||
}
|
||||
|
||||
if err := writeFrame(cc.conn, body); err != nil {
|
||||
return err
|
||||
}
|
||||
respBody, err := readFrame(cc.conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var resp responseEnvelope
|
||||
if err := json.Unmarshal(respBody, &resp); err != nil {
|
||||
return fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
if resp.Error != "" {
|
||||
return fmt.Errorf("remote: %s", resp.Error)
|
||||
}
|
||||
if out != nil && len(resp.Result) > 0 {
|
||||
if err := json.Unmarshal(resp.Result, out); err != nil {
|
||||
return fmt.Errorf("decode result: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) getConn(ctx context.Context, nodeID, addr string) (*clientConn, error) {
|
||||
c.mu.Lock()
|
||||
cc, ok := c.conns[nodeID]
|
||||
c.mu.Unlock()
|
||||
if ok && cc.conn != nil {
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
tlsCfg, err := c.assets.ClientConfig(nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := tls.Dialer{Config: tlsCfg}
|
||||
raw, err := d.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial %s: %w", addr, err)
|
||||
}
|
||||
tc, ok := raw.(*tls.Conn)
|
||||
if !ok {
|
||||
_ = raw.Close()
|
||||
return nil, errors.New("dial returned non-tls conn")
|
||||
}
|
||||
cc = &clientConn{conn: tc}
|
||||
c.mu.Lock()
|
||||
if existing, ok := c.conns[nodeID]; ok && existing.conn != nil {
|
||||
// concurrent dial — drop ours, reuse existing
|
||||
_ = tc.Close()
|
||||
c.mu.Unlock()
|
||||
return existing, nil
|
||||
}
|
||||
c.conns[nodeID] = cc
|
||||
c.mu.Unlock()
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
func (c *Client) dropConn(nodeID string) {
|
||||
c.mu.Lock()
|
||||
if cc, ok := c.conns[nodeID]; ok {
|
||||
if cc.conn != nil {
|
||||
_ = cc.conn.Close()
|
||||
}
|
||||
delete(c.conns, nodeID)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// requestEnvelope is the wire shape of an RPC request frame.
|
||||
type requestEnvelope struct {
|
||||
ID uint64 `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Params json.RawMessage `json:"params"`
|
||||
}
|
||||
|
||||
// responseEnvelope is the wire shape of an RPC response frame.
|
||||
type responseEnvelope struct {
|
||||
ID uint64 `json:"id"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func writeResult(w io.Writer, id uint64, result any) error {
|
||||
var raw json.RawMessage
|
||||
if result != nil {
|
||||
b, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return writeError(w, id, "marshal result: "+err.Error())
|
||||
}
|
||||
raw = b
|
||||
}
|
||||
body, err := json.Marshal(responseEnvelope{ID: id, Result: raw})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeFrame(w, body)
|
||||
}
|
||||
|
||||
func writeError(w io.Writer, id uint64, msg string) error {
|
||||
body, err := json.Marshal(responseEnvelope{ID: id, Error: msg})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeFrame(w, body)
|
||||
}
|
||||
|
||||
// peerNodeIDFromConnState extracts the peer's NodeID from the cert's
|
||||
// CommonName field. The init flow sets CN to the local NodeID.
|
||||
func peerNodeIDFromConnState(cs tls.ConnectionState) string {
|
||||
if len(cs.PeerCertificates) == 0 {
|
||||
return ""
|
||||
}
|
||||
return cs.PeerCertificates[0].Subject.CommonName
|
||||
}
|
||||
|
||||
// fingerprintOf is a small local mirror to keep this file independent
|
||||
// of the crypto package's import path at link time; we recompute the
|
||||
// SPKI hash here. Defined in tofu.go.
|
||||
var _ = (*x509.Certificate)(nil)
|
||||
@@ -0,0 +1,120 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jasper/quptime/internal/trust"
|
||||
)
|
||||
|
||||
// MinTLS is the minimum protocol version both sides require.
|
||||
const MinTLS = tls.VersionTLS13
|
||||
|
||||
// TLSAssets bundles the on-disk material needed to spin up either a
|
||||
// listener or a dialer. Build it once at daemon start and pass to
|
||||
// ServerConfig / ClientConfig.
|
||||
type TLSAssets struct {
|
||||
Cert []byte // PEM-encoded leaf cert
|
||||
Key *rsa.PrivateKey
|
||||
Trust *trust.Store
|
||||
}
|
||||
|
||||
// tlsCert wraps the local PEM cert + RSA key into a tls.Certificate.
|
||||
func (a *TLSAssets) tlsCert() (tls.Certificate, error) {
|
||||
block, _ := pem.Decode(a.Cert)
|
||||
if block == nil {
|
||||
return tls.Certificate{}, errors.New("cert PEM has no block")
|
||||
}
|
||||
leaf, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("parse leaf: %w", err)
|
||||
}
|
||||
return tls.Certificate{
|
||||
Certificate: [][]byte{block.Bytes},
|
||||
PrivateKey: a.Key,
|
||||
Leaf: leaf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ServerConfig produces a tls.Config suitable for an inter-node
|
||||
// listener. Peers must present a certificate, and that certificate's
|
||||
// fingerprint must already be present in the trust store.
|
||||
func (a *TLSAssets) ServerConfig() (*tls.Config, error) {
|
||||
cert, err := a.tlsCert()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: MinTLS,
|
||||
ClientAuth: tls.RequireAnyClientCert,
|
||||
InsecureSkipVerify: true, // we do our own pinning via VerifyPeerCertificate
|
||||
VerifyPeerCertificate: a.Trust.VerifyPeerCert,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ClientConfig produces a tls.Config suitable for dialing a peer.
|
||||
// expectedNodeID is optional: if non-empty, the handshake also
|
||||
// verifies that the cert's fingerprint matches the trust entry for
|
||||
// that node ID.
|
||||
func (a *TLSAssets) ClientConfig(expectedNodeID string) (*tls.Config, error) {
|
||||
cert, err := a.tlsCert()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
verify := a.Trust.VerifyPeerCert
|
||||
if expectedNodeID != "" {
|
||||
verify = a.makeStrictVerifier(expectedNodeID)
|
||||
}
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: MinTLS,
|
||||
InsecureSkipVerify: true, // we do our own pinning via VerifyPeerCertificate
|
||||
VerifyPeerCertificate: verify,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// InsecureBootstrapConfig is the client-side TLS config used only by
|
||||
// the TOFU prefetch (FetchPeerCert). It accepts any peer cert because
|
||||
// the caller has not yet established trust; the certificate is
|
||||
// surfaced to the operator for manual approval before being added to
|
||||
// the store. Never use this anywhere else.
|
||||
func (a *TLSAssets) InsecureBootstrapConfig() (*tls.Config, error) {
|
||||
cert, err := a.tlsCert()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: MinTLS,
|
||||
InsecureSkipVerify: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// makeStrictVerifier returns a VerifyPeerCertificate callback that
|
||||
// pins the connection to the trust entry of a specific node ID.
|
||||
func (a *TLSAssets) makeStrictVerifier(expectedNodeID string) func([][]byte, [][]*x509.Certificate) error {
|
||||
return func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
||||
if len(rawCerts) == 0 {
|
||||
return errors.New("peer presented no certificate")
|
||||
}
|
||||
cert, err := x509.ParseCertificate(rawCerts[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse peer cert: %w", err)
|
||||
}
|
||||
entry, ok := a.Trust.Get(expectedNodeID)
|
||||
if !ok {
|
||||
return fmt.Errorf("no trust entry for node %s", expectedNodeID)
|
||||
}
|
||||
got := fingerprintOf(cert)
|
||||
if got != entry.Fingerprint {
|
||||
return fmt.Errorf("fingerprint mismatch for %s: got %s want %s",
|
||||
expectedNodeID, got, entry.Fingerprint)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// fingerprintOf computes the SHA-256 SPKI fingerprint of a parsed
|
||||
// certificate using the same encoding as the crypto package
|
||||
// (sha256:hex). Duplicated here to keep the transport package
|
||||
// dependency-light at the call site.
|
||||
func fingerprintOf(cert *x509.Certificate) string {
|
||||
sum := sha256.Sum256(cert.RawSubjectPublicKeyInfo)
|
||||
return "sha256:" + hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// PeerCertSample is the result of a TOFU probe: the operator inspects
|
||||
// the fingerprint and decides whether to trust it.
|
||||
type PeerCertSample struct {
|
||||
Cert *x509.Certificate
|
||||
CertPEM []byte
|
||||
Fingerprint string
|
||||
}
|
||||
|
||||
// FetchPeerCert opens an mTLS connection to addr with no trust
|
||||
// pinning, captures the peer's certificate, and closes the connection.
|
||||
// The caller must show the fingerprint to the operator before adding
|
||||
// it to the trust store.
|
||||
//
|
||||
// This is the *only* place the trust store is bypassed. After the
|
||||
// TOFU exchange, the regular ClientConfig path applies for all future
|
||||
// traffic to that peer.
|
||||
func FetchPeerCert(ctx context.Context, assets *TLSAssets, addr string) (*PeerCertSample, error) {
|
||||
cfg, err := assets.InsecureBootstrapConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dialCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
d := tls.Dialer{Config: cfg, NetDialer: &net.Dialer{}}
|
||||
raw, err := d.DialContext(dialCtx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial %s: %w", addr, err)
|
||||
}
|
||||
defer raw.Close()
|
||||
|
||||
tc, ok := raw.(*tls.Conn)
|
||||
if !ok {
|
||||
return nil, errors.New("dial returned non-tls conn")
|
||||
}
|
||||
state := tc.ConnectionState()
|
||||
if len(state.PeerCertificates) == 0 {
|
||||
return nil, errors.New("peer presented no certificate")
|
||||
}
|
||||
leaf := state.PeerCertificates[0]
|
||||
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: leaf.Raw})
|
||||
return &PeerCertSample{
|
||||
Cert: leaf,
|
||||
CertPEM: pemBytes,
|
||||
Fingerprint: fingerprintOf(leaf),
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user