rewrite ssh tool for sudo support

This commit is contained in:
Gaetan Hurel 2025-06-29 15:38:42 +02:00
parent a81dd5484d
commit d06dabfa3c
No known key found for this signature in database
2 changed files with 250 additions and 316 deletions

View File

@ -1,66 +1,42 @@
import getpass
import logging import logging
import warnings import time
from typing import Any, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SSHInput(BaseModel): class SSHInput(BaseModel):
"""Commands for the SSH tool.""" """Input for SSH tool."""
commands: Union[str, List[str]] = Field( commands: Union[str, List[str]] = Field(
..., ...,
description="List of commands to run on the remote server", description="Command(s) to run on the remote server"
) )
"""List of commands to run."""
use_sudo: bool = Field(
default=False,
description="Whether to run commands with sudo privileges"
)
"""Whether to run commands with sudo."""
sudo_password: Optional[str] = Field(
default=None,
description="Password for sudo if required (will be used securely)"
)
"""Password for sudo if required."""
@model_validator(mode="before")
def _validate_commands(cls, values: dict) -> Any:
"""Validate commands."""
commands = values.get("commands")
if not isinstance(commands, list):
values["commands"] = [commands]
# Warn that the SSH tool has no safeguards
warnings.warn(
"The SSH tool has no safeguards by default. Use at your own risk."
)
return values
class SSHProcess: class SSHSession:
"""Persistent SSH connection for command execution.""" """Manages a persistent SSH session with sudo caching."""
def __init__(self, host: str, username: str, port: int = 22, def __init__(self, host: str, username: str, port: int = 22,
password: Optional[str] = None, key_filename: Optional[str] = None, key_filename: Optional[str] = None, password: Optional[str] = None):
**kwargs):
"""Initialize SSH process with connection parameters."""
self.host = host self.host = host
self.username = username self.username = username
self.port = port self.port = port
self.password = password
self.key_filename = key_filename self.key_filename = key_filename
self.password = password
self.client = None self.client = None
self._is_connected = False self._sudo_password = None
self._sudo_timestamp = None
self._sudo_timeout = 300 # 5 minutes, like default sudo
def connect(self): def connect(self):
"""Establish SSH connection.""" """Establish SSH connection if not already connected."""
if self._is_connected: if self.client:
return return
try: try:
import paramiko import paramiko
except ImportError as e: except ImportError as e:
@ -84,151 +60,142 @@ class SSHProcess:
connect_kwargs["key_filename"] = self.key_filename connect_kwargs["key_filename"] = self.key_filename
self.client.connect(**connect_kwargs) self.client.connect(**connect_kwargs)
self._is_connected = True
logger.info(f"SSH connection established to {self.username}@{self.host}:{self.port}") logger.info(f"SSH connection established to {self.username}@{self.host}:{self.port}")
def run(self, commands: Union[str, List[str]], use_sudo: bool = False, def _needs_sudo_password(self) -> bool:
sudo_password: Optional[str] = None) -> str: """Check if we need to ask for sudo password."""
"""Run commands over SSH and return output.""" if not self._sudo_password:
if not self._is_connected: return True
if not self._sudo_timestamp:
return True
# Check if sudo timeout has expired
if time.time() - self._sudo_timestamp > self._sudo_timeout:
self._sudo_password = None
self._sudo_timestamp = None
return True
return False
def execute(self, command: str) -> str:
"""Execute a single command, handling sudo automatically."""
if not self.client:
self.connect() self.connect()
# Check if command needs sudo
needs_sudo = command.strip().startswith('sudo ')
if needs_sudo:
# Remove 'sudo ' prefix if present
actual_command = command.strip()[5:].strip()
# Check if we need to get sudo password
if self._needs_sudo_password():
self._sudo_password = getpass.getpass(f"[sudo] password for {self.username}: ")
self._sudo_timestamp = time.time()
# Execute with sudo -S and pass password via stdin
full_command = f"sudo -S {actual_command}"
stdin, stdout, stderr = self.client.exec_command(full_command, get_pty=True)
# Send password
stdin.write(f"{self._sudo_password}\n")
stdin.flush()
# Get output
output = stdout.read().decode()
error = stderr.read().decode()
# Clean up sudo prompt from output
lines = output.split('\n')
cleaned_lines = [line for line in lines if not line.strip().startswith('[sudo]')]
output = '\n'.join(cleaned_lines)
else:
# Regular command without sudo
stdin, stdout, stderr = self.client.exec_command(command)
output = stdout.read().decode()
error = stderr.read().decode()
# Combine output and error
result = output
if error:
result += f"\n{error}" if result else error
return result.strip()
def run_commands(self, commands: Union[str, List[str]]) -> str:
"""Run one or more commands and return combined output."""
if isinstance(commands, str): if isinstance(commands, str):
commands = [commands] commands = [commands]
outputs = [] outputs = []
for command in commands: for cmd in commands:
try: try:
# Prepare command with sudo if needed output = self.execute(cmd)
if use_sudo: outputs.append(f"$ {cmd}\n{output}")
if sudo_password:
# Use echo to pipe password to sudo -S (read from stdin)
full_command = f"echo '{sudo_password}' | sudo -S {command}"
else:
# Try sudo without password (for passwordless sudo)
full_command = f"sudo {command}"
else:
full_command = command
stdin, stdout, stderr = self.client.exec_command(full_command)
# For sudo commands with password, we need to handle stdin
if use_sudo and sudo_password:
stdin.write(f"{sudo_password}\n")
stdin.flush()
output = stdout.read().decode()
error = stderr.read().decode()
# Filter out sudo password prompt from error output
if use_sudo and error:
error_lines = error.split('\n')
filtered_error = '\n'.join(
line for line in error_lines
if not line.startswith('[sudo]') and line.strip()
)
error = filtered_error
if error:
outputs.append(f"$ {command}\n{output}{error}")
else:
outputs.append(f"$ {command}\n{output}")
except Exception as e: except Exception as e:
outputs.append(f"$ {command}\nError: {str(e)}") outputs.append(f"$ {cmd}\nError: {str(e)}")
return "\n\n".join(outputs) return "\n\n".join(outputs)
def __del__(self): def close(self):
"""Close SSH connection when object is destroyed.""" """Close the SSH connection."""
if self.client and self._is_connected: if self.client:
self.client.close() self.client.close()
self.client = None
logger.info(f"SSH connection closed to {self.username}@{self.host}") logger.info(f"SSH connection closed to {self.username}@{self.host}")
def __del__(self):
def _get_default_ssh_process(host: str, username: str, **kwargs) -> Any: """Ensure connection is closed when object is destroyed."""
"""Get default SSH process with persistent connection.""" self.close()
return SSHProcess(host=host, username=username, **kwargs)
class SSHTool(BaseTool): class SSHTool(BaseTool):
"""Tool to run commands on remote servers via SSH with persistent connection.""" """Simple SSH tool that behaves like a normal terminal session."""
process: Optional[SSHProcess] = Field(default=None)
"""SSH process with persistent connection."""
# Connection parameters
host: str = Field(..., description="SSH host address")
username: str = Field(..., description="SSH username")
port: int = Field(default=22, description="SSH port")
password: Optional[str] = Field(default=None, description="SSH password")
key_filename: Optional[str] = Field(default=None, description="Path to SSH key")
# Tool configuration
name: str = "ssh" name: str = "ssh"
description: str = """ description: str = """Execute commands on a remote server via SSH.
Run shell commands on a remote server via SSH.
This tool maintains a persistent SSH connection and allows executing Simply pass the commands you want to run. Use 'sudo' prefix for privileged commands.
commands on the remote server. It supports both regular and privileged The tool will automatically handle password prompts and maintain sudo session.
(sudo) command execution.
Use the 'use_sudo' parameter to run commands with sudo privileges.
If sudo requires a password, provide it via 'sudo_password'.
Examples: Examples:
- Regular command: {"commands": "ls -la"} - {"commands": "ls -la"}
- Sudo command: {"commands": "apt update", "use_sudo": true} - {"commands": "sudo apt update"}
- Multiple commands: {"commands": ["df -h", "free -m", "top -n 1"]} - {"commands": ["df -h", "sudo systemctl status nginx", "free -m"]}
""" """
args_schema: Type[BaseModel] = SSHInput args_schema: Type[BaseModel] = SSHInput
# Connection parameters
host: str = Field(..., description="SSH host")
username: str = Field(..., description="SSH username")
port: int = Field(default=22, description="SSH port")
key_filename: Optional[str] = Field(default=None, description="SSH key path")
password: Optional[str] = Field(default=None, description="SSH password")
# Session management
session: Optional[SSHSession] = Field(default=None, exclude=True)
model_config = { model_config = {
"arbitrary_types_allowed": True "arbitrary_types_allowed": True
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
"""Initialize SSH tool and set description.""" """Initialize SSH tool."""
super().__init__(**kwargs) super().__init__(**kwargs)
# Create session but don't connect yet
self.description = f"Run commands on remote server {self.username}@{self.host}:{self.port}" self.session = SSHSession(
# Initialize the SSH process (but don't connect yet)
self.process = SSHProcess(
host=self.host, host=self.host,
username=self.username, username=self.username,
port=self.port, port=self.port,
password=self.password, key_filename=self.key_filename,
key_filename=self.key_filename password=self.password
) )
def _run( def _run(self, commands: Union[str, List[str]], **kwargs) -> str:
self, """Execute commands on remote server."""
commands: Union[str, List[str]],
use_sudo: bool = False,
sudo_password: Optional[str] = None,
**kwargs,
) -> str:
"""Run commands on remote server and return output."""
try: try:
print(f"Executing SSH command on {self.username}@{self.host}:{self.port}") # noqa: T201 print(f"Executing on {self.username}@{self.host}:{self.port}")
print(f"Commands: {commands}") # noqa: T201 return self.session.run_commands(commands)
if use_sudo:
print("Running with sudo privileges") # noqa: T201
# Safety check for privileged commands
if use_sudo:
user_input = input("Proceed with sudo command execution? (y/n): ").lower()
if user_input == "y":
return self.process.run(commands, use_sudo=use_sudo, sudo_password=sudo_password)
else:
logger.info("User aborted sudo command execution.")
return "Command execution aborted by user."
else:
user_input = input("Proceed with SSH command execution? (y/n): ").lower()
if user_input == "y":
return self.process.run(commands)
else:
logger.info("Invalid input. User aborted SSH command execution.")
return "Command execution aborted by user."
except Exception as e: except Exception as e:
logger.error(f"Error during SSH command execution: {e}") logger.error(f"SSH execution error: {e}")
return f"Error: {str(e)}" return f"Error: {str(e)}"

View File

@ -1,66 +1,42 @@
import getpass
import logging import logging
import warnings import time
from typing import Any, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SSHInput(BaseModel): class SSHInput(BaseModel):
"""Commands for the SSH tool.""" """Input for SSH tool."""
commands: Union[str, List[str]] = Field( commands: Union[str, List[str]] = Field(
..., ...,
description="List of commands to run on the remote server", description="Command(s) to run on the remote server"
) )
"""List of commands to run."""
use_sudo: bool = Field(
default=False,
description="Whether to run commands with sudo privileges"
)
"""Whether to run commands with sudo."""
sudo_password: Optional[str] = Field(
default=None,
description="Password for sudo if required (will be used securely)"
)
"""Password for sudo if required."""
@model_validator(mode="before")
def _validate_commands(cls, values: dict) -> Any:
"""Validate commands."""
commands = values.get("commands")
if not isinstance(commands, list):
values["commands"] = [commands]
# Warn that the SSH tool has no safeguards
warnings.warn(
"The SSH tool has no safeguards by default. Use at your own risk."
)
return values
class SSHProcess: class SSHSession:
"""Persistent SSH connection for command execution.""" """Manages a persistent SSH session with sudo caching."""
def __init__(self, host: str, username: str, port: int = 22, def __init__(self, host: str, username: str, port: int = 22,
password: Optional[str] = None, key_filename: Optional[str] = None, key_filename: Optional[str] = None, password: Optional[str] = None):
**kwargs):
"""Initialize SSH process with connection parameters."""
self.host = host self.host = host
self.username = username self.username = username
self.port = port self.port = port
self.password = password
self.key_filename = key_filename self.key_filename = key_filename
self.password = password
self.client = None self.client = None
self._is_connected = False self._sudo_password = None
self._sudo_timestamp = None
self._sudo_timeout = 300 # 5 minutes, like default sudo
def connect(self): def connect(self):
"""Establish SSH connection.""" """Establish SSH connection if not already connected."""
if self._is_connected: if self.client:
return return
try: try:
import paramiko import paramiko
except ImportError as e: except ImportError as e:
@ -84,151 +60,142 @@ class SSHProcess:
connect_kwargs["key_filename"] = self.key_filename connect_kwargs["key_filename"] = self.key_filename
self.client.connect(**connect_kwargs) self.client.connect(**connect_kwargs)
self._is_connected = True
logger.info(f"SSH connection established to {self.username}@{self.host}:{self.port}") logger.info(f"SSH connection established to {self.username}@{self.host}:{self.port}")
def run(self, commands: Union[str, List[str]], use_sudo: bool = False, def _needs_sudo_password(self) -> bool:
sudo_password: Optional[str] = None) -> str: """Check if we need to ask for sudo password."""
"""Run commands over SSH and return output.""" if not self._sudo_password:
if not self._is_connected: return True
if not self._sudo_timestamp:
return True
# Check if sudo timeout has expired
if time.time() - self._sudo_timestamp > self._sudo_timeout:
self._sudo_password = None
self._sudo_timestamp = None
return True
return False
def execute(self, command: str) -> str:
"""Execute a single command, handling sudo automatically."""
if not self.client:
self.connect() self.connect()
# Check if command needs sudo
needs_sudo = command.strip().startswith('sudo ')
if needs_sudo:
# Remove 'sudo ' prefix if present
actual_command = command.strip()[5:].strip()
# Check if we need to get sudo password
if self._needs_sudo_password():
self._sudo_password = getpass.getpass(f"[sudo] password for {self.username}: ")
self._sudo_timestamp = time.time()
# Execute with sudo -S and pass password via stdin
full_command = f"sudo -S {actual_command}"
stdin, stdout, stderr = self.client.exec_command(full_command, get_pty=True)
# Send password
stdin.write(f"{self._sudo_password}\n")
stdin.flush()
# Get output
output = stdout.read().decode()
error = stderr.read().decode()
# Clean up sudo prompt from output
lines = output.split('\n')
cleaned_lines = [line for line in lines if not line.strip().startswith('[sudo]')]
output = '\n'.join(cleaned_lines)
else:
# Regular command without sudo
stdin, stdout, stderr = self.client.exec_command(command)
output = stdout.read().decode()
error = stderr.read().decode()
# Combine output and error
result = output
if error:
result += f"\n{error}" if result else error
return result.strip()
def run_commands(self, commands: Union[str, List[str]]) -> str:
"""Run one or more commands and return combined output."""
if isinstance(commands, str): if isinstance(commands, str):
commands = [commands] commands = [commands]
outputs = [] outputs = []
for command in commands: for cmd in commands:
try: try:
# Prepare command with sudo if needed output = self.execute(cmd)
if use_sudo: outputs.append(f"$ {cmd}\n{output}")
if sudo_password:
# Use echo to pipe password to sudo -S (read from stdin)
full_command = f"echo '{sudo_password}' | sudo -S {command}"
else:
# Try sudo without password (for passwordless sudo)
full_command = f"sudo {command}"
else:
full_command = command
stdin, stdout, stderr = self.client.exec_command(full_command)
# For sudo commands with password, we need to handle stdin
if use_sudo and sudo_password:
stdin.write(f"{sudo_password}\n")
stdin.flush()
output = stdout.read().decode()
error = stderr.read().decode()
# Filter out sudo password prompt from error output
if use_sudo and error:
error_lines = error.split('\n')
filtered_error = '\n'.join(
line for line in error_lines
if not line.startswith('[sudo]') and line.strip()
)
error = filtered_error
if error:
outputs.append(f"$ {command}\n{output}{error}")
else:
outputs.append(f"$ {command}\n{output}")
except Exception as e: except Exception as e:
outputs.append(f"$ {command}\nError: {str(e)}") outputs.append(f"$ {cmd}\nError: {str(e)}")
return "\n\n".join(outputs) return "\n\n".join(outputs)
def __del__(self): def close(self):
"""Close SSH connection when object is destroyed.""" """Close the SSH connection."""
if self.client and self._is_connected: if self.client:
self.client.close() self.client.close()
self.client = None
logger.info(f"SSH connection closed to {self.username}@{self.host}") logger.info(f"SSH connection closed to {self.username}@{self.host}")
def __del__(self):
def _get_default_ssh_process(host: str, username: str, **kwargs) -> Any: """Ensure connection is closed when object is destroyed."""
"""Get default SSH process with persistent connection.""" self.close()
return SSHProcess(host=host, username=username, **kwargs)
class SSHTool(BaseTool): class SSHTool(BaseTool):
"""Tool to run commands on remote servers via SSH with persistent connection.""" """Simple SSH tool that behaves like a normal terminal session."""
process: Optional[SSHProcess] = Field(default=None)
"""SSH process with persistent connection."""
# Connection parameters
host: str = Field(..., description="SSH host address")
username: str = Field(..., description="SSH username")
port: int = Field(default=22, description="SSH port")
password: Optional[str] = Field(default=None, description="SSH password")
key_filename: Optional[str] = Field(default=None, description="Path to SSH key")
# Tool configuration
name: str = "ssh" name: str = "ssh"
description: str = """ description: str = """Execute commands on a remote server via SSH.
Run shell commands on a remote server via SSH.
This tool maintains a persistent SSH connection and allows executing Simply pass the commands you want to run. Use 'sudo' prefix for privileged commands.
commands on the remote server. It supports both regular and privileged The tool will automatically handle password prompts and maintain sudo session.
(sudo) command execution.
Use the 'use_sudo' parameter to run commands with sudo privileges.
If sudo requires a password, provide it via 'sudo_password'.
Examples: Examples:
- Regular command: {"commands": "ls -la"} - {"commands": "ls -la"}
- Sudo command: {"commands": "apt update", "use_sudo": true} - {"commands": "sudo apt update"}
- Multiple commands: {"commands": ["df -h", "free -m", "top -n 1"]} - {"commands": ["df -h", "sudo systemctl status nginx", "free -m"]}
""" """
args_schema: Type[BaseModel] = SSHInput args_schema: Type[BaseModel] = SSHInput
# Connection parameters
host: str = Field(..., description="SSH host")
username: str = Field(..., description="SSH username")
port: int = Field(default=22, description="SSH port")
key_filename: Optional[str] = Field(default=None, description="SSH key path")
password: Optional[str] = Field(default=None, description="SSH password")
# Session management
session: Optional[SSHSession] = Field(default=None, exclude=True)
model_config = { model_config = {
"arbitrary_types_allowed": True "arbitrary_types_allowed": True
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
"""Initialize SSH tool and set description.""" """Initialize SSH tool."""
super().__init__(**kwargs) super().__init__(**kwargs)
# Create session but don't connect yet
self.description = f"Run commands on remote server {self.username}@{self.host}:{self.port}" self.session = SSHSession(
# Initialize the SSH process (but don't connect yet)
self.process = SSHProcess(
host=self.host, host=self.host,
username=self.username, username=self.username,
port=self.port, port=self.port,
password=self.password, key_filename=self.key_filename,
key_filename=self.key_filename password=self.password
) )
def _run( def _run(self, commands: Union[str, List[str]], **kwargs) -> str:
self, """Execute commands on remote server."""
commands: Union[str, List[str]],
use_sudo: bool = False,
sudo_password: Optional[str] = None,
**kwargs,
) -> str:
"""Run commands on remote server and return output."""
try: try:
print(f"Executing SSH command on {self.username}@{self.host}:{self.port}") # noqa: T201 print(f"Executing on {self.username}@{self.host}:{self.port}")
print(f"Commands: {commands}") # noqa: T201 return self.session.run_commands(commands)
if use_sudo:
print("Running with sudo privileges") # noqa: T201
# Safety check for privileged commands
if use_sudo:
user_input = input("Proceed with sudo command execution? (y/n): ").lower()
if user_input == "y":
return self.process.run(commands, use_sudo=use_sudo, sudo_password=sudo_password)
else:
logger.info("User aborted sudo command execution.")
return "Command execution aborted by user."
else:
user_input = input("Proceed with SSH command execution? (y/n): ").lower()
if user_input == "y":
return self.process.run(commands)
else:
logger.info("Invalid input. User aborted SSH command execution.")
return "Command execution aborted by user."
except Exception as e: except Exception as e:
logger.error(f"Error during SSH command execution: {e}") logger.error(f"SSH execution error: {e}")
return f"Error: {str(e)}" return f"Error: {str(e)}"