2025-06-30 07:58:13 +02:00

90 lines
3.5 KiB
Python

"""SSH Connection Manager for preventing multiple simultaneous connections."""
import threading
import logging
from typing import Optional, Dict, Tuple
from .ssh_tool import SSHSession
logger = logging.getLogger(__name__)
class SSHConnectionManager:
"""Manages shared SSH connections to prevent connection flooding."""
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._connections = {}
cls._instance._execution_lock = threading.Lock()
return cls._instance
def _get_connection_key(self, host: str, username: str, port: int) -> str:
"""Generate a unique key for the connection."""
return f"{username}@{host}:{port}"
def get_session(self, host: str, username: str, port: int = 22,
key_filename: Optional[str] = None,
password: Optional[str] = None) -> SSHSession:
"""Get or create a shared SSH session."""
connection_key = self._get_connection_key(host, username, port)
with self._lock:
if connection_key not in self._connections:
logger.info(f"Creating new shared SSH session to {connection_key}")
session = SSHSession(
host=host,
username=username,
port=port,
key_filename=key_filename,
password=password
)
# Don't connect immediately - let it connect on first use
self._connections[connection_key] = session
return self._connections[connection_key]
def execute_with_lock(self, session: SSHSession, commands) -> str:
"""Execute commands with a global lock to prevent parallel SSH operations."""
with self._execution_lock:
logger.debug("Acquired SSH execution lock")
try:
result = session.run_commands(commands)
# Add a small delay to prevent rapid successive connections
import time
time.sleep(0.2) # 200ms delay
return result
finally:
logger.debug("Released SSH execution lock")
def close_all(self):
"""Close all managed connections."""
with self._lock:
for connection_key, session in self._connections.items():
try:
session.close()
logger.info(f"Closed SSH connection: {connection_key}")
except Exception as e:
logger.error(f"Error closing connection {connection_key}: {e}")
self._connections.clear()
def close_connection(self, host: str, username: str, port: int = 22):
"""Close a specific connection."""
connection_key = self._get_connection_key(host, username, port)
with self._lock:
if connection_key in self._connections:
try:
self._connections[connection_key].close()
del self._connections[connection_key]
logger.info(f"Closed SSH connection: {connection_key}")
except Exception as e:
logger.error(f"Error closing connection {connection_key}: {e}")
# Global connection manager instance
ssh_manager = SSHConnectionManager()