2025-06-29 14:49:07 +02:00

186 lines
6.4 KiB
Python

import logging
import warnings
from typing import Any, List, Optional, Type, Union
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field, model_validator
logger = logging.getLogger(__name__)
class SSHInput(BaseModel):
"""Commands for the SSH tool."""
commands: Union[str, List[str]] = Field(
...,
description="List of commands to run on the remote server",
)
"""List of commands to run."""
@model_validator(mode="before")
@classmethod
def _validate_commands(cls, values: dict) -> Any:
"""Validate commands."""
commands = values.get("commands")
if not isinstance(commands, list):
values["commands"] = [commands]
# Warn that the SSH tool has no safeguards
warnings.warn(
"The SSH tool has no safeguards by default. Use at your own risk."
)
return values
class SSHProcess:
"""Persistent SSH connection for command execution."""
def __init__(self, host: str, username: str, port: int = 22,
password: Optional[str] = None, key_filename: Optional[str] = None,
timeout: float = 30.0, return_err_output: bool = True):
"""Initialize SSH process with connection parameters."""
self.host = host
self.username = username
self.port = port
self.password = password
self.key_filename = key_filename
self.timeout = timeout
self.return_err_output = return_err_output
self.client = None
# Don't connect immediately - connect when needed
def _connect(self):
"""Establish SSH connection."""
try:
import paramiko
except ImportError:
raise ImportError(
"paramiko is required for SSH functionality. "
"Install it with `pip install paramiko`"
)
self.client = paramiko.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
connect_kwargs = {
"hostname": self.host,
"username": self.username,
"port": self.port,
"timeout": self.timeout,
}
if self.password:
connect_kwargs["password"] = self.password
if self.key_filename:
connect_kwargs["key_filename"] = self.key_filename
self.client.connect(**connect_kwargs)
logger.info(f"SSH connection established to {self.username}@{self.host}:{self.port}")
def run(self, commands: Union[str, List[str]]) -> str:
"""Run commands over SSH and return output."""
if not self.client:
self._connect()
if isinstance(commands, str):
commands = [commands]
outputs = []
for command in commands:
try:
stdin, stdout, stderr = self.client.exec_command(command)
output = stdout.read().decode('utf-8')
error = stderr.read().decode('utf-8')
if error and self.return_err_output:
outputs.append(f"$ {command}\n{output}{error}")
else:
outputs.append(f"$ {command}\n{output}")
except Exception as e:
outputs.append(f"$ {command}\nError: {str(e)}")
return "\n".join(outputs)
def __del__(self):
"""Close SSH connection when object is destroyed."""
if self.client:
self.client.close()
logger.info(f"SSH connection closed to {self.username}@{self.host}")
def _get_default_ssh_process(host: str, username: str, **kwargs) -> Any:
"""Get default SSH process with persistent connection."""
return SSHProcess(host=host, username=username, **kwargs)
class SSHTool(BaseTool):
"""Tool to run commands on remote servers via SSH with persistent connection."""
process: Optional[SSHProcess] = Field(default=None)
"""SSH process with persistent connection."""
# Connection parameters
host: str = Field(..., description="SSH host address")
username: str = Field(..., description="SSH username")
port: int = Field(default=22, description="SSH port")
password: Optional[str] = Field(default=None, description="SSH password")
key_filename: Optional[str] = Field(default=None, description="Path to SSH key")
timeout: float = Field(default=30.0, description="Connection timeout")
name: str = "ssh"
"""Name of tool."""
description: str = Field(default="")
"""Description of tool."""
args_schema: Type[BaseModel] = SSHInput
"""Schema for input arguments."""
ask_human_input: bool = False
"""
If True, prompts the user for confirmation (y/n) before executing
commands on the remote server.
"""
def __init__(self, **data):
"""Initialize SSH tool and set description."""
super().__init__(**data)
# Set description after initialization
self.description = f"Run commands on remote server {self.username}@{self.host}:{self.port}"
# Initialize the SSH process (but don't connect yet)
self.process = SSHProcess(
host=self.host,
username=self.username,
port=self.port,
password=self.password,
key_filename=self.key_filename,
timeout=self.timeout,
return_err_output=True
)
def _run(
self,
commands: Union[str, List[str]],
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Run commands on remote server and return output."""
print(f"Executing SSH command on {self.username}@{self.host}:{self.port}") # noqa: T201
print(f"Commands: {commands}") # noqa: T201
try:
if self.ask_human_input:
user_input = input("Proceed with SSH command execution? (y/n): ").lower()
if user_input == "y":
return self.process.run(commands)
else:
logger.info("Invalid input. User aborted SSH command execution.")
return None # type: ignore[return-value]
else:
return self.process.run(commands)
except Exception as e:
logger.error(f"Error during SSH command execution: {e}")
return None # type: ignore[return-value]