"""SSH Connection Manager for preventing multiple simultaneous connections.""" import threading import logging from typing import Optional, Dict, Tuple from .ssh_tool import SSHSession logger = logging.getLogger(__name__) class SSHConnectionManager: """Manages shared SSH connections to prevent connection flooding.""" _instance = None _lock = threading.Lock() def __new__(cls): if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._connections = {} cls._instance._execution_lock = threading.Lock() return cls._instance def _get_connection_key(self, host: str, username: str, port: int) -> str: """Generate a unique key for the connection.""" return f"{username}@{host}:{port}" def get_session(self, host: str, username: str, port: int = 22, key_filename: Optional[str] = None, password: Optional[str] = None) -> SSHSession: """Get or create a shared SSH session.""" connection_key = self._get_connection_key(host, username, port) with self._lock: if connection_key not in self._connections: logger.info(f"Creating new shared SSH session to {connection_key}") session = SSHSession( host=host, username=username, port=port, key_filename=key_filename, password=password ) # Don't connect immediately - let it connect on first use self._connections[connection_key] = session return self._connections[connection_key] def execute_with_lock(self, session: SSHSession, commands) -> str: """Execute commands with a global lock to prevent parallel SSH operations.""" with self._execution_lock: logger.debug("Acquired SSH execution lock") try: result = session.run_commands(commands) # Add a small delay to prevent rapid successive connections import time time.sleep(0.2) # 200ms delay return result finally: logger.debug("Released SSH execution lock") def close_all(self): """Close all managed connections.""" with self._lock: for connection_key, session in self._connections.items(): try: session.close() logger.info(f"Closed SSH connection: {connection_key}") except Exception as e: logger.error(f"Error closing connection {connection_key}: {e}") self._connections.clear() def close_connection(self, host: str, username: str, port: int = 22): """Close a specific connection.""" connection_key = self._get_connection_key(host, username, port) with self._lock: if connection_key in self._connections: try: self._connections[connection_key].close() del self._connections[connection_key] logger.info(f"Closed SSH connection: {connection_key}") except Exception as e: logger.error(f"Error closing connection {connection_key}: {e}") # Global connection manager instance ssh_manager = SSHConnectionManager()