Source code for padrelay.security.rate_limiting

"""
Rate limiting utilities for the PadRelay
"""
from collections import defaultdict
from datetime import datetime, timedelta
from typing import DefaultDict, Dict, Tuple
from ..core.logging_utils import get_logger
from ..server.constants import DEFAULT_BLOCK_DURATION

logger = get_logger(__name__)

[docs] class ConnectionTracker: """ Track active connections and implement rate limiting """
[docs] def __init__( self, max_connections: int = 1, rate_limit_window: int = 60, max_requests: int = 100, block_duration: int = DEFAULT_BLOCK_DURATION, ) -> None: """ Initialize connection tracker Args: max_connections: Maximum number of simultaneous connections rate_limit_window: Time window for rate limiting in seconds max_requests: Maximum number of requests per window """ self.clients: DefaultDict[Tuple[str, int], Dict] = defaultdict(lambda: { 'last_seen': datetime.now(), 'authenticated': False, 'last_requests': [] }) self.max_connections = max_connections self.rate_limit_window = rate_limit_window self.max_requests = max_requests self.active_connections = 0 self.block_duration = block_duration self.blocked_clients: Dict[Tuple[str, int], datetime] = {}
[docs] def can_connect(self, addr: Tuple[str, int]) -> bool: """ Check if a new connection is allowed Args: addr: Client address tuple Returns: bool: True if connection is allowed, False otherwise """ self._cleanup() return self.active_connections < self.max_connections
[docs] def is_rate_limited(self, addr: Tuple[str, int]) -> bool: """Check if client is being rate limited or blocked.""" self._cleanup() now = datetime.now() # Check existing block if addr in self.blocked_clients: if now < self.blocked_clients[addr]: return True else: del self.blocked_clients[addr] client = self.clients[addr] # Remove expired requests client['last_requests'] = [t for t in client['last_requests'] if now - t < timedelta(seconds=self.rate_limit_window)] # Add current request client['last_requests'].append(now) client['last_seen'] = now # Check if rate limit exceeded if len(client['last_requests']) > self.max_requests: self.blocked_clients[addr] = now + timedelta(seconds=self.block_duration) return True return False
[docs] def is_blocked(self, addr: Tuple[str, int]) -> bool: """Check if a client is currently blocked""" self._cleanup() now = datetime.now() if addr in self.blocked_clients and now < self.blocked_clients[addr]: return True return False
[docs] def authenticate(self, addr: Tuple[str, int]) -> None: """ Mark a client as authenticated Args: addr: Client address tuple """ if not self.clients[addr]['authenticated']: self.active_connections += 1 self.clients[addr]['authenticated'] = True self.clients[addr]['last_seen'] = datetime.now() logger.info(f"Client {addr} authenticated. Active connections: {self.active_connections}")
[docs] def disconnect(self, addr: Tuple[str, int]) -> None: """ Handle client disconnection Args: addr: Client address tuple """ if addr in self.clients and self.clients[addr]['authenticated']: self.active_connections -= 1 self.clients[addr]['authenticated'] = False logger.info(f"Client {addr} disconnected. Active connections: {self.active_connections}")
def _cleanup(self) -> None: """Clean up expired client and block records""" now = datetime.now() expired = [addr for addr, data in self.clients.items() if now - data['last_seen'] > timedelta(minutes=5)] for addr in expired: if self.clients[addr]['authenticated']: self.active_connections -= 1 del self.clients[addr] expired_blocks = [addr for addr, expiry in self.blocked_clients.items() if now > expiry] for addr in expired_blocks: del self.blocked_clients[addr]