rewrite ssh tool for sudo support
This commit is contained in:
parent
a81dd5484d
commit
d06dabfa3c
@ -1,66 +1,42 @@
|
||||
import getpass
|
||||
import logging
|
||||
import warnings
|
||||
import time
|
||||
from typing import Any, List, Optional, Type, Union
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SSHInput(BaseModel):
|
||||
"""Commands for the SSH tool."""
|
||||
|
||||
"""Input for SSH tool."""
|
||||
commands: Union[str, List[str]] = Field(
|
||||
...,
|
||||
description="List of commands to run on the remote server",
|
||||
description="Command(s) to run on the remote server"
|
||||
)
|
||||
"""List of commands to run."""
|
||||
|
||||
|
||||
class SSHSession:
|
||||
"""Manages a persistent SSH session with sudo caching."""
|
||||
|
||||
use_sudo: bool = Field(
|
||||
default=False,
|
||||
description="Whether to run commands with sudo privileges"
|
||||
)
|
||||
"""Whether to run commands with sudo."""
|
||||
|
||||
sudo_password: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Password for sudo if required (will be used securely)"
|
||||
)
|
||||
"""Password for sudo if required."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
def _validate_commands(cls, values: dict) -> Any:
|
||||
"""Validate commands."""
|
||||
commands = values.get("commands")
|
||||
if not isinstance(commands, list):
|
||||
values["commands"] = [commands]
|
||||
# Warn that the SSH tool has no safeguards
|
||||
warnings.warn(
|
||||
"The SSH tool has no safeguards by default. Use at your own risk."
|
||||
)
|
||||
return values
|
||||
|
||||
|
||||
class SSHProcess:
|
||||
"""Persistent SSH connection for command execution."""
|
||||
|
||||
def __init__(self, host: str, username: str, port: int = 22,
|
||||
password: Optional[str] = None, key_filename: Optional[str] = None,
|
||||
**kwargs):
|
||||
"""Initialize SSH process with connection parameters."""
|
||||
key_filename: Optional[str] = None, password: Optional[str] = None):
|
||||
self.host = host
|
||||
self.username = username
|
||||
self.port = port
|
||||
self.password = password
|
||||
self.key_filename = key_filename
|
||||
self.password = password
|
||||
self.client = None
|
||||
self._is_connected = False
|
||||
|
||||
self._sudo_password = None
|
||||
self._sudo_timestamp = None
|
||||
self._sudo_timeout = 300 # 5 minutes, like default sudo
|
||||
|
||||
def connect(self):
|
||||
"""Establish SSH connection."""
|
||||
if self._is_connected:
|
||||
"""Establish SSH connection if not already connected."""
|
||||
if self.client:
|
||||
return
|
||||
|
||||
try:
|
||||
import paramiko
|
||||
except ImportError as e:
|
||||
@ -68,7 +44,7 @@ class SSHProcess:
|
||||
"paramiko is required for SSH functionality. "
|
||||
"Install it with `pip install paramiko`"
|
||||
) from e
|
||||
|
||||
|
||||
self.client = paramiko.SSHClient()
|
||||
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
|
||||
@ -84,151 +60,142 @@ class SSHProcess:
|
||||
connect_kwargs["key_filename"] = self.key_filename
|
||||
|
||||
self.client.connect(**connect_kwargs)
|
||||
self._is_connected = True
|
||||
logger.info(f"SSH connection established to {self.username}@{self.host}:{self.port}")
|
||||
|
||||
def run(self, commands: Union[str, List[str]], use_sudo: bool = False,
|
||||
sudo_password: Optional[str] = None) -> str:
|
||||
"""Run commands over SSH and return output."""
|
||||
if not self._is_connected:
|
||||
self.connect()
|
||||
|
||||
def _needs_sudo_password(self) -> bool:
|
||||
"""Check if we need to ask for sudo password."""
|
||||
if not self._sudo_password:
|
||||
return True
|
||||
if not self._sudo_timestamp:
|
||||
return True
|
||||
# Check if sudo timeout has expired
|
||||
if time.time() - self._sudo_timestamp > self._sudo_timeout:
|
||||
self._sudo_password = None
|
||||
self._sudo_timestamp = None
|
||||
return True
|
||||
return False
|
||||
|
||||
def execute(self, command: str) -> str:
|
||||
"""Execute a single command, handling sudo automatically."""
|
||||
if not self.client:
|
||||
self.connect()
|
||||
|
||||
# Check if command needs sudo
|
||||
needs_sudo = command.strip().startswith('sudo ')
|
||||
|
||||
if needs_sudo:
|
||||
# Remove 'sudo ' prefix if present
|
||||
actual_command = command.strip()[5:].strip()
|
||||
|
||||
# Check if we need to get sudo password
|
||||
if self._needs_sudo_password():
|
||||
self._sudo_password = getpass.getpass(f"[sudo] password for {self.username}: ")
|
||||
self._sudo_timestamp = time.time()
|
||||
|
||||
# Execute with sudo -S and pass password via stdin
|
||||
full_command = f"sudo -S {actual_command}"
|
||||
stdin, stdout, stderr = self.client.exec_command(full_command, get_pty=True)
|
||||
|
||||
# Send password
|
||||
stdin.write(f"{self._sudo_password}\n")
|
||||
stdin.flush()
|
||||
|
||||
# Get output
|
||||
output = stdout.read().decode()
|
||||
error = stderr.read().decode()
|
||||
|
||||
# Clean up sudo prompt from output
|
||||
lines = output.split('\n')
|
||||
cleaned_lines = [line for line in lines if not line.strip().startswith('[sudo]')]
|
||||
output = '\n'.join(cleaned_lines)
|
||||
|
||||
else:
|
||||
# Regular command without sudo
|
||||
stdin, stdout, stderr = self.client.exec_command(command)
|
||||
output = stdout.read().decode()
|
||||
error = stderr.read().decode()
|
||||
|
||||
# Combine output and error
|
||||
result = output
|
||||
if error:
|
||||
result += f"\n{error}" if result else error
|
||||
|
||||
return result.strip()
|
||||
|
||||
def run_commands(self, commands: Union[str, List[str]]) -> str:
|
||||
"""Run one or more commands and return combined output."""
|
||||
if isinstance(commands, str):
|
||||
commands = [commands]
|
||||
|
||||
|
||||
outputs = []
|
||||
for command in commands:
|
||||
for cmd in commands:
|
||||
try:
|
||||
# Prepare command with sudo if needed
|
||||
if use_sudo:
|
||||
if sudo_password:
|
||||
# Use echo to pipe password to sudo -S (read from stdin)
|
||||
full_command = f"echo '{sudo_password}' | sudo -S {command}"
|
||||
else:
|
||||
# Try sudo without password (for passwordless sudo)
|
||||
full_command = f"sudo {command}"
|
||||
else:
|
||||
full_command = command
|
||||
|
||||
stdin, stdout, stderr = self.client.exec_command(full_command)
|
||||
|
||||
# For sudo commands with password, we need to handle stdin
|
||||
if use_sudo and sudo_password:
|
||||
stdin.write(f"{sudo_password}\n")
|
||||
stdin.flush()
|
||||
|
||||
output = stdout.read().decode()
|
||||
error = stderr.read().decode()
|
||||
|
||||
# Filter out sudo password prompt from error output
|
||||
if use_sudo and error:
|
||||
error_lines = error.split('\n')
|
||||
filtered_error = '\n'.join(
|
||||
line for line in error_lines
|
||||
if not line.startswith('[sudo]') and line.strip()
|
||||
)
|
||||
error = filtered_error
|
||||
|
||||
if error:
|
||||
outputs.append(f"$ {command}\n{output}{error}")
|
||||
else:
|
||||
outputs.append(f"$ {command}\n{output}")
|
||||
output = self.execute(cmd)
|
||||
outputs.append(f"$ {cmd}\n{output}")
|
||||
except Exception as e:
|
||||
outputs.append(f"$ {command}\nError: {str(e)}")
|
||||
|
||||
outputs.append(f"$ {cmd}\nError: {str(e)}")
|
||||
|
||||
return "\n\n".join(outputs)
|
||||
|
||||
def __del__(self):
|
||||
"""Close SSH connection when object is destroyed."""
|
||||
if self.client and self._is_connected:
|
||||
|
||||
def close(self):
|
||||
"""Close the SSH connection."""
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self.client = None
|
||||
logger.info(f"SSH connection closed to {self.username}@{self.host}")
|
||||
|
||||
|
||||
def _get_default_ssh_process(host: str, username: str, **kwargs) -> Any:
|
||||
"""Get default SSH process with persistent connection."""
|
||||
return SSHProcess(host=host, username=username, **kwargs)
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure connection is closed when object is destroyed."""
|
||||
self.close()
|
||||
|
||||
|
||||
class SSHTool(BaseTool):
|
||||
"""Tool to run commands on remote servers via SSH with persistent connection."""
|
||||
|
||||
process: Optional[SSHProcess] = Field(default=None)
|
||||
"""SSH process with persistent connection."""
|
||||
|
||||
# Connection parameters
|
||||
host: str = Field(..., description="SSH host address")
|
||||
username: str = Field(..., description="SSH username")
|
||||
port: int = Field(default=22, description="SSH port")
|
||||
password: Optional[str] = Field(default=None, description="SSH password")
|
||||
key_filename: Optional[str] = Field(default=None, description="Path to SSH key")
|
||||
"""Simple SSH tool that behaves like a normal terminal session."""
|
||||
|
||||
# Tool configuration
|
||||
name: str = "ssh"
|
||||
description: str = """
|
||||
Run shell commands on a remote server via SSH.
|
||||
description: str = """Execute commands on a remote server via SSH.
|
||||
|
||||
This tool maintains a persistent SSH connection and allows executing
|
||||
commands on the remote server. It supports both regular and privileged
|
||||
(sudo) command execution.
|
||||
|
||||
Use the 'use_sudo' parameter to run commands with sudo privileges.
|
||||
If sudo requires a password, provide it via 'sudo_password'.
|
||||
Simply pass the commands you want to run. Use 'sudo' prefix for privileged commands.
|
||||
The tool will automatically handle password prompts and maintain sudo session.
|
||||
|
||||
Examples:
|
||||
- Regular command: {"commands": "ls -la"}
|
||||
- Sudo command: {"commands": "apt update", "use_sudo": true}
|
||||
- Multiple commands: {"commands": ["df -h", "free -m", "top -n 1"]}
|
||||
- {"commands": "ls -la"}
|
||||
- {"commands": "sudo apt update"}
|
||||
- {"commands": ["df -h", "sudo systemctl status nginx", "free -m"]}
|
||||
"""
|
||||
args_schema: Type[BaseModel] = SSHInput
|
||||
|
||||
# Connection parameters
|
||||
host: str = Field(..., description="SSH host")
|
||||
username: str = Field(..., description="SSH username")
|
||||
port: int = Field(default=22, description="SSH port")
|
||||
key_filename: Optional[str] = Field(default=None, description="SSH key path")
|
||||
password: Optional[str] = Field(default=None, description="SSH password")
|
||||
|
||||
# Session management
|
||||
session: Optional[SSHSession] = Field(default=None, exclude=True)
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True
|
||||
}
|
||||
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize SSH tool and set description."""
|
||||
"""Initialize SSH tool."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.description = f"Run commands on remote server {self.username}@{self.host}:{self.port}"
|
||||
# Initialize the SSH process (but don't connect yet)
|
||||
self.process = SSHProcess(
|
||||
# Create session but don't connect yet
|
||||
self.session = SSHSession(
|
||||
host=self.host,
|
||||
username=self.username,
|
||||
port=self.port,
|
||||
password=self.password,
|
||||
key_filename=self.key_filename
|
||||
key_filename=self.key_filename,
|
||||
password=self.password
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
commands: Union[str, List[str]],
|
||||
use_sudo: bool = False,
|
||||
sudo_password: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Run commands on remote server and return output."""
|
||||
|
||||
def _run(self, commands: Union[str, List[str]], **kwargs) -> str:
|
||||
"""Execute commands on remote server."""
|
||||
try:
|
||||
print(f"Executing SSH command on {self.username}@{self.host}:{self.port}") # noqa: T201
|
||||
print(f"Commands: {commands}") # noqa: T201
|
||||
if use_sudo:
|
||||
print("Running with sudo privileges") # noqa: T201
|
||||
|
||||
# Safety check for privileged commands
|
||||
if use_sudo:
|
||||
user_input = input("Proceed with sudo command execution? (y/n): ").lower()
|
||||
if user_input == "y":
|
||||
return self.process.run(commands, use_sudo=use_sudo, sudo_password=sudo_password)
|
||||
else:
|
||||
logger.info("User aborted sudo command execution.")
|
||||
return "Command execution aborted by user."
|
||||
else:
|
||||
user_input = input("Proceed with SSH command execution? (y/n): ").lower()
|
||||
if user_input == "y":
|
||||
return self.process.run(commands)
|
||||
else:
|
||||
logger.info("Invalid input. User aborted SSH command execution.")
|
||||
return "Command execution aborted by user."
|
||||
print(f"Executing on {self.username}@{self.host}:{self.port}")
|
||||
return self.session.run_commands(commands)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during SSH command execution: {e}")
|
||||
logger.error(f"SSH execution error: {e}")
|
||||
return f"Error: {str(e)}"
|
@ -1,66 +1,42 @@
|
||||
import getpass
|
||||
import logging
|
||||
import warnings
|
||||
import time
|
||||
from typing import Any, List, Optional, Type, Union
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SSHInput(BaseModel):
|
||||
"""Commands for the SSH tool."""
|
||||
|
||||
"""Input for SSH tool."""
|
||||
commands: Union[str, List[str]] = Field(
|
||||
...,
|
||||
description="List of commands to run on the remote server",
|
||||
description="Command(s) to run on the remote server"
|
||||
)
|
||||
"""List of commands to run."""
|
||||
|
||||
|
||||
class SSHSession:
|
||||
"""Manages a persistent SSH session with sudo caching."""
|
||||
|
||||
use_sudo: bool = Field(
|
||||
default=False,
|
||||
description="Whether to run commands with sudo privileges"
|
||||
)
|
||||
"""Whether to run commands with sudo."""
|
||||
|
||||
sudo_password: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Password for sudo if required (will be used securely)"
|
||||
)
|
||||
"""Password for sudo if required."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
def _validate_commands(cls, values: dict) -> Any:
|
||||
"""Validate commands."""
|
||||
commands = values.get("commands")
|
||||
if not isinstance(commands, list):
|
||||
values["commands"] = [commands]
|
||||
# Warn that the SSH tool has no safeguards
|
||||
warnings.warn(
|
||||
"The SSH tool has no safeguards by default. Use at your own risk."
|
||||
)
|
||||
return values
|
||||
|
||||
|
||||
class SSHProcess:
|
||||
"""Persistent SSH connection for command execution."""
|
||||
|
||||
def __init__(self, host: str, username: str, port: int = 22,
|
||||
password: Optional[str] = None, key_filename: Optional[str] = None,
|
||||
**kwargs):
|
||||
"""Initialize SSH process with connection parameters."""
|
||||
key_filename: Optional[str] = None, password: Optional[str] = None):
|
||||
self.host = host
|
||||
self.username = username
|
||||
self.port = port
|
||||
self.password = password
|
||||
self.key_filename = key_filename
|
||||
self.password = password
|
||||
self.client = None
|
||||
self._is_connected = False
|
||||
|
||||
self._sudo_password = None
|
||||
self._sudo_timestamp = None
|
||||
self._sudo_timeout = 300 # 5 minutes, like default sudo
|
||||
|
||||
def connect(self):
|
||||
"""Establish SSH connection."""
|
||||
if self._is_connected:
|
||||
"""Establish SSH connection if not already connected."""
|
||||
if self.client:
|
||||
return
|
||||
|
||||
try:
|
||||
import paramiko
|
||||
except ImportError as e:
|
||||
@ -68,7 +44,7 @@ class SSHProcess:
|
||||
"paramiko is required for SSH functionality. "
|
||||
"Install it with `pip install paramiko`"
|
||||
) from e
|
||||
|
||||
|
||||
self.client = paramiko.SSHClient()
|
||||
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
|
||||
@ -84,151 +60,142 @@ class SSHProcess:
|
||||
connect_kwargs["key_filename"] = self.key_filename
|
||||
|
||||
self.client.connect(**connect_kwargs)
|
||||
self._is_connected = True
|
||||
logger.info(f"SSH connection established to {self.username}@{self.host}:{self.port}")
|
||||
|
||||
def run(self, commands: Union[str, List[str]], use_sudo: bool = False,
|
||||
sudo_password: Optional[str] = None) -> str:
|
||||
"""Run commands over SSH and return output."""
|
||||
if not self._is_connected:
|
||||
self.connect()
|
||||
|
||||
def _needs_sudo_password(self) -> bool:
|
||||
"""Check if we need to ask for sudo password."""
|
||||
if not self._sudo_password:
|
||||
return True
|
||||
if not self._sudo_timestamp:
|
||||
return True
|
||||
# Check if sudo timeout has expired
|
||||
if time.time() - self._sudo_timestamp > self._sudo_timeout:
|
||||
self._sudo_password = None
|
||||
self._sudo_timestamp = None
|
||||
return True
|
||||
return False
|
||||
|
||||
def execute(self, command: str) -> str:
|
||||
"""Execute a single command, handling sudo automatically."""
|
||||
if not self.client:
|
||||
self.connect()
|
||||
|
||||
# Check if command needs sudo
|
||||
needs_sudo = command.strip().startswith('sudo ')
|
||||
|
||||
if needs_sudo:
|
||||
# Remove 'sudo ' prefix if present
|
||||
actual_command = command.strip()[5:].strip()
|
||||
|
||||
# Check if we need to get sudo password
|
||||
if self._needs_sudo_password():
|
||||
self._sudo_password = getpass.getpass(f"[sudo] password for {self.username}: ")
|
||||
self._sudo_timestamp = time.time()
|
||||
|
||||
# Execute with sudo -S and pass password via stdin
|
||||
full_command = f"sudo -S {actual_command}"
|
||||
stdin, stdout, stderr = self.client.exec_command(full_command, get_pty=True)
|
||||
|
||||
# Send password
|
||||
stdin.write(f"{self._sudo_password}\n")
|
||||
stdin.flush()
|
||||
|
||||
# Get output
|
||||
output = stdout.read().decode()
|
||||
error = stderr.read().decode()
|
||||
|
||||
# Clean up sudo prompt from output
|
||||
lines = output.split('\n')
|
||||
cleaned_lines = [line for line in lines if not line.strip().startswith('[sudo]')]
|
||||
output = '\n'.join(cleaned_lines)
|
||||
|
||||
else:
|
||||
# Regular command without sudo
|
||||
stdin, stdout, stderr = self.client.exec_command(command)
|
||||
output = stdout.read().decode()
|
||||
error = stderr.read().decode()
|
||||
|
||||
# Combine output and error
|
||||
result = output
|
||||
if error:
|
||||
result += f"\n{error}" if result else error
|
||||
|
||||
return result.strip()
|
||||
|
||||
def run_commands(self, commands: Union[str, List[str]]) -> str:
|
||||
"""Run one or more commands and return combined output."""
|
||||
if isinstance(commands, str):
|
||||
commands = [commands]
|
||||
|
||||
|
||||
outputs = []
|
||||
for command in commands:
|
||||
for cmd in commands:
|
||||
try:
|
||||
# Prepare command with sudo if needed
|
||||
if use_sudo:
|
||||
if sudo_password:
|
||||
# Use echo to pipe password to sudo -S (read from stdin)
|
||||
full_command = f"echo '{sudo_password}' | sudo -S {command}"
|
||||
else:
|
||||
# Try sudo without password (for passwordless sudo)
|
||||
full_command = f"sudo {command}"
|
||||
else:
|
||||
full_command = command
|
||||
|
||||
stdin, stdout, stderr = self.client.exec_command(full_command)
|
||||
|
||||
# For sudo commands with password, we need to handle stdin
|
||||
if use_sudo and sudo_password:
|
||||
stdin.write(f"{sudo_password}\n")
|
||||
stdin.flush()
|
||||
|
||||
output = stdout.read().decode()
|
||||
error = stderr.read().decode()
|
||||
|
||||
# Filter out sudo password prompt from error output
|
||||
if use_sudo and error:
|
||||
error_lines = error.split('\n')
|
||||
filtered_error = '\n'.join(
|
||||
line for line in error_lines
|
||||
if not line.startswith('[sudo]') and line.strip()
|
||||
)
|
||||
error = filtered_error
|
||||
|
||||
if error:
|
||||
outputs.append(f"$ {command}\n{output}{error}")
|
||||
else:
|
||||
outputs.append(f"$ {command}\n{output}")
|
||||
output = self.execute(cmd)
|
||||
outputs.append(f"$ {cmd}\n{output}")
|
||||
except Exception as e:
|
||||
outputs.append(f"$ {command}\nError: {str(e)}")
|
||||
|
||||
outputs.append(f"$ {cmd}\nError: {str(e)}")
|
||||
|
||||
return "\n\n".join(outputs)
|
||||
|
||||
def __del__(self):
|
||||
"""Close SSH connection when object is destroyed."""
|
||||
if self.client and self._is_connected:
|
||||
|
||||
def close(self):
|
||||
"""Close the SSH connection."""
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self.client = None
|
||||
logger.info(f"SSH connection closed to {self.username}@{self.host}")
|
||||
|
||||
|
||||
def _get_default_ssh_process(host: str, username: str, **kwargs) -> Any:
|
||||
"""Get default SSH process with persistent connection."""
|
||||
return SSHProcess(host=host, username=username, **kwargs)
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure connection is closed when object is destroyed."""
|
||||
self.close()
|
||||
|
||||
|
||||
class SSHTool(BaseTool):
|
||||
"""Tool to run commands on remote servers via SSH with persistent connection."""
|
||||
|
||||
process: Optional[SSHProcess] = Field(default=None)
|
||||
"""SSH process with persistent connection."""
|
||||
|
||||
# Connection parameters
|
||||
host: str = Field(..., description="SSH host address")
|
||||
username: str = Field(..., description="SSH username")
|
||||
port: int = Field(default=22, description="SSH port")
|
||||
password: Optional[str] = Field(default=None, description="SSH password")
|
||||
key_filename: Optional[str] = Field(default=None, description="Path to SSH key")
|
||||
"""Simple SSH tool that behaves like a normal terminal session."""
|
||||
|
||||
# Tool configuration
|
||||
name: str = "ssh"
|
||||
description: str = """
|
||||
Run shell commands on a remote server via SSH.
|
||||
description: str = """Execute commands on a remote server via SSH.
|
||||
|
||||
This tool maintains a persistent SSH connection and allows executing
|
||||
commands on the remote server. It supports both regular and privileged
|
||||
(sudo) command execution.
|
||||
|
||||
Use the 'use_sudo' parameter to run commands with sudo privileges.
|
||||
If sudo requires a password, provide it via 'sudo_password'.
|
||||
Simply pass the commands you want to run. Use 'sudo' prefix for privileged commands.
|
||||
The tool will automatically handle password prompts and maintain sudo session.
|
||||
|
||||
Examples:
|
||||
- Regular command: {"commands": "ls -la"}
|
||||
- Sudo command: {"commands": "apt update", "use_sudo": true}
|
||||
- Multiple commands: {"commands": ["df -h", "free -m", "top -n 1"]}
|
||||
- {"commands": "ls -la"}
|
||||
- {"commands": "sudo apt update"}
|
||||
- {"commands": ["df -h", "sudo systemctl status nginx", "free -m"]}
|
||||
"""
|
||||
args_schema: Type[BaseModel] = SSHInput
|
||||
|
||||
# Connection parameters
|
||||
host: str = Field(..., description="SSH host")
|
||||
username: str = Field(..., description="SSH username")
|
||||
port: int = Field(default=22, description="SSH port")
|
||||
key_filename: Optional[str] = Field(default=None, description="SSH key path")
|
||||
password: Optional[str] = Field(default=None, description="SSH password")
|
||||
|
||||
# Session management
|
||||
session: Optional[SSHSession] = Field(default=None, exclude=True)
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True
|
||||
}
|
||||
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize SSH tool and set description."""
|
||||
"""Initialize SSH tool."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.description = f"Run commands on remote server {self.username}@{self.host}:{self.port}"
|
||||
# Initialize the SSH process (but don't connect yet)
|
||||
self.process = SSHProcess(
|
||||
# Create session but don't connect yet
|
||||
self.session = SSHSession(
|
||||
host=self.host,
|
||||
username=self.username,
|
||||
port=self.port,
|
||||
password=self.password,
|
||||
key_filename=self.key_filename
|
||||
key_filename=self.key_filename,
|
||||
password=self.password
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
commands: Union[str, List[str]],
|
||||
use_sudo: bool = False,
|
||||
sudo_password: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Run commands on remote server and return output."""
|
||||
|
||||
def _run(self, commands: Union[str, List[str]], **kwargs) -> str:
|
||||
"""Execute commands on remote server."""
|
||||
try:
|
||||
print(f"Executing SSH command on {self.username}@{self.host}:{self.port}") # noqa: T201
|
||||
print(f"Commands: {commands}") # noqa: T201
|
||||
if use_sudo:
|
||||
print("Running with sudo privileges") # noqa: T201
|
||||
|
||||
# Safety check for privileged commands
|
||||
if use_sudo:
|
||||
user_input = input("Proceed with sudo command execution? (y/n): ").lower()
|
||||
if user_input == "y":
|
||||
return self.process.run(commands, use_sudo=use_sudo, sudo_password=sudo_password)
|
||||
else:
|
||||
logger.info("User aborted sudo command execution.")
|
||||
return "Command execution aborted by user."
|
||||
else:
|
||||
user_input = input("Proceed with SSH command execution? (y/n): ").lower()
|
||||
if user_input == "y":
|
||||
return self.process.run(commands)
|
||||
else:
|
||||
logger.info("Invalid input. User aborted SSH command execution.")
|
||||
return "Command execution aborted by user."
|
||||
print(f"Executing on {self.username}@{self.host}:{self.port}")
|
||||
return self.session.run_commands(commands)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during SSH command execution: {e}")
|
||||
logger.error(f"SSH execution error: {e}")
|
||||
return f"Error: {str(e)}"
|
Loading…
x
Reference in New Issue
Block a user