2025-06-29 15:15:27 +02:00

234 lines
8.5 KiB
Python

import logging
import warnings
from typing import Any, List, Optional, Type, Union
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field, model_validator
logger = logging.getLogger(__name__)
class SSHInput(BaseModel):
"""Commands for the SSH tool."""
commands: Union[str, List[str]] = Field(
...,
description="List of commands to run on the remote server",
)
"""List of commands to run."""
use_sudo: bool = Field(
default=False,
description="Whether to run commands with sudo privileges"
)
"""Whether to run commands with sudo."""
sudo_password: Optional[str] = Field(
default=None,
description="Password for sudo if required (will be used securely)"
)
"""Password for sudo if required."""
@model_validator(mode="before")
def _validate_commands(cls, values: dict) -> Any:
"""Validate commands."""
commands = values.get("commands")
if not isinstance(commands, list):
values["commands"] = [commands]
# Warn that the SSH tool has no safeguards
warnings.warn(
"The SSH tool has no safeguards by default. Use at your own risk."
)
return values
class SSHProcess:
"""Persistent SSH connection for command execution."""
def __init__(self, host: str, username: str, port: int = 22,
password: Optional[str] = None, key_filename: Optional[str] = None,
**kwargs):
"""Initialize SSH process with connection parameters."""
self.host = host
self.username = username
self.port = port
self.password = password
self.key_filename = key_filename
self.client = None
self._is_connected = False
def connect(self):
"""Establish SSH connection."""
if self._is_connected:
return
try:
import paramiko
except ImportError as e:
raise ImportError(
"paramiko is required for SSH functionality. "
"Install it with `pip install paramiko`"
) from e
self.client = paramiko.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
connect_kwargs = {
"hostname": self.host,
"port": self.port,
"username": self.username,
}
if self.password:
connect_kwargs["password"] = self.password
if self.key_filename:
connect_kwargs["key_filename"] = self.key_filename
self.client.connect(**connect_kwargs)
self._is_connected = True
logger.info(f"SSH connection established to {self.username}@{self.host}:{self.port}")
def run(self, commands: Union[str, List[str]], use_sudo: bool = False,
sudo_password: Optional[str] = None) -> str:
"""Run commands over SSH and return output."""
if not self._is_connected:
self.connect()
if isinstance(commands, str):
commands = [commands]
outputs = []
for command in commands:
try:
# Prepare command with sudo if needed
if use_sudo:
if sudo_password:
# Use echo to pipe password to sudo -S (read from stdin)
full_command = f"echo '{sudo_password}' | sudo -S {command}"
else:
# Try sudo without password (for passwordless sudo)
full_command = f"sudo {command}"
else:
full_command = command
stdin, stdout, stderr = self.client.exec_command(full_command)
# For sudo commands with password, we need to handle stdin
if use_sudo and sudo_password:
stdin.write(f"{sudo_password}\n")
stdin.flush()
output = stdout.read().decode()
error = stderr.read().decode()
# Filter out sudo password prompt from error output
if use_sudo and error:
error_lines = error.split('\n')
filtered_error = '\n'.join(
line for line in error_lines
if not line.startswith('[sudo]') and line.strip()
)
error = filtered_error
if error:
outputs.append(f"$ {command}\n{output}{error}")
else:
outputs.append(f"$ {command}\n{output}")
except Exception as e:
outputs.append(f"$ {command}\nError: {str(e)}")
return "\n\n".join(outputs)
def __del__(self):
"""Close SSH connection when object is destroyed."""
if self.client and self._is_connected:
self.client.close()
logger.info(f"SSH connection closed to {self.username}@{self.host}")
def _get_default_ssh_process(host: str, username: str, **kwargs) -> Any:
"""Get default SSH process with persistent connection."""
return SSHProcess(host=host, username=username, **kwargs)
class SSHTool(BaseTool):
"""Tool to run commands on remote servers via SSH with persistent connection."""
process: Optional[SSHProcess] = Field(default=None)
"""SSH process with persistent connection."""
# Connection parameters
host: str = Field(..., description="SSH host address")
username: str = Field(..., description="SSH username")
port: int = Field(default=22, description="SSH port")
password: Optional[str] = Field(default=None, description="SSH password")
key_filename: Optional[str] = Field(default=None, description="Path to SSH key")
# Tool configuration
name: str = "ssh"
description: str = """
Run shell commands on a remote server via SSH.
This tool maintains a persistent SSH connection and allows executing
commands on the remote server. It supports both regular and privileged
(sudo) command execution.
Use the 'use_sudo' parameter to run commands with sudo privileges.
If sudo requires a password, provide it via 'sudo_password'.
Examples:
- Regular command: {"commands": "ls -la"}
- Sudo command: {"commands": "apt update", "use_sudo": true}
- Multiple commands: {"commands": ["df -h", "free -m", "top -n 1"]}
"""
args_schema: Type[BaseModel] = SSHInput
model_config = {
"arbitrary_types_allowed": True
}
def __init__(self, **kwargs):
"""Initialize SSH tool and set description."""
super().__init__(**kwargs)
self.description = f"Run commands on remote server {self.username}@{self.host}:{self.port}"
# Initialize the SSH process (but don't connect yet)
self.process = SSHProcess(
host=self.host,
username=self.username,
port=self.port,
password=self.password,
key_filename=self.key_filename
)
def _run(
self,
commands: Union[str, List[str]],
use_sudo: bool = False,
sudo_password: Optional[str] = None,
**kwargs,
) -> str:
"""Run commands on remote server and return output."""
try:
print(f"Executing SSH command on {self.username}@{self.host}:{self.port}") # noqa: T201
print(f"Commands: {commands}") # noqa: T201
if use_sudo:
print("Running with sudo privileges") # noqa: T201
# Safety check for privileged commands
if use_sudo:
user_input = input("Proceed with sudo command execution? (y/n): ").lower()
if user_input == "y":
return self.process.run(commands, use_sudo=use_sudo, sudo_password=sudo_password)
else:
logger.info("User aborted sudo command execution.")
return "Command execution aborted by user."
else:
user_input = input("Proceed with SSH command execution? (y/n): ").lower()
if user_input == "y":
return self.process.run(commands)
else:
logger.info("Invalid input. User aborted SSH command execution.")
return "Command execution aborted by user."
except Exception as e:
logger.error(f"Error during SSH command execution: {e}")
return f"Error: {str(e)}"