Sunday, September 21, 2025

Unified Metamaterial-IO server+client in one Python script.

 

 """    Metamaterial SLS/Laser Sintering Unified Python Script - so our nuclear reactor assemblies and homebrew chemists can create safer hotboxes and technicians can endure to enjoy new material science assays. ♥ A.   """

 #!/usr/bin/env python3
"""
Unified Metamaterial-IO server+client in one Python script.

Implements the io_dist_full.c wire protocol (big-endian, len-prefixed)
with the same op-codes:
  - 0: interpolate triples (fx, fy, fz) -> [Y][Z] over shared interpolation x
  - 1: differentiate triples (fx, fy, fz) -> [dY/dx][dZ/dx] on original per-curve x
  - 2: interpolate Y-only (fx, fy) -> [Y] over shared interpolation x
  - 4: integrate triples (fx, fy, fz) -> cumulative trapezoidal [Y_int][Z_int] on original x

Design goals:
  - Fully self-contained: one file, no external services required.
  - Numpy-accelerated math paths for good performance.
  - Protocol compatible with the C reference: big-endian floats, u8 op, u32 sizes, len-prefixed outputs.
  - Threaded server with graceful shutdown on Ctrl+C.

Usage:
  - Start server:
      python metamaterial_io.py --server 0.0.0.0 5000
  - Run client demo (op=2 interpolation of a single curve):
      python metamaterial_io.py --client 127.0.0.1 5000
"""

import argparse
import math
import os
import socket
import struct
import sys
import threading
import time
from typing import List, Optional, Tuple

try:
    import numpy as np
except ImportError:
    print("This script requires numpy (pip install numpy).", file=sys.stderr)
    sys.exit(1)

# ========== Wire helpers (big-endian) ==========

def be_u8_recv(sock: socket.socket) -> int:
    b = sock.recv(1)
    if len(b) != 1:
        raise ConnectionError("read u8 failed")
    return b[0]

def be_u8_send(sock: socket.socket, v: int) -> None:
    sock.sendall(bytes([v & 0xFF]))

def be_u32_recv(sock: socket.socket) -> int:
    b = recv_all(sock, 4)
    return struct.unpack("!I", b)[0]

def be_u32_send(sock: socket.socket, v: int) -> None:
    sock.sendall(struct.pack("!I", v & 0xFFFFFFFF))

def recv_all(sock: socket.socket, n: int) -> bytes:
    buf = bytearray()
    while len(buf) < n:
        chunk = sock.recv(n - len(buf))
        if not chunk:
            raise ConnectionError("recv timeout/closed")
        buf.extend(chunk)
    return bytes(buf)

def be_f32_array_recv(sock: socket.socket, count: int) -> np.ndarray:
    if count == 0:
        return np.empty((0,), dtype=np.float32)
    raw = recv_all(sock, count * 4)
    # Big-endian float32 to native
    arr = np.frombuffer(raw, dtype=">f4").astype(np.float32, copy=True)
    return arr

def len_prefixed_bytes_send(sock: socket.socket, payload: bytes) -> None:
    be_u32_send(sock, len(payload))
    if payload:
        sock.sendall(payload)

def len_prefixed_error(sock: socket.socket, msg: str) -> None:
    data = msg.encode("utf-8", errors="replace")
    len_prefixed_bytes_send(sock, data)

def len_prefixed_f32_array_send(sock: socket.socket, arr: np.ndarray) -> None:
    # Convert to big-endian f32 bytes and send with u32 length prefix
    if arr is None:
        be_u32_send(sock, 0)
        return
    a = np.asarray(arr, dtype=np.float32)
    be = a.astype(">f4", copy=False).tobytes(order="C")
    len_prefixed_bytes_send(sock, be)

# ========== Math kernels (numpy) ==========

def _sort_by_x(x: np.ndarray, *ys: np.ndarray) -> Tuple[np.ndarray, List[np.ndarray]]:
    idx = np.argsort(x, kind="stable")
    xs = x[idx]
    youts = []
    for y in ys:
        youts.append(y[idx])
    return xs, youts

def _interp_shared(xs: np.ndarray, ys: np.ndarray, xq: np.ndarray) -> np.ndarray:
    # linear interpolation with edge handling (hold edges)
    # assumes xs strictly increasing; if duplicates exist, consolidate by stable unique
    xsu, uniq_idx = np.unique(xs, return_index=True)
    if xsu.shape[0] < 2:
        # not enough unique points
        return np.full_like(xq, ys[uniq_idx[0]] if xsu.shape[0] == 1 else 0.0, dtype=np.float32)
    ysu = ys[uniq_idx]
    yq = np.interp(xq, xsu, ysu).astype(np.float32)
    return yq

def _differentiate_curve(xs: np.ndarray, ys: np.ndarray) -> np.ndarray:
    n = xs.shape[0]
    if n < 2:
        return np.zeros((n,), dtype=np.float32)
    dy = np.empty((n,), dtype=np.float32)
    # forward/backward for edges, central for interior
    dy[0] = (ys[1] - ys[0]) / (xs[1] - xs[0])
    dy[-1] = (ys[-1] - ys[-2]) / (xs[-1] - xs[-2])
    if n > 2:
        # central differences with nonuniform spacing: (y[i+1]-y[i-1])/(x[i+1]-x[i-1])
        dy[1:-1] = (ys[2:] - ys[:-2]) / (xs[2:] - xs[:-2])
    return dy

def _integrate_trap(xs: np.ndarray, ys: np.ndarray) -> np.ndarray:
    n = xs.shape[0]
    out = np.zeros((n,), dtype=np.float32)
    if n < 2:
        return out
    dx = xs[1:] - xs[:-1]
    # cumulative trapezoid
    acc = np.cumsum(0.5 * dx * (ys[1:] + ys[:-1]), dtype=np.float64).astype(np.float32)
    out[1:] = acc
    return out

# ========== Request handling (server) ==========

class OpInfo:
    def __init__(self, code: int, needs_triple: bool, needs_interp: bool):
        self.code = code
        self.needs_triple = needs_triple
        self.needs_interp = needs_interp

OP_TABLE = {
    0: OpInfo(0, needs_triple=True,  needs_interp=True),   # interp_triple
    1: OpInfo(1, needs_triple=True,  needs_interp=False),  # differentiate
    2: OpInfo(2, needs_triple=False, needs_interp=True),   # interp_yonly
    4: OpInfo(4, needs_triple=True,  needs_interp=False),  # integrate
}

def handle_connection(conn: socket.socket, addr: Tuple[str, int]) -> None:
    conn.settimeout(60.0)
    try:
        op = be_u8_recv(conn)
        if op not in OP_TABLE:
            len_prefixed_error(conn, f"Error: Unsupported op {op}")
            return
        info = OP_TABLE[op]
        N = be_u32_recv(conn)
        if N == 0 or N > 1_000_000:
            len_prefixed_error(conn, "Error: Invalid N")
            return

        fx_list: List[np.ndarray] = []
        fy_list: List[np.ndarray] = []
        fz_list: List[np.ndarray] = []

        nx: List[int] = []

        for i in range(N):
            nfx = be_u32_recv(conn); fx = be_f32_array_recv(conn, nfx)
            nfy = be_u32_recv(conn); fy = be_f32_array_recv(conn, nfy)
            if info.needs_triple:
                nfz = be_u32_recv(conn); fz = be_f32_array_recv(conn, nfz)
                if not (nfx == nfy == nfz) or nfx < 3:
                    len_prefixed_error(conn, "Error: length mismatch or <3")
                    return
                fz_list.append(fz)
            else:
                if nfx != nfy or nfx < 3:
                    len_prefixed_error(conn, "Error: length mismatch or <3")
                    return
            fx_list.append(fx); fy_list.append(fy); nx.append(nfx)

        xinterp: Optional[np.ndarray] = None
        M = 0
        if info.needs_interp:
            M = be_u32_recv(conn)
            xinterp = be_f32_array_recv(conn, M)
            if M == 0 or M > 10_000_000:
                len_prefixed_error(conn, "Error: Invalid M")
                return

        # Compute outputs
        if info.needs_interp:
            # output: N * M * (1 or 2)
            outY = np.empty((N, M), dtype=np.float32)
            outZ = np.empty((N, M), dtype=np.float32) if info.needs_triple else None
            for i in range(N):
                xs, (yY,) = _sort_by_x(fx_list[i], fy_list[i])
                outY[i, :] = _interp_shared(xs, yY, xinterp)
                if info.needs_triple:
                    xs2, (yZ,) = _sort_by_x(fx_list[i], fz_list[i])
                    outZ[i, :] = _interp_shared(xs2, yZ, xinterp)
            # serialize row-major [all Y curves][all Z curves]
            if info.needs_triple:
                payload = np.concatenate([outY.reshape(-1), outZ.reshape(-1)])
            else:
                payload = outY.reshape(-1)
            len_prefixed_f32_array_send(conn, payload)

        elif op == 1:
            # differentiate per curve; output: sum(nx) * (1 or 2)
            partsY = []
            partsZ = [] if info.needs_triple else None
            for i in range(N):
                xs, (yY,) = _sort_by_x(fx_list[i], fy_list[i])
                partsY.append(_differentiate_curve(xs, yY))
                if info.needs_triple:
                    xs2, (yZ,) = _sort_by_x(fx_list[i], fz_list[i])
                    partsZ.append(_differentiate_curve(xs2, yZ))
            if info.needs_triple:
                payload = np.concatenate(partsY + partsZ)
            else:
                payload = np.concatenate(partsY)
            len_prefixed_f32_array_send(conn, payload)

        elif op == 4:
            # integrate (cumulative trap) per curve; output: sum(nx) * (1 or 2)
            partsY = []
            partsZ = [] if info.needs_triple else None
            for i in range(N):
                xs, (yY,) = _sort_by_x(fx_list[i], fy_list[i])
                partsY.append(_integrate_trap(xs, yY))
                if info.needs_triple:
                    xs2, (yZ,) = _sort_by_x(fx_list[i], fz_list[i])
                    partsZ.append(_integrate_trap(xs2, yZ))
            if info.needs_triple:
                payload = np.concatenate(partsY + partsZ)
            else:
                payload = np.concatenate(partsY)
            len_prefixed_f32_array_send(conn, payload)

    except Exception as e:
        try:
            len_prefixed_error(conn, f"Error: {e}")
        except Exception:
            pass
    finally:
        try:
            conn.shutdown(socket.SHUT_RDWR)
        except Exception:
            pass
        conn.close()

# ========== Server bootstrap ==========

def run_server(host: str, port: int, accept_threads: int = 2) -> None:
    srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    srv.bind((host, port))
    srv.listen(512)
    print(f"[server] listening on {host}:{port}", flush=True)

    stop_evt = threading.Event()

    def accept_loop():
        while not stop_evt.is_set():
            try:
                conn, addr = srv.accept()
                conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
                t = threading.Thread(target=handle_connection, args=(conn, addr), daemon=True)
                t.start()
            except OSError:
                break

    threads = [threading.Thread(target=accept_loop, daemon=True) for _ in range(accept_threads)]
    for t in threads:
        t.start()

    try:
        while True:
            time.sleep(0.5)
    except KeyboardInterrupt:
        print("\n[server] shutting down...", flush=True)
    finally:
        stop_evt.set()
        try:
            srv.close()
        except Exception:
            pass
        for t in threads:
            t.join(timeout=1.0)
        print("[server] stopped.", flush=True)

# ========== Integrated client demo (op=2) ==========

def client_demo(host: str, port: int) -> int:
    # One curve: fx=[1..5], fy ~ increasing, xi=[1.5, 3.5]
    fx = np.array([1,2,3,4,5], dtype=np.float32)
    fy = np.array([10,12,15,19,25], dtype=np.float32)
    xi = np.array([1.5, 3.5], dtype=np.float32)

    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
    s.settimeout(10.0)
    s.connect((host, port))

    # op=2, N=1
    be_u8_send(s, 2)
    be_u32_send(s, 1)

    # curve 0: fx
    be_u32_send(s, fx.shape[0])
    s.sendall(fx.astype(">f4").tobytes(order="C"))
    # fy
    be_u32_send(s, fy.shape[0])
    s.sendall(fy.astype(">f4").tobytes(order="C"))
    # xinterp M
    be_u32_send(s, xi.shape[0])
    s.sendall(xi.astype(">f4").tobytes(order="C"))

    # receive len-prefixed payload
    nbytes = be_u32_recv(s)
    payload = recv_all(s, nbytes) if nbytes > 0 else b""
    s.close()

    if nbytes % 4 != 0:
        sys.stderr.write(payload.decode("utf-8", errors="replace") + "\n")
        return 1

    out = np.frombuffer(payload, dtype=">f4").astype(np.float32)
    print("Y_interp:", out.tolist())
    return 0

# ========== CLI ==========

def main():
    ap = argparse.ArgumentParser(description="Unified Metamaterial-IO server+client (Python)")
    sub = ap.add_subparsers(dest="mode", required=True)

    ap_srv = sub.add_parser("--server", help="Run server")
    ap_srv.add_argument("host", type=str, help="Bind host, e.g. 0.0.0.0")
    ap_srv.add_argument("port", type=int, help="Bind port, e.g. 5000")
    ap_srv.add_argument("--accept", type=int, default=2, help="Accept threads (default 2)")

    ap_cli = sub.add_parser("--client", help="Run client demo (op=2)")
    ap_cli.add_argument("host", type=str, help="Server host")
    ap_cli.add_argument("port", type=int, help="Server port")

    args = ap.parse_args()

    if args.mode == "--server":
        run_server(args.host, args.port, args.accept)
    elif args.mode == "--client":
        sys.exit(client_demo(args.host, args.port))
    else:
        ap.print_help()

if __name__ == "__main__":
    main()

No comments: