90 lines
3.5 KiB
Python
90 lines
3.5 KiB
Python
"""SSH Connection Manager for preventing multiple simultaneous connections."""
|
|
|
|
import threading
|
|
import logging
|
|
from typing import Optional, Dict, Tuple
|
|
from .ssh_tool import SSHSession
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SSHConnectionManager:
|
|
"""Manages shared SSH connections to prevent connection flooding."""
|
|
|
|
_instance = None
|
|
_lock = threading.Lock()
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
with cls._lock:
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
cls._instance._connections = {}
|
|
cls._instance._execution_lock = threading.Lock()
|
|
return cls._instance
|
|
|
|
def _get_connection_key(self, host: str, username: str, port: int) -> str:
|
|
"""Generate a unique key for the connection."""
|
|
return f"{username}@{host}:{port}"
|
|
|
|
def get_session(self, host: str, username: str, port: int = 22,
|
|
key_filename: Optional[str] = None,
|
|
password: Optional[str] = None) -> SSHSession:
|
|
"""Get or create a shared SSH session."""
|
|
connection_key = self._get_connection_key(host, username, port)
|
|
|
|
with self._lock:
|
|
if connection_key not in self._connections:
|
|
logger.info(f"Creating new shared SSH session to {connection_key}")
|
|
session = SSHSession(
|
|
host=host,
|
|
username=username,
|
|
port=port,
|
|
key_filename=key_filename,
|
|
password=password
|
|
)
|
|
# Don't connect immediately - let it connect on first use
|
|
self._connections[connection_key] = session
|
|
|
|
return self._connections[connection_key]
|
|
|
|
def execute_with_lock(self, session: SSHSession, commands) -> str:
|
|
"""Execute commands with a global lock to prevent parallel SSH operations."""
|
|
with self._execution_lock:
|
|
logger.debug("Acquired SSH execution lock")
|
|
try:
|
|
result = session.run_commands(commands)
|
|
# Add a small delay to prevent rapid successive connections
|
|
import time
|
|
time.sleep(0.2) # 200ms delay
|
|
return result
|
|
finally:
|
|
logger.debug("Released SSH execution lock")
|
|
|
|
def close_all(self):
|
|
"""Close all managed connections."""
|
|
with self._lock:
|
|
for connection_key, session in self._connections.items():
|
|
try:
|
|
session.close()
|
|
logger.info(f"Closed SSH connection: {connection_key}")
|
|
except Exception as e:
|
|
logger.error(f"Error closing connection {connection_key}: {e}")
|
|
self._connections.clear()
|
|
|
|
def close_connection(self, host: str, username: str, port: int = 22):
|
|
"""Close a specific connection."""
|
|
connection_key = self._get_connection_key(host, username, port)
|
|
with self._lock:
|
|
if connection_key in self._connections:
|
|
try:
|
|
self._connections[connection_key].close()
|
|
del self._connections[connection_key]
|
|
logger.info(f"Closed SSH connection: {connection_key}")
|
|
except Exception as e:
|
|
logger.error(f"Error closing connection {connection_key}: {e}")
|
|
|
|
|
|
# Global connection manager instance
|
|
ssh_manager = SSHConnectionManager()
|