import logging import warnings from typing import Any, List, Optional, Type, Union from langchain_core.tools import BaseTool from pydantic import BaseModel, Field, model_validator logger = logging.getLogger(__name__) class SSHInput(BaseModel): """Commands for the SSH tool.""" commands: Union[str, List[str]] = Field( ..., description="List of commands to run on the remote server", ) """List of commands to run.""" use_sudo: bool = Field( default=False, description="Whether to run commands with sudo privileges" ) """Whether to run commands with sudo.""" sudo_password: Optional[str] = Field( default=None, description="Password for sudo if required (will be used securely)" ) """Password for sudo if required.""" @model_validator(mode="before") def _validate_commands(cls, values: dict) -> Any: """Validate commands.""" commands = values.get("commands") if not isinstance(commands, list): values["commands"] = [commands] # Warn that the SSH tool has no safeguards warnings.warn( "The SSH tool has no safeguards by default. Use at your own risk." ) return values class SSHProcess: """Persistent SSH connection for command execution.""" def __init__(self, host: str, username: str, port: int = 22, password: Optional[str] = None, key_filename: Optional[str] = None, **kwargs): """Initialize SSH process with connection parameters.""" self.host = host self.username = username self.port = port self.password = password self.key_filename = key_filename self.client = None self._is_connected = False def connect(self): """Establish SSH connection.""" if self._is_connected: return try: import paramiko except ImportError as e: raise ImportError( "paramiko is required for SSH functionality. " "Install it with `pip install paramiko`" ) from e self.client = paramiko.SSHClient() self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) connect_kwargs = { "hostname": self.host, "port": self.port, "username": self.username, } if self.password: connect_kwargs["password"] = self.password if self.key_filename: connect_kwargs["key_filename"] = self.key_filename self.client.connect(**connect_kwargs) self._is_connected = True logger.info(f"SSH connection established to {self.username}@{self.host}:{self.port}") def run(self, commands: Union[str, List[str]], use_sudo: bool = False, sudo_password: Optional[str] = None) -> str: """Run commands over SSH and return output.""" if not self._is_connected: self.connect() if isinstance(commands, str): commands = [commands] outputs = [] for command in commands: try: # Prepare command with sudo if needed if use_sudo: if sudo_password: # Use echo to pipe password to sudo -S (read from stdin) full_command = f"echo '{sudo_password}' | sudo -S {command}" else: # Try sudo without password (for passwordless sudo) full_command = f"sudo {command}" else: full_command = command stdin, stdout, stderr = self.client.exec_command(full_command) # For sudo commands with password, we need to handle stdin if use_sudo and sudo_password: stdin.write(f"{sudo_password}\n") stdin.flush() output = stdout.read().decode() error = stderr.read().decode() # Filter out sudo password prompt from error output if use_sudo and error: error_lines = error.split('\n') filtered_error = '\n'.join( line for line in error_lines if not line.startswith('[sudo]') and line.strip() ) error = filtered_error if error: outputs.append(f"$ {command}\n{output}{error}") else: outputs.append(f"$ {command}\n{output}") except Exception as e: outputs.append(f"$ {command}\nError: {str(e)}") return "\n\n".join(outputs) def __del__(self): """Close SSH connection when object is destroyed.""" if self.client and self._is_connected: self.client.close() logger.info(f"SSH connection closed to {self.username}@{self.host}") def _get_default_ssh_process(host: str, username: str, **kwargs) -> Any: """Get default SSH process with persistent connection.""" return SSHProcess(host=host, username=username, **kwargs) class SSHTool(BaseTool): """Tool to run commands on remote servers via SSH with persistent connection.""" process: Optional[SSHProcess] = Field(default=None) """SSH process with persistent connection.""" # Connection parameters host: str = Field(..., description="SSH host address") username: str = Field(..., description="SSH username") port: int = Field(default=22, description="SSH port") password: Optional[str] = Field(default=None, description="SSH password") key_filename: Optional[str] = Field(default=None, description="Path to SSH key") # Tool configuration name: str = "ssh" description: str = """ Run shell commands on a remote server via SSH. This tool maintains a persistent SSH connection and allows executing commands on the remote server. It supports both regular and privileged (sudo) command execution. Use the 'use_sudo' parameter to run commands with sudo privileges. If sudo requires a password, provide it via 'sudo_password'. Examples: - Regular command: {"commands": "ls -la"} - Sudo command: {"commands": "apt update", "use_sudo": true} - Multiple commands: {"commands": ["df -h", "free -m", "top -n 1"]} """ args_schema: Type[BaseModel] = SSHInput model_config = { "arbitrary_types_allowed": True } def __init__(self, **kwargs): """Initialize SSH tool and set description.""" super().__init__(**kwargs) self.description = f"Run commands on remote server {self.username}@{self.host}:{self.port}" # Initialize the SSH process (but don't connect yet) self.process = SSHProcess( host=self.host, username=self.username, port=self.port, password=self.password, key_filename=self.key_filename ) def _run( self, commands: Union[str, List[str]], use_sudo: bool = False, sudo_password: Optional[str] = None, **kwargs, ) -> str: """Run commands on remote server and return output.""" try: print(f"Executing SSH command on {self.username}@{self.host}:{self.port}") # noqa: T201 print(f"Commands: {commands}") # noqa: T201 if use_sudo: print("Running with sudo privileges") # noqa: T201 # Safety check for privileged commands if use_sudo: user_input = input("Proceed with sudo command execution? (y/n): ").lower() if user_input == "y": return self.process.run(commands, use_sudo=use_sudo, sudo_password=sudo_password) else: logger.info("User aborted sudo command execution.") return "Command execution aborted by user." else: user_input = input("Proceed with SSH command execution? (y/n): ").lower() if user_input == "y": return self.process.run(commands) else: logger.info("Invalid input. User aborted SSH command execution.") return "Command execution aborted by user." except Exception as e: logger.error(f"Error during SSH command execution: {e}") return f"Error: {str(e)}"