""" Session pool for managing SSH and SFTP sessions to avoid reconnecting when switching between servers. """ import threading import time from collections import OrderedDict from typing import Dict, Optional, Tuple from core.ssh_client import ShellSession, SFTPSession _CRITICAL_KEYS = ('ip', 'port', 'username', 'password', 'type', 'access_key', 'secret_key', 'use_ssl') def _server_changed(old: dict, new: dict) -> bool: """Check if critical connection fields differ.""" return any(old.get(k) != new.get(k) for k in _CRITICAL_KEYS) class SessionData: """Container for session data including the actual sessions and their metadata.""" def __init__(self, alias: str, server: dict, key_path: str): self.alias = alias self.server = server self.key_path = key_path self.shell_session: Optional[ShellSession] = None self.sftp_session: Optional[SFTPSession] = None self.last_access_time = time.time() # State preservation for sessions self.terminal_buffer: bytes = b"" self.remote_path: str = "/" self.sudo_mode: bool = False def cleanup(self): """Clean up sessions.""" if self.shell_session: self.shell_session.disconnect() self.shell_session = None if self.sftp_session: self.sftp_session.disconnect() self.sftp_session = None class SessionPool: """ Manages a pool of SSH/SFTP sessions to keep connections alive when switching between servers. Features: - Caches sessions per server alias - Keeps idle sessions alive with keepalive - Maintains session state (terminal buffer, remote path) - LRU eviction for max sessions limit - Thread-safe operations """ def __init__(self, max_sessions: int = 5): self.max_sessions = max_sessions self._sessions: Dict[str, SessionData] = {} self._lock = threading.RLock() # Reentrant lock for thread safety self._last_used_order = OrderedDict() # Track access order for LRU def get_or_create_shell_session(self, alias: str, server: dict, key_path: str) -> Tuple[ShellSession, bool]: """ Get existing shell session or create a new one. Args: alias: Server alias server: Server configuration dict key_path: Path to SSH key Returns: Tuple of (session, is_new_session) """ with self._lock: # Get or create session data if alias not in self._sessions: session_data = SessionData(alias, server, key_path) self._sessions[alias] = session_data else: session_data = self._sessions[alias] # Invalidate if server connection data changed if _server_changed(session_data.server, server): session_data.cleanup() session_data.server = server session_data.key_path = key_path # Update access time for LRU self._update_last_access(alias) # Create shell session if needed if session_data.shell_session is None or not session_data.shell_session.connected: shell_session = ShellSession(server, key_path) session_data.shell_session = shell_session # Restore terminal buffer if we have one if session_data.terminal_buffer: # We can't directly restore the buffer since terminal handles its own state # But we remember the fact that we had buffered data pass return shell_session, True return session_data.shell_session, False def get_or_create_sftp_session(self, alias: str, server: dict, key_path: str) -> Tuple[SFTPSession, bool]: """ Get existing SFTP session or create a new one. Args: alias: Server alias server: Server configuration dict key_path: Path to SSH key Returns: Tuple of (session, is_new_session) """ with self._lock: # Get or create session data if alias not in self._sessions: session_data = SessionData(alias, server, key_path) self._sessions[alias] = session_data else: session_data = self._sessions[alias] # Invalidate if server connection data changed if _server_changed(session_data.server, server): session_data.cleanup() session_data.server = server session_data.key_path = key_path # Update access time for LRU self._update_last_access(alias) # Create SFTP session if needed if session_data.sftp_session is None or not session_data.sftp_session.connected: sftp_session = SFTPSession(server, key_path) session_data.sftp_session = sftp_session sftp_session.sudo_mode = session_data.sudo_mode return sftp_session, True return session_data.sftp_session, False def activate_shell_session(self, alias: str, server: dict, key_path: str) -> ShellSession: """ Activate a shell session for the given alias (or create if needed). Updates access time and ensures session is ready. """ session, _ = self.get_or_create_shell_session(alias, server, key_path) with self._lock: self._update_last_access(alias) return session def activate_sftp_session(self, alias: str, server: dict, key_path: str) -> SFTPSession: """ Activate an SFTP session for the given alias (or create if needed). Updates access time and ensures session is ready. """ session, _ = self.get_or_create_sftp_session(alias, server, key_path) with self._lock: self._update_last_access(alias) # Apply stored state session_data = self._sessions[alias] session.sudo_mode = session_data.sudo_mode return session def store_shell_state(self, alias: str, terminal_buffer: bytes): """Store terminal state when switching away from a server.""" with self._lock: if alias in self._sessions: self._sessions[alias].terminal_buffer = terminal_buffer self._update_last_access(alias) def store_sftp_state(self, alias: str, remote_path: str, sudo_mode: bool): """Store SFTP state when switching away from a server.""" with self._lock: if alias in self._sessions: session_data = self._sessions[alias] session_data.remote_path = remote_path session_data.sudo_mode = sudo_mode self._update_last_access(alias) def get_shell_state(self, alias: str) -> bytes: """Retrieve terminal state when switching back to a server.""" with self._lock: if alias in self._sessions: self._update_last_access(alias) return self._sessions[alias].terminal_buffer return b"" def get_sftp_state(self, alias: str) -> Tuple[str, bool]: """Retrieve SFTP state when switching back to a server.""" with self._lock: if alias in self._sessions: session_data = self._sessions[alias] self._update_last_access(alias) return session_data.remote_path, session_data.sudo_mode return "/", False def _update_last_access(self, alias: str): """Update the last access time for the given alias.""" if alias in self._last_used_order: del self._last_used_order[alias] self._last_used_order[alias] = time.time() # Enforce max sessions limit using LRU while len(self._last_used_order) > self.max_sessions: oldest_alias, _ = self._last_used_order.popitem(last=False) if oldest_alias in self._sessions: old_session = self._sessions[oldest_alias] old_session.cleanup() del self._sessions[oldest_alias] def disconnect_session(self, alias: str): """Explicitly disconnect a session.""" with self._lock: if alias in self._sessions: session_data = self._sessions[alias] session_data.cleanup() del self._sessions[alias] if alias in self._last_used_order: del self._last_used_order[alias] def disconnect_all(self): """Disconnect all sessions.""" with self._lock: for session_data in self._sessions.values(): session_data.cleanup() self._sessions.clear() self._last_used_order.clear() def cleanup_deleted_server(self, alias: str): """Clean up sessions when a server is deleted.""" self.disconnect_session(alias) def rename_server(self, old_alias: str, new_alias: str): """Rename a server's session references (after alias change).""" with self._lock: if old_alias in self._sessions: session_data = self._sessions.pop(old_alias) session_data.alias = new_alias self._sessions[new_alias] = session_data if old_alias in self._last_used_order: ts = self._last_used_order.pop(old_alias) self._last_used_order[new_alias] = ts def get_active_sessions(self) -> list: """Get list of aliases for active sessions.""" with self._lock: active = [] for alias, session_data in self._sessions.items(): has_active = ( (session_data.shell_session and session_data.shell_session.connected) or (session_data.sftp_session and session_data.sftp_session.connected) ) if has_active: active.append(alias) return active