#!/usr/bin/env python3 """ Test script to verify SSH connection sharing works properly. This script simulates multiple agents trying to use SSH simultaneously. """ import time import threading from custom_tools import configured_remote_server, ssh_manager from custom_tools.ssh_tool import SSHTool def test_ssh_connection_sharing(): """Test that SSH connection sharing prevents multiple connections.""" print("๐Ÿงช Testing SSH Connection Sharing...") print("=" * 50) # Test 1: Verify shared connection is being used print("\n1. Testing shared connection mechanism...") # Create multiple SSH tool instances with shared connection ssh_tool1 = SSHTool( host="157.90.211.119", port=8081, username="g", key_filename="/Users/ghsioux/.ssh/id_rsa_hetzner", use_shared_connection=True ) ssh_tool2 = SSHTool( host="157.90.211.119", port=8081, username="g", key_filename="/Users/ghsioux/.ssh/id_rsa_hetzner", use_shared_connection=True ) # Verify they share the same session if ssh_tool1.session is ssh_tool2.session: print("โœ… SSH tools are sharing the same session instance") else: print("โŒ SSH tools are NOT sharing the same session instance") # Test 2: Test sequential execution print("\n2. Testing sequential execution...") def run_command(tool, command, name): """Run a command with timing info.""" start_time = time.time() try: result = tool._run(command) end_time = time.time() print(f" {name}: Completed in {end_time - start_time:.2f}s") return result except Exception as e: end_time = time.time() print(f" {name}: Failed in {end_time - start_time:.2f}s - {e}") return f"Error: {e}" # Test commands that should run sequentially commands = [ ("whoami", "Agent 1"), ("date", "Agent 2"), ("pwd", "Agent 3") ] threads = [] results = {} for cmd, agent_name in commands: thread = threading.Thread( target=lambda c=cmd, n=agent_name: results.update({n: run_command(configured_remote_server, c, n)}) ) threads.append(thread) # Start all threads (they should execute sequentially due to our lock) print(" Starting multiple SSH operations...") start_time = time.time() for thread in threads: thread.start() for thread in threads: thread.join() total_time = time.time() - start_time print(f" Total execution time: {total_time:.2f}s") # Test 3: Verify connection cleanup print("\n3. Testing connection cleanup...") print(" Current connections:", len(ssh_manager._connections)) # Close all connections ssh_manager.close_all() print(" Connections after cleanup:", len(ssh_manager._connections)) print("\n" + "=" * 50) print("๐ŸŽ‰ SSH Connection Sharing Test Complete!") return results if __name__ == "__main__": try: results = test_ssh_connection_sharing() print("\nTest Results:") for agent, result in results.items(): print(f" {agent}: {result[:50]}{'...' if len(result) > 50 else ''}") except Exception as e: print(f"Test failed: {e}")