better detection
This commit is contained in:
@@ -2,17 +2,19 @@
|
||||
|
||||
from .poem_tool import print_poem
|
||||
from .ssh_tool import SSHTool
|
||||
from .ssh_connection_manager import ssh_manager
|
||||
from langchain_community.tools.shell.tool import ShellTool
|
||||
|
||||
# Pre-configured SSH tool for your server - only connects when actually used
|
||||
# Pre-configured SSH tool for your server - uses shared connection to prevent SSH banner errors
|
||||
# TODO: Update these connection details for your actual server
|
||||
configured_remote_server = SSHTool(
|
||||
host="157.90.211.119", # Replace with your server
|
||||
port=8081,
|
||||
username="g", # Replace with your username
|
||||
key_filename="/Users/ghsioux/.ssh/id_rsa_hetzner", # Replace with your key path
|
||||
ask_human_input=True # Safety confirmation
|
||||
ask_human_input=True, # Safety confirmation
|
||||
use_shared_connection=True # Use shared connection pool to prevent banner errors
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["print_poem", "SSHTool", "ShellTool", "configured_remote_server"]
|
||||
__all__ = ["print_poem", "SSHTool", "ShellTool", "configured_remote_server", "ssh_manager"]
|
||||
|
@@ -0,0 +1,89 @@
|
||||
"""SSH Connection Manager for preventing multiple simultaneous connections."""
|
||||
|
||||
import threading
|
||||
import logging
|
||||
from typing import Optional, Dict, Tuple
|
||||
from .ssh_tool import SSHSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SSHConnectionManager:
|
||||
"""Manages shared SSH connections to prevent connection flooding."""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._connections = {}
|
||||
cls._instance._execution_lock = threading.Lock()
|
||||
return cls._instance
|
||||
|
||||
def _get_connection_key(self, host: str, username: str, port: int) -> str:
|
||||
"""Generate a unique key for the connection."""
|
||||
return f"{username}@{host}:{port}"
|
||||
|
||||
def get_session(self, host: str, username: str, port: int = 22,
|
||||
key_filename: Optional[str] = None,
|
||||
password: Optional[str] = None) -> SSHSession:
|
||||
"""Get or create a shared SSH session."""
|
||||
connection_key = self._get_connection_key(host, username, port)
|
||||
|
||||
with self._lock:
|
||||
if connection_key not in self._connections:
|
||||
logger.info(f"Creating new shared SSH session to {connection_key}")
|
||||
session = SSHSession(
|
||||
host=host,
|
||||
username=username,
|
||||
port=port,
|
||||
key_filename=key_filename,
|
||||
password=password
|
||||
)
|
||||
# Don't connect immediately - let it connect on first use
|
||||
self._connections[connection_key] = session
|
||||
|
||||
return self._connections[connection_key]
|
||||
|
||||
def execute_with_lock(self, session: SSHSession, commands) -> str:
|
||||
"""Execute commands with a global lock to prevent parallel SSH operations."""
|
||||
with self._execution_lock:
|
||||
logger.debug("Acquired SSH execution lock")
|
||||
try:
|
||||
result = session.run_commands(commands)
|
||||
# Add a small delay to prevent rapid successive connections
|
||||
import time
|
||||
time.sleep(0.2) # 200ms delay
|
||||
return result
|
||||
finally:
|
||||
logger.debug("Released SSH execution lock")
|
||||
|
||||
def close_all(self):
|
||||
"""Close all managed connections."""
|
||||
with self._lock:
|
||||
for connection_key, session in self._connections.items():
|
||||
try:
|
||||
session.close()
|
||||
logger.info(f"Closed SSH connection: {connection_key}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing connection {connection_key}: {e}")
|
||||
self._connections.clear()
|
||||
|
||||
def close_connection(self, host: str, username: str, port: int = 22):
|
||||
"""Close a specific connection."""
|
||||
connection_key = self._get_connection_key(host, username, port)
|
||||
with self._lock:
|
||||
if connection_key in self._connections:
|
||||
try:
|
||||
self._connections[connection_key].close()
|
||||
del self._connections[connection_key]
|
||||
logger.info(f"Closed SSH connection: {connection_key}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing connection {connection_key}: {e}")
|
||||
|
||||
|
||||
# Global connection manager instance
|
||||
ssh_manager = SSHConnectionManager()
|
@@ -173,9 +173,11 @@ class SSHTool(BaseTool):
|
||||
port: int = Field(default=22, description="SSH port")
|
||||
key_filename: Optional[str] = Field(default=None, description="SSH key path")
|
||||
password: Optional[str] = Field(default=None, description="SSH password")
|
||||
ask_human_input: bool = Field(default=False, description="Ask for human confirmation")
|
||||
|
||||
# Session management
|
||||
session: Optional[SSHSession] = Field(default=None, exclude=True)
|
||||
use_shared_connection: bool = Field(default=True, description="Use shared SSH connection")
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True
|
||||
@@ -184,20 +186,41 @@ class SSHTool(BaseTool):
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize SSH tool."""
|
||||
super().__init__(**kwargs)
|
||||
# Create session but don't connect yet
|
||||
self.session = SSHSession(
|
||||
host=self.host,
|
||||
username=self.username,
|
||||
port=self.port,
|
||||
key_filename=self.key_filename,
|
||||
password=self.password
|
||||
)
|
||||
|
||||
if self.use_shared_connection:
|
||||
# Import here to avoid circular dependency
|
||||
from .ssh_connection_manager import ssh_manager
|
||||
# Use the shared connection manager
|
||||
self.session = ssh_manager.get_session(
|
||||
host=self.host,
|
||||
username=self.username,
|
||||
port=self.port,
|
||||
key_filename=self.key_filename,
|
||||
password=self.password
|
||||
)
|
||||
else:
|
||||
# Create individual session but don't connect yet
|
||||
self.session = SSHSession(
|
||||
host=self.host,
|
||||
username=self.username,
|
||||
port=self.port,
|
||||
key_filename=self.key_filename,
|
||||
password=self.password
|
||||
)
|
||||
|
||||
def _run(self, commands: Union[str, List[str]], **kwargs) -> str:
|
||||
"""Execute commands on remote server."""
|
||||
try:
|
||||
print(f"Executing on {self.username}@{self.host}:{self.port}")
|
||||
return self.session.run_commands(commands)
|
||||
|
||||
if self.use_shared_connection:
|
||||
# Use the connection manager's execution lock
|
||||
from .ssh_connection_manager import ssh_manager
|
||||
return ssh_manager.execute_with_lock(self.session, commands)
|
||||
else:
|
||||
# Direct execution without shared lock
|
||||
return self.session.run_commands(commands)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SSH execution error: {e}")
|
||||
return f"Error: {str(e)}"
|
Reference in New Issue
Block a user