import logging import warnings from typing import Any, List, Optional, Type, Union from langchain_core.callbacks import CallbackManagerForToolRun from langchain_core.tools import BaseTool from pydantic import BaseModel, Field, model_validator logger = logging.getLogger(__name__) class SSHInput(BaseModel): """Commands for the SSH tool.""" commands: Union[str, List[str]] = Field( ..., description="List of commands to run on the remote server", ) """List of commands to run.""" @model_validator(mode="before") @classmethod def _validate_commands(cls, values: dict) -> Any: """Validate commands.""" commands = values.get("commands") if not isinstance(commands, list): values["commands"] = [commands] # Warn that the SSH tool has no safeguards warnings.warn( "The SSH tool has no safeguards by default. Use at your own risk." ) return values class SSHProcess: """Persistent SSH connection for command execution.""" def __init__(self, host: str, username: str, port: int = 22, password: Optional[str] = None, key_filename: Optional[str] = None, timeout: float = 30.0, return_err_output: bool = True): """Initialize SSH process with connection parameters.""" self.host = host self.username = username self.port = port self.password = password self.key_filename = key_filename self.timeout = timeout self.return_err_output = return_err_output self.client = None # Don't connect immediately - connect when needed def _connect(self): """Establish SSH connection.""" try: import paramiko except ImportError: raise ImportError( "paramiko is required for SSH functionality. " "Install it with `pip install paramiko`" ) self.client = paramiko.SSHClient() self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) connect_kwargs = { "hostname": self.host, "username": self.username, "port": self.port, "timeout": self.timeout, } if self.password: connect_kwargs["password"] = self.password if self.key_filename: connect_kwargs["key_filename"] = self.key_filename self.client.connect(**connect_kwargs) logger.info(f"SSH connection established to {self.username}@{self.host}:{self.port}") def run(self, commands: Union[str, List[str]]) -> str: """Run commands over SSH and return output.""" if not self.client: self._connect() if isinstance(commands, str): commands = [commands] outputs = [] for command in commands: try: stdin, stdout, stderr = self.client.exec_command(command) output = stdout.read().decode('utf-8') error = stderr.read().decode('utf-8') if error and self.return_err_output: outputs.append(f"$ {command}\n{output}{error}") else: outputs.append(f"$ {command}\n{output}") except Exception as e: outputs.append(f"$ {command}\nError: {str(e)}") return "\n".join(outputs) def __del__(self): """Close SSH connection when object is destroyed.""" if self.client: self.client.close() logger.info(f"SSH connection closed to {self.username}@{self.host}") def _get_default_ssh_process(host: str, username: str, **kwargs) -> Any: """Get default SSH process with persistent connection.""" return SSHProcess(host=host, username=username, **kwargs) class SSHTool(BaseTool): """Tool to run commands on remote servers via SSH with persistent connection.""" process: Optional[SSHProcess] = Field(default=None) """SSH process with persistent connection.""" # Connection parameters host: str = Field(..., description="SSH host address") username: str = Field(..., description="SSH username") port: int = Field(default=22, description="SSH port") password: Optional[str] = Field(default=None, description="SSH password") key_filename: Optional[str] = Field(default=None, description="Path to SSH key") timeout: float = Field(default=30.0, description="Connection timeout") name: str = "ssh" """Name of tool.""" description: str = Field(default="") """Description of tool.""" args_schema: Type[BaseModel] = SSHInput """Schema for input arguments.""" ask_human_input: bool = False """ If True, prompts the user for confirmation (y/n) before executing commands on the remote server. """ def __init__(self, **data): """Initialize SSH tool and set description.""" super().__init__(**data) # Set description after initialization self.description = f"Run commands on remote server {self.username}@{self.host}:{self.port}" # Initialize the SSH process (but don't connect yet) self.process = SSHProcess( host=self.host, username=self.username, port=self.port, password=self.password, key_filename=self.key_filename, timeout=self.timeout, return_err_output=True ) def _run( self, commands: Union[str, List[str]], run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """Run commands on remote server and return output.""" print(f"Executing SSH command on {self.username}@{self.host}:{self.port}") # noqa: T201 print(f"Commands: {commands}") # noqa: T201 try: if self.ask_human_input: user_input = input("Proceed with SSH command execution? (y/n): ").lower() if user_input == "y": return self.process.run(commands) else: logger.info("Invalid input. User aborted SSH command execution.") return None # type: ignore[return-value] else: return self.process.run(commands) except Exception as e: logger.error(f"Error during SSH command execution: {e}") return None # type: ignore[return-value]