Source code for padrelay.client.client_app

"""Client application"""
import asyncio
import time
import random
import json
import ssl
from datetime import datetime
import socket
from typing import Optional
from ..core.logging_utils import get_logger
from ..core.exceptions import AuthenticationError, ConnectionError
from ..protocol.constants import HEARTBEAT_INTERVAL, PROTOCOL_VERSION
from ..protocol.tcp import TCPProtocolHandler
from ..protocol.udp import UDPProtocolHandler
from ..security.auth import Authenticator
from ..security.tls_utils import create_client_ssl_context
from .constants import RECONNECT_DELAY
from .input import GamepadInput

logger = get_logger(__name__)


[docs] class UDPClientProtocol(asyncio.DatagramProtocol): """Datagram protocol that queues incoming packets"""
[docs] def __init__(self) -> None: self.queue: asyncio.Queue[bytes] = asyncio.Queue()
[docs] def datagram_received(self, data: bytes, addr) -> None: # type: ignore[override] self.queue.put_nowait(data)
[docs] def error_received(self, exc: Exception) -> None: # type: ignore[override] logger.error(f"UDP client error received: {exc}")
[docs] class VirtualGamepadClient: """Client controller"""
[docs] def __init__( self, server_ip: str, server_port: int, protocol: str, gamepad: GamepadInput, update_rate: int = 60, password: Optional[str] = None, config: Optional[object] = None, enable_tls: bool = False, ) -> None: """Construct the client and configure networking""" self.server_ip = server_ip self.server_port = server_port self.protocol = protocol.lower() self.gamepad = gamepad self.update_rate = update_rate self.config = config self.enable_tls = enable_tls and self.protocol == 'tcp' self.ssl_context: Optional[ssl.SSLContext] = None # Security self.authenticator = Authenticator(password) # Initialize TLS if enabled if self.enable_tls: try: self.ssl_context = create_client_ssl_context(verify_cert=False) logger.info("TLS/SSL encryption enabled") except Exception as e: logger.error(f"Failed to initialize TLS: {e}") logger.warning("Continuing without TLS encryption") self.enable_tls = False else: if self.protocol == 'tcp': logger.warning( "TLS/SSL encryption is DISABLED. " "Network traffic will be transmitted in plaintext. " "Use --enable-tls to enable encryption, or ensure you're on a trusted network." ) # Connection state self.connected = False self.running = False self.shutdown_event = asyncio.Event() # Tasks self.main_task = None self.heartbeat_task = None logger.info(f"Client initialized with {self.protocol.upper()} protocol") logger.info(f"Server: {self.server_ip}:{self.server_port}") logger.info(f"Update rate: {self.update_rate} Hz")
[docs] async def run(self) -> bool: """Run the client using the chosen protocol""" self.running = True try: if self.protocol == 'tcp': await self._run_tcp() else: await self._run_udp() return True except asyncio.CancelledError: logger.info("Client cancelled") return False except Exception as e: logger.error(f"Client error: {e}") return False finally: self.running = False
[docs] async def shutdown(self) -> bool: """Shut down tasks and clean up resources""" logger.info("Shutdown initiated...") self.shutdown_event.set() # Cancel heartbeat task if self.heartbeat_task and not self.heartbeat_task.done(): self.heartbeat_task.cancel() try: await self.heartbeat_task except asyncio.CancelledError: pass # Cancel main task if self.main_task and not self.main_task.done(): self.main_task.cancel() try: await self.main_task except asyncio.CancelledError: pass # Cleanup gamepad self.gamepad.cleanup() logger.info("Shutdown complete") return True
async def _run_tcp(self) -> None: """Main loop for TCP mode""" while not self.shutdown_event.is_set(): try: # Connect to server logger.info(f"Connecting to TCP server at {self.server_ip}:{self.server_port}...") reader, writer = await asyncio.open_connection( self.server_ip, self.server_port, ssl=self.ssl_context if self.enable_tls else None, server_hostname=self.server_ip if self.enable_tls else None, ) logger.info(f"Connected to server at {self.server_ip}:{self.server_port}") # Log TLS connection details if enabled if self.enable_tls: ssl_object = writer.get_extra_info('ssl_object') if ssl_object: cipher = ssl_object.cipher() tls_version = ssl_object.version() logger.info(f"TLS handshake successful") logger.debug(f"TLS connection details:") logger.debug(f" TLS version: {tls_version}") if cipher: logger.debug(f" Cipher: {cipher[0]} (bits: {cipher[2]})") logger.debug(f" Protocol: {cipher[1]}") else: logger.warning(f"TLS enabled but no SSL object found") # Initialize protocol handler tcp_handler = TCPProtocolHandler(reader, writer) # Optimize socket settings sock = writer.get_extra_info('socket') sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536) sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 65536) # Authenticate if not await self._tcp_auth(tcp_handler): logger.error("Authentication failed, retrying in 5 seconds...") await asyncio.sleep(RECONNECT_DELAY) continue # Start heartbeat task self.heartbeat_task = asyncio.create_task( self._tcp_heartbeat_loop(tcp_handler) ) # Set connection state self.connected = True # Main sending loop send_interval = 1.0 / self.update_rate last_send_time = time.time() batch_counter = 0 while not self.shutdown_event.is_set(): # Poll gamepad and send update message = self.gamepad.poll() if message: # Strip unnecessary fields to reduce message size if batch_counter % 10 != 0: # Only include full data every 10th message if "timestamp" in message: del message["timestamp"] batch_counter += 1 if not await tcp_handler.send_message(message): break if not await tcp_handler.drain(): break if tcp_handler.writer.is_closing(): logger.warning("Server closed the connection") break # Calculate sleep time to maintain consistent update rate current_time = time.time() elapsed = current_time - last_send_time sleep_time = max(0, send_interval - elapsed) # Adaptive timing - if we're behind schedule, catch up gradually if sleep_time < 0: sleep_time = 0 # Reset timing if we fall too far behind if elapsed > send_interval * 3: logger.warning("Timing drift detected, resetting schedule") last_send_time = current_time else: last_send_time = current_time + sleep_time await asyncio.sleep(sleep_time) # Check if heartbeat task failed if self.heartbeat_task.done(): if self.heartbeat_task.exception(): logger.error(f"Heartbeat task failed: {self.heartbeat_task.exception()}") else: logger.warning("Heartbeat task ended") break except asyncio.CancelledError: logger.info("TCP send loop cancelled") break except Exception as e: logger.error(f"TCP connection error: {e}") # Update connection state self.connected = False # Cleanup try: if 'tcp_handler' in locals(): tcp_handler.close() await tcp_handler.wait_closed() if self.heartbeat_task and not self.heartbeat_task.done(): self.heartbeat_task.cancel() try: await self.heartbeat_task except asyncio.CancelledError: pass except Exception as e: logger.error(f"Error during cleanup: {e}") # Reconnect after delay if not self.shutdown_event.is_set(): logger.info(f"Reconnecting in {RECONNECT_DELAY} seconds...") await asyncio.sleep(RECONNECT_DELAY) async def _tcp_auth(self, tcp_handler) -> bool: """Authenticate with the server over TCP""" try: # Wait for auth challenge from server logger.debug("Waiting for authentication challenge from server") challenge_msg = await tcp_handler.read_message() if not challenge_msg or challenge_msg.get("type") != "auth_challenge": logger.error("Authentication protocol error: Invalid challenge") return False logger.debug("Received authentication challenge") # Extract challenge and verify protocol version challenge = challenge_msg.get("challenge") salt = challenge_msg.get("salt") iterations = challenge_msg.get("iterations") if salt and iterations: logger.debug(f"Updating authenticator parameters (iterations: {iterations})") self.authenticator.set_parameters(salt, int(iterations)) if challenge_msg.get("protocol_version") != PROTOCOL_VERSION: logger.error(f"Protocol version mismatch: server={challenge_msg.get('protocol_version')}, client={PROTOCOL_VERSION}") return False # Calculate HMAC response logger.debug("Generating authentication response") response = self.authenticator.generate_tcp_response(challenge) # Send auth response logger.debug("Sending authentication response") auth_response = { "type": "auth_response", "protocol_version": PROTOCOL_VERSION, "response": response } await tcp_handler.send_message(auth_response) await tcp_handler.drain() # Wait for auth result logger.debug("Waiting for authentication result") result = await tcp_handler.read_message() if not result: return False if result.get("type") == "auth_success": logger.info("Authentication successful") return True else: logger.error(f"Authentication failed: {result.get('message', 'Unknown error')}") return False except Exception as e: logger.error(f"Authentication error: {e}") return False async def _tcp_heartbeat_loop( self, tcp_handler, interval: int = HEARTBEAT_INTERVAL ) -> bool: """Send heartbeats at ``interval`` seconds""" while not self.shutdown_event.is_set(): try: heartbeat = { "type": "heartbeat", "protocol_version": PROTOCOL_VERSION, "timestamp": datetime.now().isoformat() } await tcp_handler.send_message(heartbeat) await tcp_handler.drain() # Wait for heartbeat acknowledgment with timeout try: response = await asyncio.wait_for(tcp_handler.read_message(), timeout=5.0) if not response or response.get("type") != "heartbeat_ack": logger.warning("Invalid heartbeat response") return False except asyncio.TimeoutError: logger.warning("Heartbeat acknowledgment timeout") return False await asyncio.sleep(interval) except asyncio.CancelledError: logger.info("Heartbeat task cancelled") return True except Exception as e: logger.error(f"Heartbeat error: {e}") return False async def _run_udp(self) -> None: """Main loop for UDP mode""" loop = asyncio.get_running_loop() while not self.shutdown_event.is_set(): try: protocol = UDPClientProtocol() transport, _ = await loop.create_datagram_endpoint( lambda: protocol, remote_addr=(self.server_ip, self.server_port) ) # Initialize UDP handler udp_handler = UDPProtocolHandler(transport, (self.server_ip, self.server_port)) logger.info(f"UDP client sending to {self.server_ip}:{self.server_port}") # Set connection state self.connected = True # Retrieve hashing parameters if the client only has the plaintex # password. If the server does not respond, fall back to using # the plaintext key. if self.authenticator.password_plain: req = { "type": "auth_params_request", "protocol_version": PROTOCOL_VERSION, "timestamp": datetime.now().isoformat(), } udp_handler.send_message(req) try: data = await asyncio.wait_for(protocol.queue.get(), timeout=0.2) except asyncio.TimeoutError: logger.debug("No auth parameters received; assuming plaintext mode") else: resp = json.loads(data.decode("utf-8")) if resp.get("type") == "auth_params": salt = resp.get("salt") iterations = resp.get("iterations") if salt and iterations is not None: self.authenticator.set_parameters(str(salt), int(iterations)) # Use the derived key for tokens self.authenticator.password_plain = None last_heartbeat = time.time() send_interval = 1.0 / self.update_rate while not self.shutdown_event.is_set(): try: current_time = time.time() if current_time - last_heartbeat >= HEARTBEAT_INTERVAL: heartbeat = { "type": "heartbeat", "protocol_version": PROTOCOL_VERSION, "timestamp": datetime.now().isoformat(), } auth_token = self.authenticator.generate_udp_token() if auth_token: heartbeat["auth_token"] = auth_token udp_handler.send_message(heartbeat) try: data = await asyncio.wait_for(protocol.queue.get(), timeout=5.0) response = json.loads(data.decode("utf-8")) if response.get("type") != "heartbeat_ack": logger.warning("Invalid heartbeat response") raise ConnectionError("Invalid heartbeat ack") except asyncio.TimeoutError: logger.warning("Heartbeat acknowledgment timeout") raise ConnectionError("Heartbeat timeout") last_heartbeat = current_time message = self.gamepad.poll() if message: auth_token = self.authenticator.generate_udp_token() if auth_token: message["auth_token"] = auth_token udp_handler.send_message(message) await asyncio.sleep(send_interval) except asyncio.CancelledError: logger.info("UDP send loop cancelled") raise except Exception as e: logger.error(f"Error in UDP send loop: {e}") raise except asyncio.CancelledError: logger.info("UDP client cancelled") break except Exception as e: logger.error(f"UDP endpoint error: {e}") finally: self.connected = False if 'udp_handler' in locals(): udp_handler.close() if not self.shutdown_event.is_set(): logger.info(f"Reconnecting in {RECONNECT_DELAY} seconds...") await asyncio.sleep(RECONNECT_DELAY)