add custom ssh tool

This commit is contained in:
Gaetan Hurel
2025-06-29 14:49:07 +02:00
parent bf2cc1a409
commit 38b77a657f
17 changed files with 1292 additions and 19 deletions

View File

@@ -1 +1,16 @@
"""Custom tools package for the LangGraph demo agent."""
from .poem_tool import print_poem
from .ssh_tool import SSHTool
from langchain_community.tools.shell.tool import ShellTool
# Pre-configured SSH tool for your server - only connects when actually used
# TODO: Update these connection details for your actual server
configured_ssh_tool = SSHTool(
host="your-server.example.com", # Replace with your server
username="admin", # Replace with your username
key_filename="~/.ssh/id_rsa", # Replace with your key path
ask_human_input=True # Safety confirmation
)
__all__ = ["print_poem", "SSHTool", "ShellTool", "configured_ssh_tool"]

View File

@@ -0,0 +1,185 @@
import logging
import warnings
from typing import Any, List, Optional, Type, Union
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field, model_validator
logger = logging.getLogger(__name__)
class SSHInput(BaseModel):
"""Commands for the SSH tool."""
commands: Union[str, List[str]] = Field(
...,
description="List of commands to run on the remote server",
)
"""List of commands to run."""
@model_validator(mode="before")
@classmethod
def _validate_commands(cls, values: dict) -> Any:
"""Validate commands."""
commands = values.get("commands")
if not isinstance(commands, list):
values["commands"] = [commands]
# Warn that the SSH tool has no safeguards
warnings.warn(
"The SSH tool has no safeguards by default. Use at your own risk."
)
return values
class SSHProcess:
"""Persistent SSH connection for command execution."""
def __init__(self, host: str, username: str, port: int = 22,
password: Optional[str] = None, key_filename: Optional[str] = None,
timeout: float = 30.0, return_err_output: bool = True):
"""Initialize SSH process with connection parameters."""
self.host = host
self.username = username
self.port = port
self.password = password
self.key_filename = key_filename
self.timeout = timeout
self.return_err_output = return_err_output
self.client = None
# Don't connect immediately - connect when needed
def _connect(self):
"""Establish SSH connection."""
try:
import paramiko
except ImportError:
raise ImportError(
"paramiko is required for SSH functionality. "
"Install it with `pip install paramiko`"
)
self.client = paramiko.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
connect_kwargs = {
"hostname": self.host,
"username": self.username,
"port": self.port,
"timeout": self.timeout,
}
if self.password:
connect_kwargs["password"] = self.password
if self.key_filename:
connect_kwargs["key_filename"] = self.key_filename
self.client.connect(**connect_kwargs)
logger.info(f"SSH connection established to {self.username}@{self.host}:{self.port}")
def run(self, commands: Union[str, List[str]]) -> str:
"""Run commands over SSH and return output."""
if not self.client:
self._connect()
if isinstance(commands, str):
commands = [commands]
outputs = []
for command in commands:
try:
stdin, stdout, stderr = self.client.exec_command(command)
output = stdout.read().decode('utf-8')
error = stderr.read().decode('utf-8')
if error and self.return_err_output:
outputs.append(f"$ {command}\n{output}{error}")
else:
outputs.append(f"$ {command}\n{output}")
except Exception as e:
outputs.append(f"$ {command}\nError: {str(e)}")
return "\n".join(outputs)
def __del__(self):
"""Close SSH connection when object is destroyed."""
if self.client:
self.client.close()
logger.info(f"SSH connection closed to {self.username}@{self.host}")
def _get_default_ssh_process(host: str, username: str, **kwargs) -> Any:
"""Get default SSH process with persistent connection."""
return SSHProcess(host=host, username=username, **kwargs)
class SSHTool(BaseTool):
"""Tool to run commands on remote servers via SSH with persistent connection."""
process: Optional[SSHProcess] = Field(default=None)
"""SSH process with persistent connection."""
# Connection parameters
host: str = Field(..., description="SSH host address")
username: str = Field(..., description="SSH username")
port: int = Field(default=22, description="SSH port")
password: Optional[str] = Field(default=None, description="SSH password")
key_filename: Optional[str] = Field(default=None, description="Path to SSH key")
timeout: float = Field(default=30.0, description="Connection timeout")
name: str = "ssh"
"""Name of tool."""
description: str = Field(default="")
"""Description of tool."""
args_schema: Type[BaseModel] = SSHInput
"""Schema for input arguments."""
ask_human_input: bool = False
"""
If True, prompts the user for confirmation (y/n) before executing
commands on the remote server.
"""
def __init__(self, **data):
"""Initialize SSH tool and set description."""
super().__init__(**data)
# Set description after initialization
self.description = f"Run commands on remote server {self.username}@{self.host}:{self.port}"
# Initialize the SSH process (but don't connect yet)
self.process = SSHProcess(
host=self.host,
username=self.username,
port=self.port,
password=self.password,
key_filename=self.key_filename,
timeout=self.timeout,
return_err_output=True
)
def _run(
self,
commands: Union[str, List[str]],
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Run commands on remote server and return output."""
print(f"Executing SSH command on {self.username}@{self.host}:{self.port}") # noqa: T201
print(f"Commands: {commands}") # noqa: T201
try:
if self.ask_human_input:
user_input = input("Proceed with SSH command execution? (y/n): ").lower()
if user_input == "y":
return self.process.run(commands)
else:
logger.info("Invalid input. User aborted SSH command execution.")
return None # type: ignore[return-value]
else:
return self.process.run(commands)
except Exception as e:
logger.error(f"Error during SSH command execution: {e}")
return None # type: ignore[return-value]

View File

@@ -4,10 +4,11 @@ from langchain.chat_models import init_chat_model
from langchain_community.tools.shell.tool import ShellTool
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import HumanMessage
from custom_tools.poem_tool import print_poem
from custom_tools import print_poem, configured_ssh_tool
# Suppress the shell tool warning since we're using it intentionally for sysadmin tasks
warnings.filterwarnings("ignore", message="The shell tool has no safeguards by default. Use at your own risk.")
warnings.filterwarnings("ignore", message="The SSH tool has no safeguards by default. Use at your own risk.")
def create_agent():
@@ -19,21 +20,27 @@ def create_agent():
# Define the tools available to the agent
shell_tool = ShellTool()
tools = [shell_tool, print_poem]
tools = [shell_tool, configured_ssh_tool, print_poem]
# Create a ReAct agent with system administration debugging focus
system_prompt = """You are an expert system administrator debugging agent with deep knowledge of Linux, macOS, BSD, and Windows systems.
## PRIMARY MISSION
Help sysadmins diagnose, troubleshoot, and resolve system issues efficiently. You have access to shell commands and can execute diagnostic procedures to identify and fix problems.
Help sysadmins diagnose, troubleshoot, and resolve system issues efficiently. You have access to both local shell commands and remote SSH access to execute diagnostic procedures on multiple systems.
## CORE CAPABILITIES
1. **System Analysis**: Execute shell commands to gather system information and diagnostics
2. **OS Detection**: Automatically detect the operating system and adapt commands accordingly
3. **Issue Diagnosis**: Analyze symptoms and systematically investigate root causes
4. **Problem Resolution**: Provide solutions and execute fixes when safe to do so
5. **Easter Egg**: Generate poems when users need a morale boost (use print_poem tool)
1. **Local System Analysis**: Execute shell commands on the local machine (terminal tool)
2. **Remote System Analysis**: Execute commands on remote servers via SSH (configured_ssh_tool)
3. **OS Detection**: Automatically detect the operating system and adapt commands accordingly
4. **Issue Diagnosis**: Analyze symptoms and systematically investigate root causes
5. **Problem Resolution**: Provide solutions and execute fixes when safe to do so
6. **Easter Egg**: Generate poems when users need a morale boost (use print_poem tool)
## AVAILABLE TOOLS
- **terminal**: Execute commands on the local machine
- **configured_ssh_tool**: Execute commands on the pre-configured remote server
- **print_poem**: Generate motivational poems for debugging sessions
## OPERATING SYSTEM AWARENESS
- **First interaction**: Always detect the OS using appropriate commands (uname, systeminfo, etc.)