WebSocket Interceptors

WebSocket interceptors allow you to process and modify WebSocket data as it flows through your server using a middleware pattern. This guide shows how to implement custom interceptors for WebSocket communication.

Creating Custom Interceptors

To create a custom WebSocket interceptor, implement the Interceptor interface:
from polyflux.core.server.interception.base import Interceptor
from polyflux.websocket import WebSocketInbound, WebSocketOutbound

class CustomWebSocketInterceptor(Interceptor[WebSocketInbound, WebSocketOutbound]):
    async def intercept_inbound(self, inbound: WebSocketInbound, next_interceptor) -> WebSocketInbound:
        """Process inbound WebSocket data before it reaches handlers."""
        processed = inbound.copy()
        
        # Add your custom logic here BEFORE calling next
        # For example, logging, authentication, data transformation, etc.
        
        # Call next_interceptor to continue the chain
        result = await next_interceptor(processed)
        
        # Optionally add logic AFTER the chain completes
        return result

    async def intercept_outbound(self, outbound: WebSocketOutbound, next_interceptor) -> WebSocketOutbound:
        """Process outbound WebSocket data before it's sent to clients."""
        processed = outbound.copy()
        
        # Add your custom logic here BEFORE calling next
        # For example, logging, data transformation, compression, etc.
        
        # Call next_interceptor to continue the chain
        result = await next_interceptor(processed)
        
        # Optionally add logic AFTER the chain completes
        return result

Flow Control Patterns

Continue Chain (Standard)

async def intercept_inbound(self, inbound: WebSocketInbound, next_interceptor) -> WebSocketInbound:
    # Process data
    processed = inbound.copy()
    processed["authenticated"] = True
    
    # Continue to next interceptor
    return await next_interceptor(processed)

Short-Circuit Chain

async def intercept_inbound(self, inbound: WebSocketInbound, next_interceptor) -> WebSocketInbound:
    # Check some condition
    if inbound.get("message_type") == "forbidden":
        # Don't call next_interceptor - short-circuit the chain
        return {"error": "Forbidden message type", "status": 403}
    
    # Continue normally
    return await next_interceptor(inbound)

Conditional Processing

async def intercept_inbound(self, inbound: WebSocketInbound, next_interceptor) -> WebSocketInbound:
    message_type = inbound.get("message_type", "")
    target = inbound.get("target", "")
    
    if message_type == "auth":
        # Special handling for auth messages
        authenticated_data = await self.authenticate(inbound)
        return await next_interceptor(authenticated_data)
    elif target == "admin":
        # Admin messages need special validation
        if not self.is_admin_authorized(inbound):
            return {"error": "Admin access denied", "status": 403}
        return await next_interceptor(inbound)
    else:
        # Regular processing
        return await next_interceptor(inbound)

Adding Interceptors to Your Server

Once you’ve created your custom interceptor, add it to your WebSocket server:
from polyflux.websocket import WebSocketServer

# Create your server
server = WebSocketServer()

# Create and add your interceptor
interceptor = CustomWebSocketInterceptor()
server.add_interceptor(interceptor)

Example: Logging Interceptor

Here’s a practical example of a logging interceptor using the new middleware pattern:
import logging
import time
from polyflux.core.server.interception.base import Interceptor
from polyflux.websocket import WebSocketInbound, WebSocketOutbound

logger = logging.getLogger(__name__)

class LoggingInterceptor(Interceptor[WebSocketInbound, WebSocketOutbound]):
    async def intercept_inbound(self, inbound: WebSocketInbound, next_interceptor) -> WebSocketInbound:
        message_type = inbound.get("message_type", "unknown")
        data_size = len(str(inbound.get("data", "")))
        start_time = time.time()
        
        logger.info(f"Processing inbound: type={message_type}, size={data_size}")
        
        # Continue the chain
        result = await next_interceptor(inbound)
        
        # Log completion time
        duration = time.time() - start_time
        logger.info(f"Completed inbound processing in {duration:.3f}s")
        
        return result

    async def intercept_outbound(self, outbound: WebSocketOutbound, next_interceptor) -> WebSocketOutbound:
        message_type = outbound.get("message_type", "unknown")
        target = outbound.get("target", "unknown")
        
        logger.info(f"Processing outbound: type={message_type}, target={target}")
        
        # Continue the chain
        result = await next_interceptor(outbound)
        
        logger.info(f"Completed outbound processing")
        
        return result

Advanced Examples

Authentication Interceptor with Short-Circuiting

class AuthInterceptor(Interceptor[WebSocketInbound, WebSocketOutbound]):
    def __init__(self):
        self.valid_tokens = {"token123": "user1", "token456": "user2"}
    
    async def intercept_inbound(self, inbound: WebSocketInbound, next_interceptor) -> WebSocketInbound:
        # Skip auth for certain message types
        if inbound.get("message_type") == "ping":
            return await next_interceptor(inbound)
        
        # Check authentication
        auth_token = inbound.get("auth_token")
        if not auth_token or auth_token not in self.valid_tokens:
            # Short-circuit - don't call next_interceptor
            return {
                "message_type": "error",
                "data": "Authentication required",
                "status": 401
            }
        
        # Add user info to message
        processed = inbound.copy()
        processed["user_id"] = self.valid_tokens[auth_token]
        
        return await next_interceptor(processed)
    
    async def intercept_outbound(self, outbound: WebSocketOutbound, next_interceptor) -> WebSocketOutbound:
        # Just pass through outbound messages
        return await next_interceptor(outbound)

Rate Limiting Interceptor

import time
from collections import defaultdict

class RateLimitInterceptor(Interceptor[WebSocketInbound, WebSocketOutbound]):
    def __init__(self, max_requests=10, window_seconds=60):
        self.max_requests = max_requests
        self.window_seconds = window_seconds
        self.request_counts = defaultdict(list)
    
    async def intercept_inbound(self, inbound: WebSocketInbound, next_interceptor) -> WebSocketInbound:
        user_id = inbound.get("user_id", "anonymous")
        current_time = time.time()
        
        # Clean old requests
        self.request_counts[user_id] = [
            req_time for req_time in self.request_counts[user_id]
            if current_time - req_time < self.window_seconds
        ]
        
        # Check rate limit
        if len(self.request_counts[user_id]) >= self.max_requests:
            # Rate limited - short-circuit
            return {
                "message_type": "error", 
                "data": f"Rate limit exceeded: {self.max_requests} requests per {self.window_seconds}s",
                "status": 429
            }
        
        # Record this request
        self.request_counts[user_id].append(current_time)
        
        # Continue processing
        return await next_interceptor(inbound)
    
    async def intercept_outbound(self, outbound: WebSocketOutbound, next_interceptor) -> WebSocketOutbound:
        return await next_interceptor(outbound)

Data Structures

WebSocketInbound

Incoming message structure:
  • message_type: Type of the message (e.g., “text”, “binary”, “close”, “ping”, “pong”) - Required
  • data: The actual message content (bytes or str) - Required
  • close_code: Close code for close messages (optional)
  • close_reason: Human-readable close reason (optional)
  • is_final: Whether this is the final frame in a fragmented message (optional)

WebSocketOutbound

Outgoing message structure:
  • target: WebSocket URL (ws:// or wss://) - Required (inherited from Outbound)
  • message_type: Type of the message to send (e.g., “text”, “binary”, “close”, “ping”, “pong”) - Required
  • data: The message content to send (bytes or str) - Required
  • headers: HTTP headers for WebSocket handshake (optional)
  • subprotocols: List of WebSocket subprotocols (optional)
  • close_code: Close code for close messages (optional)
  • close_reason: Human-readable close reason (optional)

Best Practices

  1. Always copy data: Use .copy() when modifying inbound/outbound data to avoid side effects
  2. Call next_interceptor: Always call await next_interceptor(data) unless you want to short-circuit
  3. Handle errors gracefully: Wrap your interceptor logic in try-catch blocks
  4. Keep processing lightweight: Interceptors run for every message, so avoid heavy operations
  5. Use conditional processing: Check message_type, target, or data content to handle different scenarios
  6. Short-circuit when needed: Return directly without calling next_interceptor() to stop the chain
  7. Before/after logic: Add logic before next_interceptor() for pre-processing, after for post-processing
  8. Chain order matters: Interceptors execute in registration order for inbound, reverse order for outbound

Context Sharing

Interceptors can share state using the InterceptionContext (accessed via ContextVar):
from datetime import datetime
from polyflux.core.server.interception.utils import get_context

class UserContextInterceptor(Interceptor[WebSocketInbound, WebSocketOutbound]):
    async def intercept_inbound(self, inbound: WebSocketInbound, next_interceptor) -> WebSocketInbound:
        # Get shared context
        context = get_context()
        if context:
            context.metadata['user_id'] = inbound.get('user_id')
            context.metadata['connection_id'] = inbound.get('connection_id')
            context.metadata['request_time'] = datetime.now()
        
        return await next_interceptor(inbound)

class LoggingInterceptor(Interceptor[WebSocketInbound, WebSocketOutbound]):
    async def intercept_inbound(self, inbound: WebSocketInbound, next_interceptor) -> WebSocketInbound:
        context = get_context()
        user_id = context.metadata.get('user_id', 'anonymous') if context else 'unknown'
        chain_id = context.chain_id if context else 'unknown'
        
        logger.info(f"Processing message for user: {user_id}, chain: {chain_id}")
        
        return await next_interceptor(inbound)