110 lines
3.3 KiB
Python
110 lines
3.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script to verify SSH connection sharing works properly.
|
|
This script simulates multiple agents trying to use SSH simultaneously.
|
|
"""
|
|
|
|
import time
|
|
import threading
|
|
from custom_tools import configured_remote_server, ssh_manager
|
|
from custom_tools.ssh_tool import SSHTool
|
|
|
|
def test_ssh_connection_sharing():
|
|
"""Test that SSH connection sharing prevents multiple connections."""
|
|
|
|
print("🧪 Testing SSH Connection Sharing...")
|
|
print("=" * 50)
|
|
|
|
# Test 1: Verify shared connection is being used
|
|
print("\n1. Testing shared connection mechanism...")
|
|
|
|
# Create multiple SSH tool instances with shared connection
|
|
ssh_tool1 = SSHTool(
|
|
host="157.90.211.119",
|
|
port=8081,
|
|
username="g",
|
|
key_filename="/Users/ghsioux/.ssh/id_rsa_hetzner",
|
|
use_shared_connection=True
|
|
)
|
|
|
|
ssh_tool2 = SSHTool(
|
|
host="157.90.211.119",
|
|
port=8081,
|
|
username="g",
|
|
key_filename="/Users/ghsioux/.ssh/id_rsa_hetzner",
|
|
use_shared_connection=True
|
|
)
|
|
|
|
# Verify they share the same session
|
|
if ssh_tool1.session is ssh_tool2.session:
|
|
print("✅ SSH tools are sharing the same session instance")
|
|
else:
|
|
print("❌ SSH tools are NOT sharing the same session instance")
|
|
|
|
# Test 2: Test sequential execution
|
|
print("\n2. Testing sequential execution...")
|
|
|
|
def run_command(tool, command, name):
|
|
"""Run a command with timing info."""
|
|
start_time = time.time()
|
|
try:
|
|
result = tool._run(command)
|
|
end_time = time.time()
|
|
print(f" {name}: Completed in {end_time - start_time:.2f}s")
|
|
return result
|
|
except Exception as e:
|
|
end_time = time.time()
|
|
print(f" {name}: Failed in {end_time - start_time:.2f}s - {e}")
|
|
return f"Error: {e}"
|
|
|
|
# Test commands that should run sequentially
|
|
commands = [
|
|
("whoami", "Agent 1"),
|
|
("date", "Agent 2"),
|
|
("pwd", "Agent 3")
|
|
]
|
|
|
|
threads = []
|
|
results = {}
|
|
|
|
for cmd, agent_name in commands:
|
|
thread = threading.Thread(
|
|
target=lambda c=cmd, n=agent_name: results.update({n: run_command(configured_remote_server, c, n)})
|
|
)
|
|
threads.append(thread)
|
|
|
|
# Start all threads (they should execute sequentially due to our lock)
|
|
print(" Starting multiple SSH operations...")
|
|
start_time = time.time()
|
|
|
|
for thread in threads:
|
|
thread.start()
|
|
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
total_time = time.time() - start_time
|
|
print(f" Total execution time: {total_time:.2f}s")
|
|
|
|
# Test 3: Verify connection cleanup
|
|
print("\n3. Testing connection cleanup...")
|
|
print(" Current connections:", len(ssh_manager._connections))
|
|
|
|
# Close all connections
|
|
ssh_manager.close_all()
|
|
print(" Connections after cleanup:", len(ssh_manager._connections))
|
|
|
|
print("\n" + "=" * 50)
|
|
print("🎉 SSH Connection Sharing Test Complete!")
|
|
|
|
return results
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
results = test_ssh_connection_sharing()
|
|
print("\nTest Results:")
|
|
for agent, result in results.items():
|
|
print(f" {agent}: {result[:50]}{'...' if len(result) > 50 else ''}")
|
|
except Exception as e:
|
|
print(f"Test failed: {e}")
|