203 lines
6.9 KiB
Python
203 lines
6.9 KiB
Python
import getpass
|
|
import logging
|
|
import time
|
|
from typing import Any, List, Optional, Type, Union
|
|
|
|
from langchain_core.tools import BaseTool
|
|
from pydantic import BaseModel, Field
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SSHInput(BaseModel):
|
|
"""Input for SSH tool."""
|
|
commands: Union[str, List[str]] = Field(
|
|
...,
|
|
description="Command(s) to run on the remote server"
|
|
)
|
|
|
|
|
|
class SSHSession:
|
|
"""Manages a persistent SSH session with sudo caching."""
|
|
|
|
def __init__(self, host: str, username: str, port: int = 22,
|
|
key_filename: Optional[str] = None, password: Optional[str] = None):
|
|
self.host = host
|
|
self.username = username
|
|
self.port = port
|
|
self.key_filename = key_filename
|
|
self.password = password
|
|
self.client = None
|
|
self._sudo_password = None
|
|
self._sudo_timestamp = None
|
|
self._sudo_timeout = 300 # 5 minutes, like default sudo
|
|
|
|
def connect(self):
|
|
"""Establish SSH connection if not already connected."""
|
|
if self.client:
|
|
return
|
|
|
|
try:
|
|
import paramiko
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"paramiko is required for SSH functionality. "
|
|
"Install it with `pip install paramiko`"
|
|
) from e
|
|
|
|
self.client = paramiko.SSHClient()
|
|
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
|
|
connect_kwargs = {
|
|
"hostname": self.host,
|
|
"port": self.port,
|
|
"username": self.username,
|
|
}
|
|
|
|
if self.password:
|
|
connect_kwargs["password"] = self.password
|
|
if self.key_filename:
|
|
connect_kwargs["key_filename"] = self.key_filename
|
|
|
|
self.client.connect(**connect_kwargs)
|
|
logger.info(f"SSH connection established to {self.username}@{self.host}:{self.port}")
|
|
|
|
def _needs_sudo_password(self) -> bool:
|
|
"""Check if we need to ask for sudo password."""
|
|
if not self._sudo_password:
|
|
return True
|
|
if not self._sudo_timestamp:
|
|
return True
|
|
# Check if sudo timeout has expired
|
|
if time.time() - self._sudo_timestamp > self._sudo_timeout:
|
|
self._sudo_password = None
|
|
self._sudo_timestamp = None
|
|
return True
|
|
return False
|
|
|
|
def execute(self, command: str) -> str:
|
|
"""Execute a single command, handling sudo automatically."""
|
|
print(f"🔧 Executing command: {command}")
|
|
|
|
if not self.client:
|
|
self.connect()
|
|
|
|
# Check if command needs sudo
|
|
needs_sudo = command.strip().startswith('sudo ')
|
|
|
|
if needs_sudo:
|
|
# Remove 'sudo ' prefix if present
|
|
actual_command = command.strip()[5:].strip()
|
|
|
|
# Check if we need to get sudo password
|
|
if self._needs_sudo_password():
|
|
self._sudo_password = getpass.getpass(f"[sudo] password for {self.username}: ")
|
|
self._sudo_timestamp = time.time()
|
|
|
|
# Execute with sudo -S and pass password via stdin
|
|
full_command = f"sudo -S {actual_command}"
|
|
stdin, stdout, stderr = self.client.exec_command(full_command, get_pty=True)
|
|
|
|
# Send password
|
|
stdin.write(f"{self._sudo_password}\n")
|
|
stdin.flush()
|
|
|
|
# Get output
|
|
output = stdout.read().decode()
|
|
error = stderr.read().decode()
|
|
|
|
# Clean up sudo prompt from output
|
|
lines = output.split('\n')
|
|
cleaned_lines = [line for line in lines if not line.strip().startswith('[sudo]')]
|
|
output = '\n'.join(cleaned_lines)
|
|
|
|
else:
|
|
# Regular command without sudo
|
|
stdin, stdout, stderr = self.client.exec_command(command)
|
|
output = stdout.read().decode()
|
|
error = stderr.read().decode()
|
|
|
|
# Combine output and error
|
|
result = output
|
|
if error:
|
|
result += f"\n{error}" if result else error
|
|
|
|
return result.strip()
|
|
|
|
def run_commands(self, commands: Union[str, List[str]]) -> str:
|
|
"""Run one or more commands and return combined output."""
|
|
if isinstance(commands, str):
|
|
commands = [commands]
|
|
|
|
outputs = []
|
|
for cmd in commands:
|
|
try:
|
|
output = self.execute(cmd)
|
|
outputs.append(f"$ {cmd}\n{output}")
|
|
except Exception as e:
|
|
outputs.append(f"$ {cmd}\nError: {str(e)}")
|
|
|
|
return "\n\n".join(outputs)
|
|
|
|
def close(self):
|
|
"""Close the SSH connection."""
|
|
if self.client:
|
|
self.client.close()
|
|
self.client = None
|
|
logger.info(f"SSH connection closed to {self.username}@{self.host}")
|
|
|
|
def __del__(self):
|
|
"""Ensure connection is closed when object is destroyed."""
|
|
self.close()
|
|
|
|
|
|
class SSHTool(BaseTool):
|
|
"""Simple SSH tool that behaves like a normal terminal session."""
|
|
|
|
name: str = "ssh"
|
|
description: str = """Execute commands on a remote server via SSH.
|
|
|
|
Simply pass the commands you want to run. Use 'sudo' prefix for privileged commands.
|
|
The tool will automatically handle password prompts and maintain sudo session.
|
|
|
|
Examples:
|
|
- {"commands": "ls -la"}
|
|
- {"commands": "sudo apt update"}
|
|
- {"commands": ["df -h", "sudo systemctl status nginx", "free -m"]}
|
|
"""
|
|
args_schema: Type[BaseModel] = SSHInput
|
|
|
|
# Connection parameters
|
|
host: str = Field(..., description="SSH host")
|
|
username: str = Field(..., description="SSH username")
|
|
port: int = Field(default=22, description="SSH port")
|
|
key_filename: Optional[str] = Field(default=None, description="SSH key path")
|
|
password: Optional[str] = Field(default=None, description="SSH password")
|
|
|
|
# Session management
|
|
session: Optional[SSHSession] = Field(default=None, exclude=True)
|
|
|
|
model_config = {
|
|
"arbitrary_types_allowed": True
|
|
}
|
|
|
|
def __init__(self, **kwargs):
|
|
"""Initialize SSH tool."""
|
|
super().__init__(**kwargs)
|
|
# Create session but don't connect yet
|
|
self.session = SSHSession(
|
|
host=self.host,
|
|
username=self.username,
|
|
port=self.port,
|
|
key_filename=self.key_filename,
|
|
password=self.password
|
|
)
|
|
|
|
def _run(self, commands: Union[str, List[str]], **kwargs) -> str:
|
|
"""Execute commands on remote server."""
|
|
try:
|
|
print(f"Executing on {self.username}@{self.host}:{self.port}")
|
|
return self.session.run_commands(commands)
|
|
except Exception as e:
|
|
logger.error(f"SSH execution error: {e}")
|
|
return f"Error: {str(e)}" |