Skip to content

API Reference: SocketApp

socketspec.app.SocketApp

FastAPI-style WebSocket application entry point.

Source code in src\socketspec\app.py
class SocketApp:
    """FastAPI-style WebSocket application entry point."""

    def __init__(
        self,
        *,
        docs: bool = False,
        docs_url: str = "/socket-docs",
        docs_access_token: str | None = None,
        auth: AuthBackend | None = None,
        backend: Literal["memory"] | BackendAdapter = "memory",
        allowed_origins: list[str] | None = None,
        max_payload_size: int = DEFAULT_MAX_PAYLOAD_SIZE,
        rate_limit: RateLimit | None = None,
        session: SessionConfig | None = None,
        rooms: list[Room] | None = None,
        namespace: str = "/",
        debug: bool = False,
        debug_url: str = "/socket-debug",
    ) -> None:
        origins = allowed_origins if allowed_origins is not None else ["*"]
        self._namespace = namespace
        self._registry = EventRegistry()
        self._backend = self._build_backend(backend)
        self._manager = ConnectionManager(self._backend)
        self._session_mgr = SessionManager(session or SessionConfig())
        self._origin_validator = OriginValidator(origins)
        self._rate_limiter = TokenBucket(rate_limit) if rate_limit else None
        self._di_resolver = DependencyResolver()
        self._middlewares: list[MiddlewareFunc] = []
        self._router = EventRouter(self._registry, self._di_resolver)
        self.rooms = RoomManager(self._backend, self._manager)
        self._auth = auth
        self._max_payload_size = max_payload_size
        self._docs = docs
        self._docs_url = docs_url
        self._docs_access_token = docs_access_token
        self._debug = debug
        self._debug_url = debug_url
        self._debug_queue: asyncio.Queue[dict[str, Any]] | None = (
            asyncio.Queue(maxsize=500) if debug else None
        )
        self._lifecycle_hooks: dict[str, list[LifecycleHook]] = {
            "connect": [],
            "disconnect": [],
            "error": [],
            "room_join": [],
            "room_leave": [],
        }
        self._compiled_chain: CompiledHandler = self._router.dispatch

        for room in rooms or []:
            self.rooms.register_static(room.name)

        self.rooms.register_join_hook(self._run_room_join_hooks)

    def on(
        self,
        event: EventName,
        *,
        description: str = "",
        tags: list[str] | None = None,
        emits: list[Emits] | None = None,
        broadcasts: list[Broadcasts] | None = None,
        ordered: bool = False,
        executor: bool = False,
        deprecated: bool = False,
    ) -> Callable[[HandlerFunc], HandlerFunc]:
        """Register an event handler."""

        def decorator(func: HandlerFunc) -> HandlerFunc:
            definition = EventDefinition(
                name=event,
                namespace=self._namespace,
                handler=func,
                payload_model=self._infer_payload_model(func),
                emits=emits or [],
                broadcasts=broadcasts or [],
                description=description,
                tags=tags or [],
                ordered=ordered,
                executor=executor,
                deprecated=deprecated,
            )
            self._registry.register(definition)
            return func

        return decorator

    def middleware(self, func: MiddlewareFunc) -> MiddlewareFunc:
        """Register middleware that runs before each event handler."""
        self._middlewares.append(func)
        return func

    def on_connect(self, func: LifecycleHook) -> LifecycleHook:
        """Register a hook that runs after a connection is established."""
        self._lifecycle_hooks["connect"].append(func)
        return func

    def on_disconnect(self, func: LifecycleHook) -> LifecycleHook:
        """Register a hook that runs after a connection is torn down."""
        self._lifecycle_hooks["disconnect"].append(func)
        return func

    def on_error(self, func: LifecycleHook) -> LifecycleHook:
        """Register a hook that runs when handler errors are emitted."""
        self._lifecycle_hooks["error"].append(func)
        return func

    def on_room_join(self, func: LifecycleHook) -> LifecycleHook:
        """Register a hook that runs after a connection joins a room."""
        self._lifecycle_hooks["room_join"].append(func)
        return func

    def on_room_leave(self, func: LifecycleHook) -> LifecycleHook:
        """Register a hook that runs after a connection leaves a room."""
        self._lifecycle_hooks["room_leave"].append(func)
        return func

    def room_guard(
        self,
        pattern: str,
    ) -> Callable[[RoomGuardFunc], RoomGuardFunc]:
        """Protect a room name pattern with a permission function."""

        def decorator(func: RoomGuardFunc) -> RoomGuardFunc:
            self.rooms.register_guard(pattern, func)
            return func

        return decorator

    async def handle_connect(
        self,
        raw_socket: RawSocket,
        headers: dict[str, str],
        query_params: dict[str, str],
    ) -> Connection | None:
        """Accept and register a new WebSocket connection."""
        normalized_headers = {key.lower(): value for key, value in headers.items()}
        origin = normalized_headers.get("origin")
        if not self._origin_validator.is_allowed(origin):
            await raw_socket.close(code=HTTP_FORBIDDEN_CLOSE_CODE)
            return None

        identity = Identity()
        if self._auth is not None:
            auth_result = await self._auth.authenticate(headers, query_params)
            if auth_result is None:
                await self._emit_raw(
                    raw_socket,
                    "AUTH_ERROR",
                    "__connect__",
                    "Authentication failed",
                )
                await raw_socket.close(code=AUTH_FAILURE_CLOSE_CODE)
                return None
            identity = auth_result

        conn = self._build_connection(
            raw_socket,
            identity,
            headers,
            query_params,
        )
        await self._manager.connect(conn)
        await self._session_mgr.start(conn)

        for hook in self._lifecycle_hooks["connect"]:
            await hook(conn)

        logger.info("Connection %s established", conn.id)
        self._debug_log({
            "type": "connect",
            "conn_id": conn.id,
            "user_id": conn.identity.user_id,
            "ts": datetime.now(timezone.utc).isoformat(),
        })
        return conn

    async def handle_event(
        self,
        conn: Connection,
        raw_message: str | bytes,
    ) -> None:
        """Handle one inbound client message."""
        size = (
            len(raw_message)
            if isinstance(raw_message, bytes)
            else len(raw_message.encode())
        )
        if size > self._max_payload_size:
            await conn.emit("__error__", {"code": "PAYLOAD_TOO_LARGE"})
            return

        try:
            if isinstance(raw_message, bytes):
                raw_message = raw_message.decode()
            data = json.loads(raw_message)
            event = data["event"]
            payload = data.get("payload", {})
            if not isinstance(payload, dict):
                raise TypeError("payload must be an object")
        except (json.JSONDecodeError, KeyError, TypeError, ValueError):
            await conn.emit(
                "__error__",
                {
                    "code": "VALIDATION_ERROR",
                    "message": "Invalid message format",
                },
            )
            return

        if self._rate_limiter is not None:
            allowed = await self._rate_limiter.consume(conn.id)
            if not allowed:
                await conn.emit("__error__", {"code": "RATE_LIMIT_ERROR"})
                return

        await self._session_mgr.touch(conn)

        # __pong__ is a system frame — never route it to the event registry.
        if event == "__pong__":
            self._session_mgr.signal_pong(conn.id)
            return

        self._debug_log({
            "type": "event",
            "conn_id": conn.id,
            "event": event,
            "payload_size": size,
            "ts": datetime.now(timezone.utc).isoformat(),
        })
        await self._compiled_chain(conn, event, payload)

    async def handle_disconnect(
        self,
        conn: Connection,
        reason: str = "client_close",
    ) -> None:
        """Tear down a connection and run lifecycle cleanup."""
        await self._session_mgr.stop(conn.id)

        for room in list(conn.rooms):
            await self.rooms.leave(conn, room)
            for hook in self._lifecycle_hooks["room_leave"]:
                await hook(conn, room)

        await self._manager.disconnect(conn)

        for hook in self._lifecycle_hooks["disconnect"]:
            await hook(conn, reason)

        if self._rate_limiter is not None:
            await self._rate_limiter.remove(conn.id)

        await self._router.cleanup(conn.id)
        logger.info("Connection %s disconnected: %s", conn.id, reason)
        self._debug_log({
            "type": "disconnect",
            "conn_id": conn.id,
            "reason": reason,
            "ts": datetime.now(timezone.utc).isoformat(),
        })

    def _startup_validate(self) -> None:
        """Run startup validation and compile middleware before serving."""
        self._registry.validate()
        self._compiled_chain = MiddlewareChain(self._middlewares).compile(
            self._router.dispatch
        )

    async def _graceful_shutdown(self) -> None:
        """Release backend resources on application shutdown."""
        await self._backend.close()

    async def _run_room_join_hooks(self, conn: Connection, room: str) -> None:
        for hook in self._lifecycle_hooks["room_join"]:
            await hook(conn, room)

    def _debug_log(self, entry: dict[str, Any]) -> None:
        """Append a debug log entry to the queue. No-op when debug=False."""
        if self._debug_queue is None:
            return
        try:
            self._debug_queue.put_nowait(entry)
        except asyncio.QueueFull:
            try:
                self._debug_queue.get_nowait()  # discard oldest
                self._debug_queue.put_nowait(entry)
            except Exception:  # noqa: BLE001
                pass

    def _build_backend(
        self,
        backend: Literal["memory"] | BackendAdapter,
    ) -> BackendAdapter:
        if backend == "memory":
            return MemoryBackend()
        return backend

    def _build_connection(
        self,
        raw_socket: RawSocket,
        identity: Identity,
        headers: dict[str, str],
        query_params: dict[str, str],
    ) -> Connection:
        now = datetime.now(timezone.utc)
        session = SessionInfo(
            started_at=now,
            expires_at=None,
            token_expires_at=identity.token_expires_at,
        )
        return Connection(
            id=str(uuid.uuid4()),
            raw_socket=raw_socket,
            identity=identity,
            session=session,
            connected_at=now,
            last_active=now,
            headers=headers,
            query_params=query_params,
            namespace=self._namespace,
        )

    async def _emit_raw(
        self,
        raw_socket: RawSocket,
        code: str,
        event: EventName,
        message: str,
    ) -> None:
        await raw_socket.send_json(
            {
                "event": "__error__",
                "payload": {
                    "code": code,
                    "event": event,
                    "message": message,
                    "request_id": str(uuid.uuid4()),
                    "details": {},
                },
            }
        )

    def _infer_payload_model(self, func: Callable[..., Any]) -> type[BaseModel] | None:
        params = list(inspect.signature(func).parameters.values())
        if len(params) < 2:
            return None
        param = params[1]
        annotation: Any = param.annotation
        if annotation is inspect.Parameter.empty:
            return None
        if isinstance(annotation, str):
            namespace = dict(func.__globals__)
            if func.__closure__:
                for name, cell in zip(
                    func.__code__.co_freevars,
                    func.__closure__,
                    strict=False,
                ):
                    namespace[name] = cell.cell_contents
            try:
                annotation = eval(annotation, namespace)  # noqa: S307
            except Exception:
                annotation = None
        if annotation is None:
            try:
                hints = get_type_hints(func, globalns=func.__globals__)
                annotation = hints.get(param.name)
            except (NameError, TypeError):
                return None
        if annotation is None:
            return None
        if inspect.isclass(annotation) and issubclass(annotation, BaseModel):
            return annotation
        return None

handle_connect(raw_socket, headers, query_params) async

Accept and register a new WebSocket connection.

Source code in src\socketspec\app.py
async def handle_connect(
    self,
    raw_socket: RawSocket,
    headers: dict[str, str],
    query_params: dict[str, str],
) -> Connection | None:
    """Accept and register a new WebSocket connection."""
    normalized_headers = {key.lower(): value for key, value in headers.items()}
    origin = normalized_headers.get("origin")
    if not self._origin_validator.is_allowed(origin):
        await raw_socket.close(code=HTTP_FORBIDDEN_CLOSE_CODE)
        return None

    identity = Identity()
    if self._auth is not None:
        auth_result = await self._auth.authenticate(headers, query_params)
        if auth_result is None:
            await self._emit_raw(
                raw_socket,
                "AUTH_ERROR",
                "__connect__",
                "Authentication failed",
            )
            await raw_socket.close(code=AUTH_FAILURE_CLOSE_CODE)
            return None
        identity = auth_result

    conn = self._build_connection(
        raw_socket,
        identity,
        headers,
        query_params,
    )
    await self._manager.connect(conn)
    await self._session_mgr.start(conn)

    for hook in self._lifecycle_hooks["connect"]:
        await hook(conn)

    logger.info("Connection %s established", conn.id)
    self._debug_log({
        "type": "connect",
        "conn_id": conn.id,
        "user_id": conn.identity.user_id,
        "ts": datetime.now(timezone.utc).isoformat(),
    })
    return conn

handle_disconnect(conn, reason='client_close') async

Tear down a connection and run lifecycle cleanup.

Source code in src\socketspec\app.py
async def handle_disconnect(
    self,
    conn: Connection,
    reason: str = "client_close",
) -> None:
    """Tear down a connection and run lifecycle cleanup."""
    await self._session_mgr.stop(conn.id)

    for room in list(conn.rooms):
        await self.rooms.leave(conn, room)
        for hook in self._lifecycle_hooks["room_leave"]:
            await hook(conn, room)

    await self._manager.disconnect(conn)

    for hook in self._lifecycle_hooks["disconnect"]:
        await hook(conn, reason)

    if self._rate_limiter is not None:
        await self._rate_limiter.remove(conn.id)

    await self._router.cleanup(conn.id)
    logger.info("Connection %s disconnected: %s", conn.id, reason)
    self._debug_log({
        "type": "disconnect",
        "conn_id": conn.id,
        "reason": reason,
        "ts": datetime.now(timezone.utc).isoformat(),
    })

handle_event(conn, raw_message) async

Handle one inbound client message.

Source code in src\socketspec\app.py
async def handle_event(
    self,
    conn: Connection,
    raw_message: str | bytes,
) -> None:
    """Handle one inbound client message."""
    size = (
        len(raw_message)
        if isinstance(raw_message, bytes)
        else len(raw_message.encode())
    )
    if size > self._max_payload_size:
        await conn.emit("__error__", {"code": "PAYLOAD_TOO_LARGE"})
        return

    try:
        if isinstance(raw_message, bytes):
            raw_message = raw_message.decode()
        data = json.loads(raw_message)
        event = data["event"]
        payload = data.get("payload", {})
        if not isinstance(payload, dict):
            raise TypeError("payload must be an object")
    except (json.JSONDecodeError, KeyError, TypeError, ValueError):
        await conn.emit(
            "__error__",
            {
                "code": "VALIDATION_ERROR",
                "message": "Invalid message format",
            },
        )
        return

    if self._rate_limiter is not None:
        allowed = await self._rate_limiter.consume(conn.id)
        if not allowed:
            await conn.emit("__error__", {"code": "RATE_LIMIT_ERROR"})
            return

    await self._session_mgr.touch(conn)

    # __pong__ is a system frame — never route it to the event registry.
    if event == "__pong__":
        self._session_mgr.signal_pong(conn.id)
        return

    self._debug_log({
        "type": "event",
        "conn_id": conn.id,
        "event": event,
        "payload_size": size,
        "ts": datetime.now(timezone.utc).isoformat(),
    })
    await self._compiled_chain(conn, event, payload)

middleware(func)

Register middleware that runs before each event handler.

Source code in src\socketspec\app.py
def middleware(self, func: MiddlewareFunc) -> MiddlewareFunc:
    """Register middleware that runs before each event handler."""
    self._middlewares.append(func)
    return func

on(event, *, description='', tags=None, emits=None, broadcasts=None, ordered=False, executor=False, deprecated=False)

Register an event handler.

Source code in src\socketspec\app.py
def on(
    self,
    event: EventName,
    *,
    description: str = "",
    tags: list[str] | None = None,
    emits: list[Emits] | None = None,
    broadcasts: list[Broadcasts] | None = None,
    ordered: bool = False,
    executor: bool = False,
    deprecated: bool = False,
) -> Callable[[HandlerFunc], HandlerFunc]:
    """Register an event handler."""

    def decorator(func: HandlerFunc) -> HandlerFunc:
        definition = EventDefinition(
            name=event,
            namespace=self._namespace,
            handler=func,
            payload_model=self._infer_payload_model(func),
            emits=emits or [],
            broadcasts=broadcasts or [],
            description=description,
            tags=tags or [],
            ordered=ordered,
            executor=executor,
            deprecated=deprecated,
        )
        self._registry.register(definition)
        return func

    return decorator

on_connect(func)

Register a hook that runs after a connection is established.

Source code in src\socketspec\app.py
def on_connect(self, func: LifecycleHook) -> LifecycleHook:
    """Register a hook that runs after a connection is established."""
    self._lifecycle_hooks["connect"].append(func)
    return func

on_disconnect(func)

Register a hook that runs after a connection is torn down.

Source code in src\socketspec\app.py
def on_disconnect(self, func: LifecycleHook) -> LifecycleHook:
    """Register a hook that runs after a connection is torn down."""
    self._lifecycle_hooks["disconnect"].append(func)
    return func

on_error(func)

Register a hook that runs when handler errors are emitted.

Source code in src\socketspec\app.py
def on_error(self, func: LifecycleHook) -> LifecycleHook:
    """Register a hook that runs when handler errors are emitted."""
    self._lifecycle_hooks["error"].append(func)
    return func

on_room_join(func)

Register a hook that runs after a connection joins a room.

Source code in src\socketspec\app.py
def on_room_join(self, func: LifecycleHook) -> LifecycleHook:
    """Register a hook that runs after a connection joins a room."""
    self._lifecycle_hooks["room_join"].append(func)
    return func

on_room_leave(func)

Register a hook that runs after a connection leaves a room.

Source code in src\socketspec\app.py
def on_room_leave(self, func: LifecycleHook) -> LifecycleHook:
    """Register a hook that runs after a connection leaves a room."""
    self._lifecycle_hooks["room_leave"].append(func)
    return func

room_guard(pattern)

Protect a room name pattern with a permission function.

Source code in src\socketspec\app.py
def room_guard(
    self,
    pattern: str,
) -> Callable[[RoomGuardFunc], RoomGuardFunc]:
    """Protect a room name pattern with a permission function."""

    def decorator(func: RoomGuardFunc) -> RoomGuardFunc:
        self.rooms.register_guard(pattern, func)
        return func

    return decorator

Constructor Parameters

Parameter Type Default Description
docs bool False Enable the /socket-docs interactive UI
docs_url str "/socket-docs" URL path for the docs UI
docs_access_token str \| None None Bearer token required to access the docs UI
auth AuthBackend \| None None Authentication backend (JWTAuth, APIKeyAuth, or custom)
backend "memory" \| BackendAdapter "memory" Storage backend for connections and rooms
allowed_origins list[str] \| None ["*"] Allowed WebSocket origin headers
max_payload_size int 65536 Maximum inbound payload size in bytes
rate_limit RateLimit \| None None Token bucket rate limit per connection
session SessionConfig \| None SessionConfig() Heartbeat and TTL configuration
rooms list[Room] \| None None Static rooms to create at startup
namespace str "/" Namespace prefix for all events

Event Registration

@socket.on(event, *, ...)

Register an event handler.

@socket.on(
    "send_message",
    description="Send a chat message.",
    tags=["chat"],
    emits=[Emits("message_ack", model=AckModel)],
    broadcasts=[Broadcasts("new_message", room="chat:{room}", model=MsgModel)],
    ordered=False,
    deprecated=False,
)
async def handler(conn: Connection, payload: MyModel) -> None: ...
Parameter Type Default Description
event str Event name. Cannot be a reserved system event.
description str "" Shown in the docs UI
tags list[str] [] Group events in the docs UI
emits list[Emits] [] Metadata: events the handler emits back to sender
broadcasts list[Broadcasts] [] Metadata: events the handler broadcasts to rooms
ordered bool False If True, events from this connection are processed sequentially
deprecated bool False Shows ⚠ DEPRECATED badge in the docs UI

Lifecycle Decorators

@socket.on_connect

Runs after a connection is established. Receives conn: Connection.

@socket.on_disconnect

Runs after a connection closes. Receives conn: Connection, reason: str.

@socket.on_error

Runs when a handler raises an unhandled exception. Receives conn: Connection, error: Exception.

@socket.on_room_join

Runs after a connection joins a room. Receives conn: Connection, room: str.

@socket.on_room_leave

Runs after a connection leaves a room. Receives conn: Connection, room: str.

@socket.room_guard(pattern)

Register a room permission guard for a name pattern.

@socket.room_guard("admin:{section}")
async def guard(conn: Connection, section: str) -> bool:
    return conn.identity.role == "admin"

@socket.middleware

Register middleware that wraps every event handler.

@socket.middleware
async def log_events(conn, event, payload, next_handler):
    print(f"[{conn.id}] {event}")
    await next_handler(conn, event, payload)

System Events (Reserved)

These events are used internally. You cannot register handlers for them.

Event Direction Description
__connect__ Server → Client Connection established
__disconnect__ Server → Client Connection closing
__ping__ Server → Client Heartbeat ping
__pong__ Client → Server Heartbeat response
__error__ Server → Client Error envelope
__auth_expiring__ Server → Client JWT near expiry
__session_expiring__ Server → Client Session max-duration reached
__idle_warning__ Server → Client Idle timeout reached
__server_shutdown__ Server → Client Server is shutting down
__refresh_auth__ Client → Server Client requests token refresh

Error Codes

All errors are delivered in the __error__ envelope:

{
  "event": "__error__",
  "payload": {
    "code": "VALIDATION_ERROR",
    "message": "...",
    "request_id": "uuid",
    "details": {}
  }
}
Code Trigger
AUTH_ERROR Authentication failed at connect time
AUTH_EXPIRED JWT expired mid-session
HANDLER_ERROR Unhandled exception in event handler
IDLE_TIMEOUT Connection was idle too long
PAYLOAD_TOO_LARGE Payload exceeded max_payload_size
PERMISSION_ERROR Room guard returned False
RATE_LIMIT_ERROR Token bucket exhausted
ROOM_NOT_FOUND Broadcast to a non-existent room
SESSION_EXPIRED max_duration reached
UNKNOWN_EVENT No handler registered for this event
VALIDATION_ERROR JSON parse failure or Pydantic validation failure