""" SSH client wrapper — connect, exec, sftp, key management via paramiko. """ import os import platform import socket import threading import time import paramiko from core.logger import log def _create_bound_socket(bind_ip: str, hostname: str, port: int, timeout: int) -> socket.socket: """Create a TCP socket bound to a specific local IP address.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(timeout) sock.bind((bind_ip, 0)) sock.connect((hostname, port)) return sock def _connect_client(server: dict, key_path: str, timeout: int = 15) -> paramiko.SSHClient: """Create and authenticate a paramiko SSHClient. Shared by SSHClientWrapper and ShellSession.""" client = paramiko.SSHClient() client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) hostname = server["ip"] port = server.get("port", 22) bind_ip = server.get("bind_interface") kwargs = { "hostname": hostname, "port": port, "username": server.get("user", "root"), "timeout": timeout, "banner_timeout": timeout, } if bind_ip: kwargs["sock"] = _create_bound_socket(bind_ip, hostname, port, timeout) # Try key first if key_path and os.path.exists(key_path): try: kwargs["key_filename"] = key_path client.connect(**kwargs) transport = client.get_transport() if transport is not None: transport.set_keepalive(30) return client except paramiko.AuthenticationException: log.debug(f"Key auth failed for {server.get('alias', '?')}, trying password") del kwargs["key_filename"] client = paramiko.SSHClient() client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) if bind_ip: kwargs["sock"] = _create_bound_socket(bind_ip, hostname, port, timeout) except Exception as e: log.debug(f"Key connect failed: {e}") del kwargs["key_filename"] client = paramiko.SSHClient() client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) if bind_ip: kwargs["sock"] = _create_bound_socket(bind_ip, hostname, port, timeout) # Fallback to password password = server.get("password", "") if password: kwargs["password"] = password kwargs["look_for_keys"] = False kwargs["allow_agent"] = False client.connect(**kwargs) transport = client.get_transport() if transport is not None: transport.set_keepalive(30) return client raise Exception(f"No auth method for {server.get('alias', 'unknown')}") class ShellSession: """Persistent interactive shell session over SSH.""" def __init__(self, server: dict, key_path: str, cols: int = 80, rows: int = 24): self.server = server self.key_path = key_path self.cols = cols self.rows = rows self._client: paramiko.SSHClient | None = None self._channel: paramiko.Channel | None = None self._running = False self._read_thread: threading.Thread | None = None # Callbacks — set by the owner self.on_data = None # on_data(data: bytes) self.on_disconnect = None # on_disconnect() @property def connected(self) -> bool: try: return ( self._channel is not None and self._channel.get_transport() is not None and self._channel.get_transport().is_active() ) except Exception: return False def connect(self): self._client = _connect_client(self.server, self.key_path) self._channel = self._client.invoke_shell( term="xterm-256color", width=self.cols, height=self.rows, ) self._channel.settimeout(0.1) self._running = True self._read_thread = threading.Thread(target=self._read_loop, daemon=True) self._read_thread.start() def _read_loop(self): try: while self._running: try: data = self._channel.recv(65536) if not data: break if self.on_data: self.on_data(data) except TimeoutError: continue except OSError: break except Exception as e: log.debug(f"ShellSession read loop error: {e}") finally: if self._running: self._running = False if self.on_disconnect: self.on_disconnect() def send(self, data: bytes): if self._channel and self._running: try: self._channel.sendall(data) except OSError: self._running = False if self.on_disconnect: self.on_disconnect() def resize(self, cols: int, rows: int): self.cols = cols self.rows = rows if self._channel and self._running: try: self._channel.resize_pty(width=cols, height=rows) except OSError: pass def disconnect(self): self._running = False if self._channel: try: self._channel.close() except Exception as e: log.debug(f"ShellSession channel close: {e}") self._channel = None if self._client: try: self._client.close() except Exception as e: log.debug(f"ShellSession client close: {e}") self._client = None def reconnect(self): self.disconnect() time.sleep(0.2) self.connect() class SSHClientWrapper: def __init__(self, server: dict, key_path: str = ""): self.server = server self.key_path = key_path or os.path.expanduser("~/.ssh/id_ed25519") self._client: paramiko.SSHClient | None = None def connect(self) -> paramiko.SSHClient: client = _connect_client(self.server, self.key_path) self._client = client return client def disconnect(self): if self._client: try: self._client.close() except Exception: pass self._client = None def exec_command(self, command: str, use_sudo: bool = True) -> tuple[str, str, int]: """Execute command. Auto-sudo if user != root and use_sudo=True.""" client = self.connect() stdin = stdout = stderr = None try: user = self.server.get("user", "root") need_sudo = use_sudo and user != "root" if need_sudo: full_cmd = f"export TERM=xterm; sudo -S -p '' bash -c {_shell_quote(command)}" else: full_cmd = f"export TERM=xterm; {command}" stdin, stdout, stderr = client.exec_command(full_cmd, timeout=120, get_pty=True) if need_sudo: password = self.server.get("password", "") stdin.write(password + "\n") stdin.flush() exit_code = stdout.channel.recv_exit_status() out = stdout.read().decode("utf-8", errors="replace") err = stderr.read().decode("utf-8", errors="replace") # Strip sudo noise err_lines = [l for l in err.splitlines() if not l.startswith("[sudo]") and "password for" not in l.lower()] err = "\n".join(err_lines).strip() return out, err, exit_code finally: # Close channels explicitly for ch in (stdin, stdout, stderr): if ch: try: ch.close() except Exception: pass client.close() def upload(self, local_path: str, remote_path: str, progress_cb=None): client = self.connect() try: sftp = client.open_sftp() if progress_cb: sftp.put(local_path, remote_path, callback=progress_cb) else: sftp.put(local_path, remote_path) sftp.chmod(remote_path, 0o664) sftp.close() finally: client.close() def download(self, remote_path: str, local_path: str, progress_cb=None): client = self.connect() try: sftp = client.open_sftp() if progress_cb: sftp.get(remote_path, local_path, callback=progress_cb) else: sftp.get(remote_path, local_path) sftp.close() finally: client.close() def check_connection(self) -> bool: try: client = _connect_client(self.server, self.key_path, timeout=5) client.close() return True except Exception: return False def install_key(self) -> str: pub_key_path = self.key_path + ".pub" if not os.path.exists(pub_key_path): raise FileNotFoundError(f"No public key at {pub_key_path}") with open(pub_key_path, "r") as f: pub_key = f.read().strip() out, _, _ = self.exec_command( f'grep -c "{pub_key}" ~/.ssh/authorized_keys 2>/dev/null || echo 0', use_sudo=False ) if out.strip() != "0": return "Key already installed" command = ( f'mkdir -p ~/.ssh && chmod 700 ~/.ssh && ' f'echo "{pub_key}" >> ~/.ssh/authorized_keys && ' f'chmod 600 ~/.ssh/authorized_keys && ' f'echo "KEY_OK"' ) out, err, code = self.exec_command(command, use_sudo=False) if "KEY_OK" in out: return "Key installed successfully" raise Exception(f"Key install failed: {err or out}") def generate_key(self) -> str: if os.path.exists(self.key_path): return f"Key already exists: {self.key_path}" os.makedirs(os.path.dirname(self.key_path), exist_ok=True) key = paramiko.Ed25519Key.generate() key.write_private_key_file(self.key_path) # Set restrictive permissions on private key (Unix) if platform.system() != "Windows": os.chmod(self.key_path, 0o600) pub_key = f"ssh-ed25519 {key.get_base64()} server-manager" with open(self.key_path + ".pub", "w") as f: f.write(pub_key + "\n") log.info(f"SSH key generated: {self.key_path}") return f"Key generated: {self.key_path}" class SFTPSession: """Persistent SFTP session for file browser.""" def __init__(self, server: dict, key_path: str): self.server = server self.key_path = key_path self._client: paramiko.SSHClient | None = None self._sftp: paramiko.SFTPClient | None = None self.sudo_mode: bool = False @property def connected(self) -> bool: try: return ( self._client is not None and self._sftp is not None and self._client.get_transport() is not None and self._client.get_transport().is_active() ) except Exception: return False def connect(self): self._client = _connect_client(self.server, self.key_path) self._sftp = self._client.open_sftp() def reconnect(self): """Disconnect and re-establish SFTP session.""" self.disconnect() time.sleep(0.2) self.connect() def disconnect(self): if self._sftp: try: self._sftp.close() except Exception as e: log.debug(f"SFTPSession sftp close: {e}") self._sftp = None if self._client: try: self._client.close() except Exception as e: log.debug(f"SFTPSession client close: {e}") self._client = None def listdir_attr(self, path: str) -> list: return self._sftp.listdir_attr(path) def stat(self, path: str): return self._sftp.stat(path) def mkdir(self, path: str): self._sftp.mkdir(path) def rmdir(self, path: str): self._sftp.rmdir(path) def remove(self, path: str): self._sftp.remove(path) def rename(self, old: str, new: str): self._sftp.rename(old, new) def upload(self, local_path: str, remote_path: str, progress_cb=None): if progress_cb: self._sftp.put(local_path, remote_path, callback=progress_cb) else: self._sftp.put(local_path, remote_path) def download(self, remote_path: str, local_path: str, progress_cb=None): if progress_cb: self._sftp.get(remote_path, local_path, callback=progress_cb) else: self._sftp.get(remote_path, local_path) def normalize(self, path: str) -> str: return self._sftp.normalize(path) # ── Recursive operations ── def upload_dir(self, local_dir: str, remote_dir: str, progress_cb=None, file_cb=None): """Recursively upload a local directory to remote.""" all_files = [] for root, dirs, files in os.walk(local_dir): rel = os.path.relpath(root, local_dir) remote_sub = remote_dir if rel == "." else remote_dir + "/" + rel.replace("\\", "/") try: self._sftp.mkdir(remote_sub) except IOError: pass for f in files: local_file = os.path.join(root, f) remote_file = remote_sub + "/" + f all_files.append((local_file, remote_file)) for idx, (local_file, remote_file) in enumerate(all_files): if file_cb: file_cb(idx + 1, len(all_files), os.path.basename(local_file)) self._sftp.put(local_file, remote_file, callback=progress_cb) def download_dir(self, remote_dir: str, local_dir: str, progress_cb=None, file_cb=None): """Recursively download a remote directory to local.""" all_files = [] self._walk_remote(remote_dir, remote_dir, all_files) for idx, (remote_file, rel_path) in enumerate(all_files): local_file = os.path.join(local_dir, rel_path) os.makedirs(os.path.dirname(local_file), exist_ok=True) if file_cb: file_cb(idx + 1, len(all_files), os.path.basename(remote_file)) self._sftp.get(remote_file, local_file, callback=progress_cb) def _walk_remote(self, base: str, current: str, result: list): """Recursively walk remote directory, collecting (abs_path, relative_path) tuples.""" import stat as stat_mod for attr in self._sftp.listdir_attr(current): full = current + "/" + attr.filename rel = full[len(base):].lstrip("/") if stat_mod.S_ISDIR(attr.st_mode or 0): self._walk_remote(base, full, result) else: result.append((full, rel.replace("/", os.sep))) def rmdir_recursive(self, path: str): """Recursively delete a remote directory.""" import stat as stat_mod for attr in self._sftp.listdir_attr(path): child = path + "/" + attr.filename if stat_mod.S_ISDIR(attr.st_mode or 0): self.rmdir_recursive(child) else: self._sftp.remove(child) self._sftp.rmdir(path) # ── Sudo operations ── def exec_command(self, cmd: str) -> str: """Execute command via SSH with optional sudo wrapper.""" if not self._client: raise Exception("Not connected") password = self.server.get("password", "") user = self.server.get("user", "root") if self.sudo_mode and user != "root": full_cmd = f"sudo -S -p '' bash -c {_shell_quote(cmd)}" else: full_cmd = cmd stdin, stdout, stderr = self._client.exec_command(full_cmd, timeout=30) if self.sudo_mode and user != "root" and password: # Wait briefly for sudo prompt to appear before sending password time.sleep(0.1) stdin.write(password + "\n") stdin.flush() return stdout.read().decode("utf-8", errors="replace") def listdir_attr_sudo(self, path: str) -> list: """List directory using sudo ls -la, returning objects with .filename/.st_size/.st_mtime/.st_mode.""" output = self.exec_command(f"ls -la --time-style=+%s {_shell_quote(path)}") results = [] for line in output.strip().splitlines(): if line.startswith("total "): continue # With --time-style=+%s the columns are: # perms links owner group size mtime name [-> target if symlink] # Use maxsplit=8 to preserve spaces in filenames, giving us: # [0=perms, 1=links, 2=owner, 3=group, 4=size, 5=mtime, 6=name...] parts = line.split(None, 8) if len(parts) < 7: continue perms = parts[0] if not (perms.startswith("d") or perms.startswith("l") or perms.startswith("-")): continue try: size = int(parts[4]) except ValueError: size = 0 continue try: mtime = int(parts[5]) except ValueError: mtime = 0 continue name = parts[6] # Handle cases where name contains " -> " (symlinks) or has spaces if len(parts) > 7: # This means the filename itself contained spaces and was split name = parts[6] + " " + parts[7] # Strip symlink target (e.g. "name -> target") name = name.split(" -> ")[0].strip() if name in (".", ".."): continue mode = _parse_ls_perms(perms) entry = _SudoFileAttr(name, size, mtime, mode) results.append(entry) return results def upload_sudo(self, local_path: str, remote_path: str, progress_cb=None): """Upload via SFTP to /tmp then sudo mv to destination.""" import random tmp_name = f"/tmp/.sm_upload_{random.randint(100000, 999999)}" if progress_cb: self._sftp.put(local_path, tmp_name, callback=progress_cb) else: self._sftp.put(local_path, tmp_name) self.exec_command(f"mv {_shell_quote(tmp_name)} {_shell_quote(remote_path)}") def download_sudo(self, remote_path: str, local_path: str, progress_cb=None): """Copy via sudo to /tmp then download via SFTP.""" import random tmp_name = f"/tmp/.sm_download_{random.randint(100000, 999999)}" self.exec_command(f"cp {_shell_quote(remote_path)} {_shell_quote(tmp_name)} && chmod 644 {_shell_quote(tmp_name)}") try: if progress_cb: self._sftp.get(tmp_name, local_path, callback=progress_cb) else: self._sftp.get(tmp_name, local_path) finally: self.exec_command(f"rm -f {_shell_quote(tmp_name)}") class _SudoFileAttr: """Mimics paramiko SFTPAttributes for sudo ls output.""" def __init__(self, filename: str, st_size: int, st_mtime: int, st_mode: int): self.filename = filename self.st_size = st_size self.st_mtime = st_mtime self.st_mode = st_mode def _parse_ls_perms(perms: str) -> int: """Parse ls -l permission string like 'drwxr-xr-x' into stat mode int.""" import stat as stat_mod mode = 0 if perms[0] == "d": mode |= stat_mod.S_IFDIR elif perms[0] == "l": mode |= stat_mod.S_IFLNK else: mode |= stat_mod.S_IFREG mapping = "rwxrwxrwx" bits = [stat_mod.S_IRUSR, stat_mod.S_IWUSR, stat_mod.S_IXUSR, stat_mod.S_IRGRP, stat_mod.S_IWGRP, stat_mod.S_IXGRP, stat_mod.S_IROTH, stat_mod.S_IWOTH, stat_mod.S_IXOTH] for i, (ch, bit) in enumerate(zip(perms[1:10], bits)): if ch != "-": mode |= bit return mode def _shell_quote(s: str) -> str: return "'" + s.replace("'", "'\\''") + "'"