268 lines
10 KiB
Python
268 lines
10 KiB
Python
"""
|
|
Session pool for managing SSH and SFTP sessions to avoid reconnecting when switching between servers.
|
|
"""
|
|
|
|
import threading
|
|
import time
|
|
from collections import OrderedDict
|
|
from typing import Dict, Optional, Tuple
|
|
from core.ssh_client import ShellSession, SFTPSession
|
|
|
|
|
|
_CRITICAL_KEYS = ('ip', 'port', 'username', 'password', 'type',
|
|
'access_key', 'secret_key', 'use_ssl')
|
|
|
|
|
|
def _server_changed(old: dict, new: dict) -> bool:
|
|
"""Check if critical connection fields differ."""
|
|
return any(old.get(k) != new.get(k) for k in _CRITICAL_KEYS)
|
|
|
|
|
|
class SessionData:
|
|
"""Container for session data including the actual sessions and their metadata."""
|
|
def __init__(self, alias: str, server: dict, key_path: str):
|
|
self.alias = alias
|
|
self.server = server
|
|
self.key_path = key_path
|
|
self.shell_session: Optional[ShellSession] = None
|
|
self.sftp_session: Optional[SFTPSession] = None
|
|
self.last_access_time = time.time()
|
|
# State preservation for sessions
|
|
self.terminal_buffer: bytes = b""
|
|
self.remote_path: str = "/"
|
|
self.sudo_mode: bool = False
|
|
|
|
def cleanup(self):
|
|
"""Clean up sessions."""
|
|
if self.shell_session:
|
|
self.shell_session.disconnect()
|
|
self.shell_session = None
|
|
if self.sftp_session:
|
|
self.sftp_session.disconnect()
|
|
self.sftp_session = None
|
|
|
|
|
|
class SessionPool:
|
|
"""
|
|
Manages a pool of SSH/SFTP sessions to keep connections alive when switching between servers.
|
|
|
|
Features:
|
|
- Caches sessions per server alias
|
|
- Keeps idle sessions alive with keepalive
|
|
- Maintains session state (terminal buffer, remote path)
|
|
- LRU eviction for max sessions limit
|
|
- Thread-safe operations
|
|
"""
|
|
|
|
def __init__(self, max_sessions: int = 5):
|
|
self.max_sessions = max_sessions
|
|
self._sessions: Dict[str, SessionData] = {}
|
|
self._lock = threading.RLock() # Reentrant lock for thread safety
|
|
self._last_used_order = OrderedDict() # Track access order for LRU
|
|
|
|
def get_or_create_shell_session(self, alias: str, server: dict, key_path: str) -> Tuple[ShellSession, bool]:
|
|
"""
|
|
Get existing shell session or create a new one.
|
|
|
|
Args:
|
|
alias: Server alias
|
|
server: Server configuration dict
|
|
key_path: Path to SSH key
|
|
|
|
Returns:
|
|
Tuple of (session, is_new_session)
|
|
"""
|
|
with self._lock:
|
|
# Get or create session data
|
|
if alias not in self._sessions:
|
|
session_data = SessionData(alias, server, key_path)
|
|
self._sessions[alias] = session_data
|
|
else:
|
|
session_data = self._sessions[alias]
|
|
# Invalidate if server connection data changed
|
|
if _server_changed(session_data.server, server):
|
|
session_data.cleanup()
|
|
session_data.server = server
|
|
session_data.key_path = key_path
|
|
|
|
# Update access time for LRU
|
|
self._update_last_access(alias)
|
|
|
|
# Create shell session if needed
|
|
if session_data.shell_session is None or not session_data.shell_session.connected:
|
|
shell_session = ShellSession(server, key_path)
|
|
session_data.shell_session = shell_session
|
|
|
|
# Restore terminal buffer if we have one
|
|
if session_data.terminal_buffer:
|
|
# We can't directly restore the buffer since terminal handles its own state
|
|
# But we remember the fact that we had buffered data
|
|
pass
|
|
|
|
return shell_session, True
|
|
|
|
return session_data.shell_session, False
|
|
|
|
def get_or_create_sftp_session(self, alias: str, server: dict, key_path: str) -> Tuple[SFTPSession, bool]:
|
|
"""
|
|
Get existing SFTP session or create a new one.
|
|
|
|
Args:
|
|
alias: Server alias
|
|
server: Server configuration dict
|
|
key_path: Path to SSH key
|
|
|
|
Returns:
|
|
Tuple of (session, is_new_session)
|
|
"""
|
|
with self._lock:
|
|
# Get or create session data
|
|
if alias not in self._sessions:
|
|
session_data = SessionData(alias, server, key_path)
|
|
self._sessions[alias] = session_data
|
|
else:
|
|
session_data = self._sessions[alias]
|
|
# Invalidate if server connection data changed
|
|
if _server_changed(session_data.server, server):
|
|
session_data.cleanup()
|
|
session_data.server = server
|
|
session_data.key_path = key_path
|
|
|
|
# Update access time for LRU
|
|
self._update_last_access(alias)
|
|
|
|
# Create SFTP session if needed
|
|
if session_data.sftp_session is None or not session_data.sftp_session.connected:
|
|
sftp_session = SFTPSession(server, key_path)
|
|
session_data.sftp_session = sftp_session
|
|
sftp_session.sudo_mode = session_data.sudo_mode
|
|
|
|
return sftp_session, True
|
|
|
|
return session_data.sftp_session, False
|
|
|
|
def activate_shell_session(self, alias: str, server: dict, key_path: str) -> ShellSession:
|
|
"""
|
|
Activate a shell session for the given alias (or create if needed).
|
|
Updates access time and ensures session is ready.
|
|
"""
|
|
session, _ = self.get_or_create_shell_session(alias, server, key_path)
|
|
with self._lock:
|
|
self._update_last_access(alias)
|
|
return session
|
|
|
|
def activate_sftp_session(self, alias: str, server: dict, key_path: str) -> SFTPSession:
|
|
"""
|
|
Activate an SFTP session for the given alias (or create if needed).
|
|
Updates access time and ensures session is ready.
|
|
"""
|
|
session, _ = self.get_or_create_sftp_session(alias, server, key_path)
|
|
with self._lock:
|
|
self._update_last_access(alias)
|
|
# Apply stored state
|
|
session_data = self._sessions[alias]
|
|
session.sudo_mode = session_data.sudo_mode
|
|
return session
|
|
|
|
def store_shell_state(self, alias: str, terminal_buffer: bytes):
|
|
"""Store terminal state when switching away from a server."""
|
|
with self._lock:
|
|
if alias in self._sessions:
|
|
self._sessions[alias].terminal_buffer = terminal_buffer
|
|
self._update_last_access(alias)
|
|
|
|
def store_sftp_state(self, alias: str, remote_path: str, sudo_mode: bool):
|
|
"""Store SFTP state when switching away from a server."""
|
|
with self._lock:
|
|
if alias in self._sessions:
|
|
session_data = self._sessions[alias]
|
|
session_data.remote_path = remote_path
|
|
session_data.sudo_mode = sudo_mode
|
|
self._update_last_access(alias)
|
|
|
|
def get_shell_state(self, alias: str) -> bytes:
|
|
"""Retrieve terminal state when switching back to a server."""
|
|
with self._lock:
|
|
if alias in self._sessions:
|
|
self._update_last_access(alias)
|
|
return self._sessions[alias].terminal_buffer
|
|
return b""
|
|
|
|
def get_sftp_state(self, alias: str) -> Tuple[str, bool]:
|
|
"""Retrieve SFTP state when switching back to a server."""
|
|
with self._lock:
|
|
if alias in self._sessions:
|
|
session_data = self._sessions[alias]
|
|
self._update_last_access(alias)
|
|
return session_data.remote_path, session_data.sudo_mode
|
|
return "/", False
|
|
|
|
def _update_last_access(self, alias: str):
|
|
"""Update the last access time for the given alias."""
|
|
if alias in self._last_used_order:
|
|
del self._last_used_order[alias]
|
|
self._last_used_order[alias] = time.time()
|
|
|
|
# Enforce max sessions limit using LRU
|
|
while len(self._last_used_order) > self.max_sessions:
|
|
oldest_alias, _ = self._last_used_order.popitem(last=False)
|
|
if oldest_alias in self._sessions:
|
|
old_session = self._sessions[oldest_alias]
|
|
old_session.cleanup()
|
|
del self._sessions[oldest_alias]
|
|
|
|
def disconnect_session(self, alias: str):
|
|
"""Explicitly disconnect a session."""
|
|
with self._lock:
|
|
if alias in self._sessions:
|
|
session_data = self._sessions[alias]
|
|
session_data.cleanup()
|
|
del self._sessions[alias]
|
|
if alias in self._last_used_order:
|
|
del self._last_used_order[alias]
|
|
|
|
def disconnect_all(self):
|
|
"""Disconnect all sessions."""
|
|
with self._lock:
|
|
for session_data in self._sessions.values():
|
|
session_data.cleanup()
|
|
self._sessions.clear()
|
|
self._last_used_order.clear()
|
|
|
|
def cleanup_deleted_server(self, alias: str):
|
|
"""Clean up sessions when a server is deleted."""
|
|
self.disconnect_session(alias)
|
|
|
|
def rename_server(self, old_alias: str, new_alias: str):
|
|
"""Rename a server's session references (after alias change)."""
|
|
with self._lock:
|
|
if old_alias in self._sessions:
|
|
session_data = self._sessions.pop(old_alias)
|
|
session_data.alias = new_alias
|
|
self._sessions[new_alias] = session_data
|
|
if old_alias in self._last_used_order:
|
|
ts = self._last_used_order.pop(old_alias)
|
|
self._last_used_order[new_alias] = ts
|
|
|
|
def get_active_sessions(self) -> list:
|
|
"""Get list of aliases for active sessions."""
|
|
with self._lock:
|
|
active = []
|
|
for alias, session_data in self._sessions.items():
|
|
has_active = (
|
|
(session_data.shell_session and session_data.shell_session.connected) or
|
|
(session_data.sftp_session and session_data.sftp_session.connected)
|
|
)
|
|
if has_active:
|
|
active.append(alias)
|
|
return active
|
|
|
|
def has_active_session(self, alias: str) -> bool:
|
|
with self._lock:
|
|
sd = self._sessions.get(alias)
|
|
if not sd:
|
|
return False
|
|
return bool(
|
|
(sd.shell_session and sd.shell_session.connected) or
|
|
(sd.sftp_session and sd.sftp_session.connected)
|
|
) |