diff --git a/multi-agent-supervisor/custom_tools/ssh_tool.py b/multi-agent-supervisor/custom_tools/ssh_tool.py index ce023d2..17cc5b9 100644 --- a/multi-agent-supervisor/custom_tools/ssh_tool.py +++ b/multi-agent-supervisor/custom_tools/ssh_tool.py @@ -1,66 +1,42 @@ +import getpass import logging -import warnings +import time from typing import Any, List, Optional, Type, Union from langchain_core.tools import BaseTool -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field logger = logging.getLogger(__name__) class SSHInput(BaseModel): - """Commands for the SSH tool.""" - + """Input for SSH tool.""" commands: Union[str, List[str]] = Field( ..., - description="List of commands to run on the remote server", + description="Command(s) to run on the remote server" ) - """List of commands to run.""" + + +class SSHSession: + """Manages a persistent SSH session with sudo caching.""" - use_sudo: bool = Field( - default=False, - description="Whether to run commands with sudo privileges" - ) - """Whether to run commands with sudo.""" - - sudo_password: Optional[str] = Field( - default=None, - description="Password for sudo if required (will be used securely)" - ) - """Password for sudo if required.""" - - @model_validator(mode="before") - def _validate_commands(cls, values: dict) -> Any: - """Validate commands.""" - commands = values.get("commands") - if not isinstance(commands, list): - values["commands"] = [commands] - # Warn that the SSH tool has no safeguards - warnings.warn( - "The SSH tool has no safeguards by default. Use at your own risk." - ) - return values - - -class SSHProcess: - """Persistent SSH connection for command execution.""" - def __init__(self, host: str, username: str, port: int = 22, - password: Optional[str] = None, key_filename: Optional[str] = None, - **kwargs): - """Initialize SSH process with connection parameters.""" + key_filename: Optional[str] = None, password: Optional[str] = None): self.host = host self.username = username self.port = port - self.password = password self.key_filename = key_filename + self.password = password self.client = None - self._is_connected = False - + self._sudo_password = None + self._sudo_timestamp = None + self._sudo_timeout = 300 # 5 minutes, like default sudo + def connect(self): - """Establish SSH connection.""" - if self._is_connected: + """Establish SSH connection if not already connected.""" + if self.client: return + try: import paramiko except ImportError as e: @@ -68,7 +44,7 @@ class SSHProcess: "paramiko is required for SSH functionality. " "Install it with `pip install paramiko`" ) from e - + self.client = paramiko.SSHClient() self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) @@ -84,151 +60,142 @@ class SSHProcess: connect_kwargs["key_filename"] = self.key_filename self.client.connect(**connect_kwargs) - self._is_connected = True logger.info(f"SSH connection established to {self.username}@{self.host}:{self.port}") - - def run(self, commands: Union[str, List[str]], use_sudo: bool = False, - sudo_password: Optional[str] = None) -> str: - """Run commands over SSH and return output.""" - if not self._is_connected: - self.connect() + def _needs_sudo_password(self) -> bool: + """Check if we need to ask for sudo password.""" + if not self._sudo_password: + return True + if not self._sudo_timestamp: + return True + # Check if sudo timeout has expired + if time.time() - self._sudo_timestamp > self._sudo_timeout: + self._sudo_password = None + self._sudo_timestamp = None + return True + return False + + def execute(self, command: str) -> str: + """Execute a single command, handling sudo automatically.""" + if not self.client: + self.connect() + + # Check if command needs sudo + needs_sudo = command.strip().startswith('sudo ') + + if needs_sudo: + # Remove 'sudo ' prefix if present + actual_command = command.strip()[5:].strip() + + # Check if we need to get sudo password + if self._needs_sudo_password(): + self._sudo_password = getpass.getpass(f"[sudo] password for {self.username}: ") + self._sudo_timestamp = time.time() + + # Execute with sudo -S and pass password via stdin + full_command = f"sudo -S {actual_command}" + stdin, stdout, stderr = self.client.exec_command(full_command, get_pty=True) + + # Send password + stdin.write(f"{self._sudo_password}\n") + stdin.flush() + + # Get output + output = stdout.read().decode() + error = stderr.read().decode() + + # Clean up sudo prompt from output + lines = output.split('\n') + cleaned_lines = [line for line in lines if not line.strip().startswith('[sudo]')] + output = '\n'.join(cleaned_lines) + + else: + # Regular command without sudo + stdin, stdout, stderr = self.client.exec_command(command) + output = stdout.read().decode() + error = stderr.read().decode() + + # Combine output and error + result = output + if error: + result += f"\n{error}" if result else error + + return result.strip() + + def run_commands(self, commands: Union[str, List[str]]) -> str: + """Run one or more commands and return combined output.""" if isinstance(commands, str): commands = [commands] - + outputs = [] - for command in commands: + for cmd in commands: try: - # Prepare command with sudo if needed - if use_sudo: - if sudo_password: - # Use echo to pipe password to sudo -S (read from stdin) - full_command = f"echo '{sudo_password}' | sudo -S {command}" - else: - # Try sudo without password (for passwordless sudo) - full_command = f"sudo {command}" - else: - full_command = command - - stdin, stdout, stderr = self.client.exec_command(full_command) - - # For sudo commands with password, we need to handle stdin - if use_sudo and sudo_password: - stdin.write(f"{sudo_password}\n") - stdin.flush() - - output = stdout.read().decode() - error = stderr.read().decode() - - # Filter out sudo password prompt from error output - if use_sudo and error: - error_lines = error.split('\n') - filtered_error = '\n'.join( - line for line in error_lines - if not line.startswith('[sudo]') and line.strip() - ) - error = filtered_error - - if error: - outputs.append(f"$ {command}\n{output}{error}") - else: - outputs.append(f"$ {command}\n{output}") + output = self.execute(cmd) + outputs.append(f"$ {cmd}\n{output}") except Exception as e: - outputs.append(f"$ {command}\nError: {str(e)}") - + outputs.append(f"$ {cmd}\nError: {str(e)}") + return "\n\n".join(outputs) - - def __del__(self): - """Close SSH connection when object is destroyed.""" - if self.client and self._is_connected: + + def close(self): + """Close the SSH connection.""" + if self.client: self.client.close() + self.client = None logger.info(f"SSH connection closed to {self.username}@{self.host}") - - -def _get_default_ssh_process(host: str, username: str, **kwargs) -> Any: - """Get default SSH process with persistent connection.""" - return SSHProcess(host=host, username=username, **kwargs) + + def __del__(self): + """Ensure connection is closed when object is destroyed.""" + self.close() class SSHTool(BaseTool): - """Tool to run commands on remote servers via SSH with persistent connection.""" - - process: Optional[SSHProcess] = Field(default=None) - """SSH process with persistent connection.""" - - # Connection parameters - host: str = Field(..., description="SSH host address") - username: str = Field(..., description="SSH username") - port: int = Field(default=22, description="SSH port") - password: Optional[str] = Field(default=None, description="SSH password") - key_filename: Optional[str] = Field(default=None, description="Path to SSH key") + """Simple SSH tool that behaves like a normal terminal session.""" - # Tool configuration name: str = "ssh" - description: str = """ - Run shell commands on a remote server via SSH. + description: str = """Execute commands on a remote server via SSH. - This tool maintains a persistent SSH connection and allows executing - commands on the remote server. It supports both regular and privileged - (sudo) command execution. - - Use the 'use_sudo' parameter to run commands with sudo privileges. - If sudo requires a password, provide it via 'sudo_password'. + Simply pass the commands you want to run. Use 'sudo' prefix for privileged commands. + The tool will automatically handle password prompts and maintain sudo session. Examples: - - Regular command: {"commands": "ls -la"} - - Sudo command: {"commands": "apt update", "use_sudo": true} - - Multiple commands: {"commands": ["df -h", "free -m", "top -n 1"]} + - {"commands": "ls -la"} + - {"commands": "sudo apt update"} + - {"commands": ["df -h", "sudo systemctl status nginx", "free -m"]} """ args_schema: Type[BaseModel] = SSHInput + # Connection parameters + host: str = Field(..., description="SSH host") + username: str = Field(..., description="SSH username") + port: int = Field(default=22, description="SSH port") + key_filename: Optional[str] = Field(default=None, description="SSH key path") + password: Optional[str] = Field(default=None, description="SSH password") + + # Session management + session: Optional[SSHSession] = Field(default=None, exclude=True) + model_config = { "arbitrary_types_allowed": True } - + def __init__(self, **kwargs): - """Initialize SSH tool and set description.""" + """Initialize SSH tool.""" super().__init__(**kwargs) - - self.description = f"Run commands on remote server {self.username}@{self.host}:{self.port}" - # Initialize the SSH process (but don't connect yet) - self.process = SSHProcess( + # Create session but don't connect yet + self.session = SSHSession( host=self.host, username=self.username, port=self.port, - password=self.password, - key_filename=self.key_filename + key_filename=self.key_filename, + password=self.password ) - - def _run( - self, - commands: Union[str, List[str]], - use_sudo: bool = False, - sudo_password: Optional[str] = None, - **kwargs, - ) -> str: - """Run commands on remote server and return output.""" + + def _run(self, commands: Union[str, List[str]], **kwargs) -> str: + """Execute commands on remote server.""" try: - print(f"Executing SSH command on {self.username}@{self.host}:{self.port}") # noqa: T201 - print(f"Commands: {commands}") # noqa: T201 - if use_sudo: - print("Running with sudo privileges") # noqa: T201 - - # Safety check for privileged commands - if use_sudo: - user_input = input("Proceed with sudo command execution? (y/n): ").lower() - if user_input == "y": - return self.process.run(commands, use_sudo=use_sudo, sudo_password=sudo_password) - else: - logger.info("User aborted sudo command execution.") - return "Command execution aborted by user." - else: - user_input = input("Proceed with SSH command execution? (y/n): ").lower() - if user_input == "y": - return self.process.run(commands) - else: - logger.info("Invalid input. User aborted SSH command execution.") - return "Command execution aborted by user." + print(f"Executing on {self.username}@{self.host}:{self.port}") + return self.session.run_commands(commands) except Exception as e: - logger.error(f"Error during SSH command execution: {e}") + logger.error(f"SSH execution error: {e}") return f"Error: {str(e)}" \ No newline at end of file diff --git a/simple-react-agent/custom_tools/ssh_tool.py b/simple-react-agent/custom_tools/ssh_tool.py index ce023d2..17cc5b9 100644 --- a/simple-react-agent/custom_tools/ssh_tool.py +++ b/simple-react-agent/custom_tools/ssh_tool.py @@ -1,66 +1,42 @@ +import getpass import logging -import warnings +import time from typing import Any, List, Optional, Type, Union from langchain_core.tools import BaseTool -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field logger = logging.getLogger(__name__) class SSHInput(BaseModel): - """Commands for the SSH tool.""" - + """Input for SSH tool.""" commands: Union[str, List[str]] = Field( ..., - description="List of commands to run on the remote server", + description="Command(s) to run on the remote server" ) - """List of commands to run.""" + + +class SSHSession: + """Manages a persistent SSH session with sudo caching.""" - use_sudo: bool = Field( - default=False, - description="Whether to run commands with sudo privileges" - ) - """Whether to run commands with sudo.""" - - sudo_password: Optional[str] = Field( - default=None, - description="Password for sudo if required (will be used securely)" - ) - """Password for sudo if required.""" - - @model_validator(mode="before") - def _validate_commands(cls, values: dict) -> Any: - """Validate commands.""" - commands = values.get("commands") - if not isinstance(commands, list): - values["commands"] = [commands] - # Warn that the SSH tool has no safeguards - warnings.warn( - "The SSH tool has no safeguards by default. Use at your own risk." - ) - return values - - -class SSHProcess: - """Persistent SSH connection for command execution.""" - def __init__(self, host: str, username: str, port: int = 22, - password: Optional[str] = None, key_filename: Optional[str] = None, - **kwargs): - """Initialize SSH process with connection parameters.""" + key_filename: Optional[str] = None, password: Optional[str] = None): self.host = host self.username = username self.port = port - self.password = password self.key_filename = key_filename + self.password = password self.client = None - self._is_connected = False - + self._sudo_password = None + self._sudo_timestamp = None + self._sudo_timeout = 300 # 5 minutes, like default sudo + def connect(self): - """Establish SSH connection.""" - if self._is_connected: + """Establish SSH connection if not already connected.""" + if self.client: return + try: import paramiko except ImportError as e: @@ -68,7 +44,7 @@ class SSHProcess: "paramiko is required for SSH functionality. " "Install it with `pip install paramiko`" ) from e - + self.client = paramiko.SSHClient() self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) @@ -84,151 +60,142 @@ class SSHProcess: connect_kwargs["key_filename"] = self.key_filename self.client.connect(**connect_kwargs) - self._is_connected = True logger.info(f"SSH connection established to {self.username}@{self.host}:{self.port}") - - def run(self, commands: Union[str, List[str]], use_sudo: bool = False, - sudo_password: Optional[str] = None) -> str: - """Run commands over SSH and return output.""" - if not self._is_connected: - self.connect() + def _needs_sudo_password(self) -> bool: + """Check if we need to ask for sudo password.""" + if not self._sudo_password: + return True + if not self._sudo_timestamp: + return True + # Check if sudo timeout has expired + if time.time() - self._sudo_timestamp > self._sudo_timeout: + self._sudo_password = None + self._sudo_timestamp = None + return True + return False + + def execute(self, command: str) -> str: + """Execute a single command, handling sudo automatically.""" + if not self.client: + self.connect() + + # Check if command needs sudo + needs_sudo = command.strip().startswith('sudo ') + + if needs_sudo: + # Remove 'sudo ' prefix if present + actual_command = command.strip()[5:].strip() + + # Check if we need to get sudo password + if self._needs_sudo_password(): + self._sudo_password = getpass.getpass(f"[sudo] password for {self.username}: ") + self._sudo_timestamp = time.time() + + # Execute with sudo -S and pass password via stdin + full_command = f"sudo -S {actual_command}" + stdin, stdout, stderr = self.client.exec_command(full_command, get_pty=True) + + # Send password + stdin.write(f"{self._sudo_password}\n") + stdin.flush() + + # Get output + output = stdout.read().decode() + error = stderr.read().decode() + + # Clean up sudo prompt from output + lines = output.split('\n') + cleaned_lines = [line for line in lines if not line.strip().startswith('[sudo]')] + output = '\n'.join(cleaned_lines) + + else: + # Regular command without sudo + stdin, stdout, stderr = self.client.exec_command(command) + output = stdout.read().decode() + error = stderr.read().decode() + + # Combine output and error + result = output + if error: + result += f"\n{error}" if result else error + + return result.strip() + + def run_commands(self, commands: Union[str, List[str]]) -> str: + """Run one or more commands and return combined output.""" if isinstance(commands, str): commands = [commands] - + outputs = [] - for command in commands: + for cmd in commands: try: - # Prepare command with sudo if needed - if use_sudo: - if sudo_password: - # Use echo to pipe password to sudo -S (read from stdin) - full_command = f"echo '{sudo_password}' | sudo -S {command}" - else: - # Try sudo without password (for passwordless sudo) - full_command = f"sudo {command}" - else: - full_command = command - - stdin, stdout, stderr = self.client.exec_command(full_command) - - # For sudo commands with password, we need to handle stdin - if use_sudo and sudo_password: - stdin.write(f"{sudo_password}\n") - stdin.flush() - - output = stdout.read().decode() - error = stderr.read().decode() - - # Filter out sudo password prompt from error output - if use_sudo and error: - error_lines = error.split('\n') - filtered_error = '\n'.join( - line for line in error_lines - if not line.startswith('[sudo]') and line.strip() - ) - error = filtered_error - - if error: - outputs.append(f"$ {command}\n{output}{error}") - else: - outputs.append(f"$ {command}\n{output}") + output = self.execute(cmd) + outputs.append(f"$ {cmd}\n{output}") except Exception as e: - outputs.append(f"$ {command}\nError: {str(e)}") - + outputs.append(f"$ {cmd}\nError: {str(e)}") + return "\n\n".join(outputs) - - def __del__(self): - """Close SSH connection when object is destroyed.""" - if self.client and self._is_connected: + + def close(self): + """Close the SSH connection.""" + if self.client: self.client.close() + self.client = None logger.info(f"SSH connection closed to {self.username}@{self.host}") - - -def _get_default_ssh_process(host: str, username: str, **kwargs) -> Any: - """Get default SSH process with persistent connection.""" - return SSHProcess(host=host, username=username, **kwargs) + + def __del__(self): + """Ensure connection is closed when object is destroyed.""" + self.close() class SSHTool(BaseTool): - """Tool to run commands on remote servers via SSH with persistent connection.""" - - process: Optional[SSHProcess] = Field(default=None) - """SSH process with persistent connection.""" - - # Connection parameters - host: str = Field(..., description="SSH host address") - username: str = Field(..., description="SSH username") - port: int = Field(default=22, description="SSH port") - password: Optional[str] = Field(default=None, description="SSH password") - key_filename: Optional[str] = Field(default=None, description="Path to SSH key") + """Simple SSH tool that behaves like a normal terminal session.""" - # Tool configuration name: str = "ssh" - description: str = """ - Run shell commands on a remote server via SSH. + description: str = """Execute commands on a remote server via SSH. - This tool maintains a persistent SSH connection and allows executing - commands on the remote server. It supports both regular and privileged - (sudo) command execution. - - Use the 'use_sudo' parameter to run commands with sudo privileges. - If sudo requires a password, provide it via 'sudo_password'. + Simply pass the commands you want to run. Use 'sudo' prefix for privileged commands. + The tool will automatically handle password prompts and maintain sudo session. Examples: - - Regular command: {"commands": "ls -la"} - - Sudo command: {"commands": "apt update", "use_sudo": true} - - Multiple commands: {"commands": ["df -h", "free -m", "top -n 1"]} + - {"commands": "ls -la"} + - {"commands": "sudo apt update"} + - {"commands": ["df -h", "sudo systemctl status nginx", "free -m"]} """ args_schema: Type[BaseModel] = SSHInput + # Connection parameters + host: str = Field(..., description="SSH host") + username: str = Field(..., description="SSH username") + port: int = Field(default=22, description="SSH port") + key_filename: Optional[str] = Field(default=None, description="SSH key path") + password: Optional[str] = Field(default=None, description="SSH password") + + # Session management + session: Optional[SSHSession] = Field(default=None, exclude=True) + model_config = { "arbitrary_types_allowed": True } - + def __init__(self, **kwargs): - """Initialize SSH tool and set description.""" + """Initialize SSH tool.""" super().__init__(**kwargs) - - self.description = f"Run commands on remote server {self.username}@{self.host}:{self.port}" - # Initialize the SSH process (but don't connect yet) - self.process = SSHProcess( + # Create session but don't connect yet + self.session = SSHSession( host=self.host, username=self.username, port=self.port, - password=self.password, - key_filename=self.key_filename + key_filename=self.key_filename, + password=self.password ) - - def _run( - self, - commands: Union[str, List[str]], - use_sudo: bool = False, - sudo_password: Optional[str] = None, - **kwargs, - ) -> str: - """Run commands on remote server and return output.""" + + def _run(self, commands: Union[str, List[str]], **kwargs) -> str: + """Execute commands on remote server.""" try: - print(f"Executing SSH command on {self.username}@{self.host}:{self.port}") # noqa: T201 - print(f"Commands: {commands}") # noqa: T201 - if use_sudo: - print("Running with sudo privileges") # noqa: T201 - - # Safety check for privileged commands - if use_sudo: - user_input = input("Proceed with sudo command execution? (y/n): ").lower() - if user_input == "y": - return self.process.run(commands, use_sudo=use_sudo, sudo_password=sudo_password) - else: - logger.info("User aborted sudo command execution.") - return "Command execution aborted by user." - else: - user_input = input("Proceed with SSH command execution? (y/n): ").lower() - if user_input == "y": - return self.process.run(commands) - else: - logger.info("Invalid input. User aborted SSH command execution.") - return "Command execution aborted by user." + print(f"Executing on {self.username}@{self.host}:{self.port}") + return self.session.run_commands(commands) except Exception as e: - logger.error(f"Error during SSH command execution: {e}") + logger.error(f"SSH execution error: {e}") return f"Error: {str(e)}" \ No newline at end of file