add custom ssh tool
This commit is contained in:
@@ -1 +1,16 @@
|
||||
"""Custom tools package for the LangGraph demo agent."""
|
||||
|
||||
from .poem_tool import print_poem
|
||||
from .ssh_tool import SSHTool
|
||||
from langchain_community.tools.shell.tool import ShellTool
|
||||
|
||||
# Pre-configured SSH tool for your server - only connects when actually used
|
||||
# TODO: Update these connection details for your actual server
|
||||
configured_ssh_tool = SSHTool(
|
||||
host="your-server.example.com", # Replace with your server
|
||||
username="admin", # Replace with your username
|
||||
key_filename="~/.ssh/id_rsa", # Replace with your key path
|
||||
ask_human_input=True # Safety confirmation
|
||||
)
|
||||
|
||||
__all__ = ["print_poem", "SSHTool", "ShellTool", "configured_ssh_tool"]
|
||||
|
0
simple-react-agent/custom_tools/ssh_config.py
Normal file
0
simple-react-agent/custom_tools/ssh_config.py
Normal file
185
simple-react-agent/custom_tools/ssh_tool.py
Normal file
185
simple-react-agent/custom_tools/ssh_tool.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, List, Optional, Type, Union
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SSHInput(BaseModel):
|
||||
"""Commands for the SSH tool."""
|
||||
|
||||
commands: Union[str, List[str]] = Field(
|
||||
...,
|
||||
description="List of commands to run on the remote server",
|
||||
)
|
||||
"""List of commands to run."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _validate_commands(cls, values: dict) -> Any:
|
||||
"""Validate commands."""
|
||||
commands = values.get("commands")
|
||||
if not isinstance(commands, list):
|
||||
values["commands"] = [commands]
|
||||
# Warn that the SSH tool has no safeguards
|
||||
warnings.warn(
|
||||
"The SSH tool has no safeguards by default. Use at your own risk."
|
||||
)
|
||||
return values
|
||||
|
||||
|
||||
class SSHProcess:
|
||||
"""Persistent SSH connection for command execution."""
|
||||
|
||||
def __init__(self, host: str, username: str, port: int = 22,
|
||||
password: Optional[str] = None, key_filename: Optional[str] = None,
|
||||
timeout: float = 30.0, return_err_output: bool = True):
|
||||
"""Initialize SSH process with connection parameters."""
|
||||
self.host = host
|
||||
self.username = username
|
||||
self.port = port
|
||||
self.password = password
|
||||
self.key_filename = key_filename
|
||||
self.timeout = timeout
|
||||
self.return_err_output = return_err_output
|
||||
self.client = None
|
||||
# Don't connect immediately - connect when needed
|
||||
|
||||
def _connect(self):
|
||||
"""Establish SSH connection."""
|
||||
try:
|
||||
import paramiko
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"paramiko is required for SSH functionality. "
|
||||
"Install it with `pip install paramiko`"
|
||||
)
|
||||
|
||||
self.client = paramiko.SSHClient()
|
||||
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
|
||||
connect_kwargs = {
|
||||
"hostname": self.host,
|
||||
"username": self.username,
|
||||
"port": self.port,
|
||||
"timeout": self.timeout,
|
||||
}
|
||||
|
||||
if self.password:
|
||||
connect_kwargs["password"] = self.password
|
||||
if self.key_filename:
|
||||
connect_kwargs["key_filename"] = self.key_filename
|
||||
|
||||
self.client.connect(**connect_kwargs)
|
||||
logger.info(f"SSH connection established to {self.username}@{self.host}:{self.port}")
|
||||
|
||||
def run(self, commands: Union[str, List[str]]) -> str:
|
||||
"""Run commands over SSH and return output."""
|
||||
if not self.client:
|
||||
self._connect()
|
||||
|
||||
if isinstance(commands, str):
|
||||
commands = [commands]
|
||||
|
||||
outputs = []
|
||||
for command in commands:
|
||||
try:
|
||||
stdin, stdout, stderr = self.client.exec_command(command)
|
||||
|
||||
output = stdout.read().decode('utf-8')
|
||||
error = stderr.read().decode('utf-8')
|
||||
|
||||
if error and self.return_err_output:
|
||||
outputs.append(f"$ {command}\n{output}{error}")
|
||||
else:
|
||||
outputs.append(f"$ {command}\n{output}")
|
||||
except Exception as e:
|
||||
outputs.append(f"$ {command}\nError: {str(e)}")
|
||||
|
||||
return "\n".join(outputs)
|
||||
|
||||
def __del__(self):
|
||||
"""Close SSH connection when object is destroyed."""
|
||||
if self.client:
|
||||
self.client.close()
|
||||
logger.info(f"SSH connection closed to {self.username}@{self.host}")
|
||||
|
||||
|
||||
def _get_default_ssh_process(host: str, username: str, **kwargs) -> Any:
|
||||
"""Get default SSH process with persistent connection."""
|
||||
return SSHProcess(host=host, username=username, **kwargs)
|
||||
|
||||
|
||||
class SSHTool(BaseTool):
|
||||
"""Tool to run commands on remote servers via SSH with persistent connection."""
|
||||
|
||||
process: Optional[SSHProcess] = Field(default=None)
|
||||
"""SSH process with persistent connection."""
|
||||
|
||||
# Connection parameters
|
||||
host: str = Field(..., description="SSH host address")
|
||||
username: str = Field(..., description="SSH username")
|
||||
port: int = Field(default=22, description="SSH port")
|
||||
password: Optional[str] = Field(default=None, description="SSH password")
|
||||
key_filename: Optional[str] = Field(default=None, description="Path to SSH key")
|
||||
timeout: float = Field(default=30.0, description="Connection timeout")
|
||||
|
||||
name: str = "ssh"
|
||||
"""Name of tool."""
|
||||
|
||||
description: str = Field(default="")
|
||||
"""Description of tool."""
|
||||
|
||||
args_schema: Type[BaseModel] = SSHInput
|
||||
"""Schema for input arguments."""
|
||||
|
||||
ask_human_input: bool = False
|
||||
"""
|
||||
If True, prompts the user for confirmation (y/n) before executing
|
||||
commands on the remote server.
|
||||
"""
|
||||
|
||||
def __init__(self, **data):
|
||||
"""Initialize SSH tool and set description."""
|
||||
super().__init__(**data)
|
||||
# Set description after initialization
|
||||
self.description = f"Run commands on remote server {self.username}@{self.host}:{self.port}"
|
||||
# Initialize the SSH process (but don't connect yet)
|
||||
self.process = SSHProcess(
|
||||
host=self.host,
|
||||
username=self.username,
|
||||
port=self.port,
|
||||
password=self.password,
|
||||
key_filename=self.key_filename,
|
||||
timeout=self.timeout,
|
||||
return_err_output=True
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
commands: Union[str, List[str]],
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Run commands on remote server and return output."""
|
||||
|
||||
print(f"Executing SSH command on {self.username}@{self.host}:{self.port}") # noqa: T201
|
||||
print(f"Commands: {commands}") # noqa: T201
|
||||
|
||||
try:
|
||||
if self.ask_human_input:
|
||||
user_input = input("Proceed with SSH command execution? (y/n): ").lower()
|
||||
if user_input == "y":
|
||||
return self.process.run(commands)
|
||||
else:
|
||||
logger.info("Invalid input. User aborted SSH command execution.")
|
||||
return None # type: ignore[return-value]
|
||||
else:
|
||||
return self.process.run(commands)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during SSH command execution: {e}")
|
||||
return None # type: ignore[return-value]
|
@@ -4,10 +4,11 @@ from langchain.chat_models import init_chat_model
|
||||
from langchain_community.tools.shell.tool import ShellTool
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langchain_core.messages import HumanMessage
|
||||
from custom_tools.poem_tool import print_poem
|
||||
from custom_tools import print_poem, configured_ssh_tool
|
||||
|
||||
# Suppress the shell tool warning since we're using it intentionally for sysadmin tasks
|
||||
warnings.filterwarnings("ignore", message="The shell tool has no safeguards by default. Use at your own risk.")
|
||||
warnings.filterwarnings("ignore", message="The SSH tool has no safeguards by default. Use at your own risk.")
|
||||
|
||||
|
||||
def create_agent():
|
||||
@@ -19,21 +20,27 @@ def create_agent():
|
||||
|
||||
# Define the tools available to the agent
|
||||
shell_tool = ShellTool()
|
||||
tools = [shell_tool, print_poem]
|
||||
tools = [shell_tool, configured_ssh_tool, print_poem]
|
||||
|
||||
|
||||
# Create a ReAct agent with system administration debugging focus
|
||||
system_prompt = """You are an expert system administrator debugging agent with deep knowledge of Linux, macOS, BSD, and Windows systems.
|
||||
|
||||
## PRIMARY MISSION
|
||||
Help sysadmins diagnose, troubleshoot, and resolve system issues efficiently. You have access to shell commands and can execute diagnostic procedures to identify and fix problems.
|
||||
Help sysadmins diagnose, troubleshoot, and resolve system issues efficiently. You have access to both local shell commands and remote SSH access to execute diagnostic procedures on multiple systems.
|
||||
|
||||
## CORE CAPABILITIES
|
||||
1. **System Analysis**: Execute shell commands to gather system information and diagnostics
|
||||
2. **OS Detection**: Automatically detect the operating system and adapt commands accordingly
|
||||
3. **Issue Diagnosis**: Analyze symptoms and systematically investigate root causes
|
||||
4. **Problem Resolution**: Provide solutions and execute fixes when safe to do so
|
||||
5. **Easter Egg**: Generate poems when users need a morale boost (use print_poem tool)
|
||||
1. **Local System Analysis**: Execute shell commands on the local machine (terminal tool)
|
||||
2. **Remote System Analysis**: Execute commands on remote servers via SSH (configured_ssh_tool)
|
||||
3. **OS Detection**: Automatically detect the operating system and adapt commands accordingly
|
||||
4. **Issue Diagnosis**: Analyze symptoms and systematically investigate root causes
|
||||
5. **Problem Resolution**: Provide solutions and execute fixes when safe to do so
|
||||
6. **Easter Egg**: Generate poems when users need a morale boost (use print_poem tool)
|
||||
|
||||
## AVAILABLE TOOLS
|
||||
- **terminal**: Execute commands on the local machine
|
||||
- **configured_ssh_tool**: Execute commands on the pre-configured remote server
|
||||
- **print_poem**: Generate motivational poems for debugging sessions
|
||||
|
||||
## OPERATING SYSTEM AWARENESS
|
||||
- **First interaction**: Always detect the OS using appropriate commands (uname, systeminfo, etc.)
|
||||
|
Reference in New Issue
Block a user