Terminal: - Auto-detect [sudo] password prompts in interactive shell output - Auto-send server password when sudo prompt detected - Reset detection flag on new command (Enter key) SSH client (SFTPSession): - Fix exec_command() sudo password timing (0.1s delay for prompt) - Fix listdir_attr_sudo() ls output parsing with proper maxsplit - Handle filenames with spaces, symlinks, and varied ls formats Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
570 lines
20 KiB
Python
570 lines
20 KiB
Python
"""
|
|
SSH client wrapper — connect, exec, sftp, key management via paramiko.
|
|
"""
|
|
|
|
import os
|
|
import platform
|
|
import socket
|
|
import threading
|
|
import time
|
|
import paramiko
|
|
from core.logger import log
|
|
|
|
|
|
def _create_bound_socket(bind_ip: str, hostname: str, port: int, timeout: int) -> socket.socket:
|
|
"""Create a TCP socket bound to a specific local IP address."""
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
sock.settimeout(timeout)
|
|
sock.bind((bind_ip, 0))
|
|
sock.connect((hostname, port))
|
|
return sock
|
|
|
|
|
|
def _connect_client(server: dict, key_path: str, timeout: int = 15) -> paramiko.SSHClient:
|
|
"""Create and authenticate a paramiko SSHClient. Shared by SSHClientWrapper and ShellSession."""
|
|
client = paramiko.SSHClient()
|
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
|
|
hostname = server["ip"]
|
|
port = server.get("port", 22)
|
|
bind_ip = server.get("bind_interface")
|
|
|
|
kwargs = {
|
|
"hostname": hostname,
|
|
"port": port,
|
|
"username": server.get("user", "root"),
|
|
"timeout": timeout,
|
|
"banner_timeout": timeout,
|
|
}
|
|
|
|
if bind_ip:
|
|
kwargs["sock"] = _create_bound_socket(bind_ip, hostname, port, timeout)
|
|
|
|
# Try key first
|
|
if key_path and os.path.exists(key_path):
|
|
try:
|
|
kwargs["key_filename"] = key_path
|
|
client.connect(**kwargs)
|
|
transport = client.get_transport()
|
|
if transport is not None:
|
|
transport.set_keepalive(30)
|
|
return client
|
|
except paramiko.AuthenticationException:
|
|
log.debug(f"Key auth failed for {server.get('alias', '?')}, trying password")
|
|
del kwargs["key_filename"]
|
|
client = paramiko.SSHClient()
|
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
if bind_ip:
|
|
kwargs["sock"] = _create_bound_socket(bind_ip, hostname, port, timeout)
|
|
except Exception as e:
|
|
log.debug(f"Key connect failed: {e}")
|
|
del kwargs["key_filename"]
|
|
client = paramiko.SSHClient()
|
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
if bind_ip:
|
|
kwargs["sock"] = _create_bound_socket(bind_ip, hostname, port, timeout)
|
|
|
|
# Fallback to password
|
|
password = server.get("password", "")
|
|
if password:
|
|
kwargs["password"] = password
|
|
kwargs["look_for_keys"] = False
|
|
kwargs["allow_agent"] = False
|
|
client.connect(**kwargs)
|
|
transport = client.get_transport()
|
|
if transport is not None:
|
|
transport.set_keepalive(30)
|
|
return client
|
|
|
|
raise Exception(f"No auth method for {server.get('alias', 'unknown')}")
|
|
|
|
|
|
class ShellSession:
|
|
"""Persistent interactive shell session over SSH."""
|
|
|
|
def __init__(self, server: dict, key_path: str, cols: int = 80, rows: int = 24):
|
|
self.server = server
|
|
self.key_path = key_path
|
|
self.cols = cols
|
|
self.rows = rows
|
|
self._client: paramiko.SSHClient | None = None
|
|
self._channel: paramiko.Channel | None = None
|
|
self._running = False
|
|
self._read_thread: threading.Thread | None = None
|
|
|
|
# Callbacks — set by the owner
|
|
self.on_data = None # on_data(data: bytes)
|
|
self.on_disconnect = None # on_disconnect()
|
|
|
|
@property
|
|
def connected(self) -> bool:
|
|
try:
|
|
return (
|
|
self._channel is not None
|
|
and self._channel.get_transport() is not None
|
|
and self._channel.get_transport().is_active()
|
|
)
|
|
except Exception:
|
|
return False
|
|
|
|
def connect(self):
|
|
self._client = _connect_client(self.server, self.key_path)
|
|
self._channel = self._client.invoke_shell(
|
|
term="xterm-256color",
|
|
width=self.cols,
|
|
height=self.rows,
|
|
)
|
|
self._channel.settimeout(0.1)
|
|
self._running = True
|
|
self._read_thread = threading.Thread(target=self._read_loop, daemon=True)
|
|
self._read_thread.start()
|
|
|
|
def _read_loop(self):
|
|
try:
|
|
while self._running:
|
|
try:
|
|
data = self._channel.recv(65536)
|
|
if not data:
|
|
break
|
|
if self.on_data:
|
|
self.on_data(data)
|
|
except TimeoutError:
|
|
continue
|
|
except OSError:
|
|
break
|
|
except Exception as e:
|
|
log.debug(f"ShellSession read loop error: {e}")
|
|
finally:
|
|
if self._running:
|
|
self._running = False
|
|
if self.on_disconnect:
|
|
self.on_disconnect()
|
|
|
|
def send(self, data: bytes):
|
|
if self._channel and self._running:
|
|
try:
|
|
self._channel.sendall(data)
|
|
except OSError:
|
|
self._running = False
|
|
if self.on_disconnect:
|
|
self.on_disconnect()
|
|
|
|
def resize(self, cols: int, rows: int):
|
|
self.cols = cols
|
|
self.rows = rows
|
|
if self._channel and self._running:
|
|
try:
|
|
self._channel.resize_pty(width=cols, height=rows)
|
|
except OSError:
|
|
pass
|
|
|
|
def disconnect(self):
|
|
self._running = False
|
|
if self._channel:
|
|
try:
|
|
self._channel.close()
|
|
except Exception as e:
|
|
log.debug(f"ShellSession channel close: {e}")
|
|
self._channel = None
|
|
if self._client:
|
|
try:
|
|
self._client.close()
|
|
except Exception as e:
|
|
log.debug(f"ShellSession client close: {e}")
|
|
self._client = None
|
|
|
|
def reconnect(self):
|
|
self.disconnect()
|
|
time.sleep(0.2)
|
|
self.connect()
|
|
|
|
|
|
class SSHClientWrapper:
|
|
def __init__(self, server: dict, key_path: str = ""):
|
|
self.server = server
|
|
self.key_path = key_path or os.path.expanduser("~/.ssh/id_ed25519")
|
|
self._client: paramiko.SSHClient | None = None
|
|
|
|
def connect(self) -> paramiko.SSHClient:
|
|
client = _connect_client(self.server, self.key_path)
|
|
self._client = client
|
|
return client
|
|
|
|
def disconnect(self):
|
|
if self._client:
|
|
try:
|
|
self._client.close()
|
|
except Exception:
|
|
pass
|
|
self._client = None
|
|
|
|
def exec_command(self, command: str, use_sudo: bool = True) -> tuple[str, str, int]:
|
|
"""Execute command. Auto-sudo if user != root and use_sudo=True."""
|
|
client = self.connect()
|
|
stdin = stdout = stderr = None
|
|
try:
|
|
user = self.server.get("user", "root")
|
|
need_sudo = use_sudo and user != "root"
|
|
|
|
if need_sudo:
|
|
full_cmd = f"export TERM=xterm; sudo -S -p '' bash -c {_shell_quote(command)}"
|
|
else:
|
|
full_cmd = f"export TERM=xterm; {command}"
|
|
|
|
stdin, stdout, stderr = client.exec_command(full_cmd, timeout=120, get_pty=True)
|
|
|
|
if need_sudo:
|
|
password = self.server.get("password", "")
|
|
stdin.write(password + "\n")
|
|
stdin.flush()
|
|
|
|
exit_code = stdout.channel.recv_exit_status()
|
|
out = stdout.read().decode("utf-8", errors="replace")
|
|
err = stderr.read().decode("utf-8", errors="replace")
|
|
|
|
# Strip sudo noise
|
|
err_lines = [l for l in err.splitlines()
|
|
if not l.startswith("[sudo]") and "password for" not in l.lower()]
|
|
err = "\n".join(err_lines).strip()
|
|
|
|
return out, err, exit_code
|
|
finally:
|
|
# Close channels explicitly
|
|
for ch in (stdin, stdout, stderr):
|
|
if ch:
|
|
try:
|
|
ch.close()
|
|
except Exception:
|
|
pass
|
|
client.close()
|
|
|
|
def upload(self, local_path: str, remote_path: str, progress_cb=None):
|
|
client = self.connect()
|
|
try:
|
|
sftp = client.open_sftp()
|
|
if progress_cb:
|
|
sftp.put(local_path, remote_path, callback=progress_cb)
|
|
else:
|
|
sftp.put(local_path, remote_path)
|
|
sftp.chmod(remote_path, 0o664)
|
|
sftp.close()
|
|
finally:
|
|
client.close()
|
|
|
|
def download(self, remote_path: str, local_path: str, progress_cb=None):
|
|
client = self.connect()
|
|
try:
|
|
sftp = client.open_sftp()
|
|
if progress_cb:
|
|
sftp.get(remote_path, local_path, callback=progress_cb)
|
|
else:
|
|
sftp.get(remote_path, local_path)
|
|
sftp.close()
|
|
finally:
|
|
client.close()
|
|
|
|
def check_connection(self) -> bool:
|
|
try:
|
|
client = _connect_client(self.server, self.key_path, timeout=5)
|
|
client.close()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
def install_key(self) -> str:
|
|
pub_key_path = self.key_path + ".pub"
|
|
if not os.path.exists(pub_key_path):
|
|
raise FileNotFoundError(f"No public key at {pub_key_path}")
|
|
|
|
with open(pub_key_path, "r") as f:
|
|
pub_key = f.read().strip()
|
|
|
|
out, _, _ = self.exec_command(
|
|
f'grep -c "{pub_key}" ~/.ssh/authorized_keys 2>/dev/null || echo 0',
|
|
use_sudo=False
|
|
)
|
|
if out.strip() != "0":
|
|
return "Key already installed"
|
|
|
|
command = (
|
|
f'mkdir -p ~/.ssh && chmod 700 ~/.ssh && '
|
|
f'echo "{pub_key}" >> ~/.ssh/authorized_keys && '
|
|
f'chmod 600 ~/.ssh/authorized_keys && '
|
|
f'echo "KEY_OK"'
|
|
)
|
|
out, err, code = self.exec_command(command, use_sudo=False)
|
|
if "KEY_OK" in out:
|
|
return "Key installed successfully"
|
|
raise Exception(f"Key install failed: {err or out}")
|
|
|
|
def generate_key(self) -> str:
|
|
if os.path.exists(self.key_path):
|
|
return f"Key already exists: {self.key_path}"
|
|
|
|
os.makedirs(os.path.dirname(self.key_path), exist_ok=True)
|
|
key = paramiko.Ed25519Key.generate()
|
|
key.write_private_key_file(self.key_path)
|
|
|
|
# Set restrictive permissions on private key (Unix)
|
|
if platform.system() != "Windows":
|
|
os.chmod(self.key_path, 0o600)
|
|
|
|
pub_key = f"ssh-ed25519 {key.get_base64()} server-manager"
|
|
with open(self.key_path + ".pub", "w") as f:
|
|
f.write(pub_key + "\n")
|
|
|
|
log.info(f"SSH key generated: {self.key_path}")
|
|
return f"Key generated: {self.key_path}"
|
|
|
|
|
|
class SFTPSession:
|
|
"""Persistent SFTP session for file browser."""
|
|
|
|
def __init__(self, server: dict, key_path: str):
|
|
self.server = server
|
|
self.key_path = key_path
|
|
self._client: paramiko.SSHClient | None = None
|
|
self._sftp: paramiko.SFTPClient | None = None
|
|
self.sudo_mode: bool = False
|
|
|
|
@property
|
|
def connected(self) -> bool:
|
|
try:
|
|
return (
|
|
self._client is not None
|
|
and self._sftp is not None
|
|
and self._client.get_transport() is not None
|
|
and self._client.get_transport().is_active()
|
|
)
|
|
except Exception:
|
|
return False
|
|
|
|
def connect(self):
|
|
self._client = _connect_client(self.server, self.key_path)
|
|
self._sftp = self._client.open_sftp()
|
|
|
|
def reconnect(self):
|
|
"""Disconnect and re-establish SFTP session."""
|
|
self.disconnect()
|
|
time.sleep(0.2)
|
|
self.connect()
|
|
|
|
def disconnect(self):
|
|
if self._sftp:
|
|
try:
|
|
self._sftp.close()
|
|
except Exception as e:
|
|
log.debug(f"SFTPSession sftp close: {e}")
|
|
self._sftp = None
|
|
if self._client:
|
|
try:
|
|
self._client.close()
|
|
except Exception as e:
|
|
log.debug(f"SFTPSession client close: {e}")
|
|
self._client = None
|
|
|
|
def listdir_attr(self, path: str) -> list:
|
|
return self._sftp.listdir_attr(path)
|
|
|
|
def stat(self, path: str):
|
|
return self._sftp.stat(path)
|
|
|
|
def mkdir(self, path: str):
|
|
self._sftp.mkdir(path)
|
|
|
|
def rmdir(self, path: str):
|
|
self._sftp.rmdir(path)
|
|
|
|
def remove(self, path: str):
|
|
self._sftp.remove(path)
|
|
|
|
def rename(self, old: str, new: str):
|
|
self._sftp.rename(old, new)
|
|
|
|
def upload(self, local_path: str, remote_path: str, progress_cb=None):
|
|
if progress_cb:
|
|
self._sftp.put(local_path, remote_path, callback=progress_cb)
|
|
else:
|
|
self._sftp.put(local_path, remote_path)
|
|
|
|
def download(self, remote_path: str, local_path: str, progress_cb=None):
|
|
if progress_cb:
|
|
self._sftp.get(remote_path, local_path, callback=progress_cb)
|
|
else:
|
|
self._sftp.get(remote_path, local_path)
|
|
|
|
def normalize(self, path: str) -> str:
|
|
return self._sftp.normalize(path)
|
|
|
|
# ── Recursive operations ──
|
|
|
|
def upload_dir(self, local_dir: str, remote_dir: str, progress_cb=None, file_cb=None):
|
|
"""Recursively upload a local directory to remote."""
|
|
all_files = []
|
|
for root, dirs, files in os.walk(local_dir):
|
|
rel = os.path.relpath(root, local_dir)
|
|
remote_sub = remote_dir if rel == "." else remote_dir + "/" + rel.replace("\\", "/")
|
|
try:
|
|
self._sftp.mkdir(remote_sub)
|
|
except IOError:
|
|
pass
|
|
for f in files:
|
|
local_file = os.path.join(root, f)
|
|
remote_file = remote_sub + "/" + f
|
|
all_files.append((local_file, remote_file))
|
|
|
|
for idx, (local_file, remote_file) in enumerate(all_files):
|
|
if file_cb:
|
|
file_cb(idx + 1, len(all_files), os.path.basename(local_file))
|
|
self._sftp.put(local_file, remote_file, callback=progress_cb)
|
|
|
|
def download_dir(self, remote_dir: str, local_dir: str, progress_cb=None, file_cb=None):
|
|
"""Recursively download a remote directory to local."""
|
|
all_files = []
|
|
self._walk_remote(remote_dir, remote_dir, all_files)
|
|
for idx, (remote_file, rel_path) in enumerate(all_files):
|
|
local_file = os.path.join(local_dir, rel_path)
|
|
os.makedirs(os.path.dirname(local_file), exist_ok=True)
|
|
if file_cb:
|
|
file_cb(idx + 1, len(all_files), os.path.basename(remote_file))
|
|
self._sftp.get(remote_file, local_file, callback=progress_cb)
|
|
|
|
def _walk_remote(self, base: str, current: str, result: list):
|
|
"""Recursively walk remote directory, collecting (abs_path, relative_path) tuples."""
|
|
import stat as stat_mod
|
|
for attr in self._sftp.listdir_attr(current):
|
|
full = current + "/" + attr.filename
|
|
rel = full[len(base):].lstrip("/")
|
|
if stat_mod.S_ISDIR(attr.st_mode or 0):
|
|
self._walk_remote(base, full, result)
|
|
else:
|
|
result.append((full, rel.replace("/", os.sep)))
|
|
|
|
def rmdir_recursive(self, path: str):
|
|
"""Recursively delete a remote directory."""
|
|
import stat as stat_mod
|
|
for attr in self._sftp.listdir_attr(path):
|
|
child = path + "/" + attr.filename
|
|
if stat_mod.S_ISDIR(attr.st_mode or 0):
|
|
self.rmdir_recursive(child)
|
|
else:
|
|
self._sftp.remove(child)
|
|
self._sftp.rmdir(path)
|
|
|
|
# ── Sudo operations ──
|
|
|
|
def exec_command(self, cmd: str) -> str:
|
|
"""Execute command via SSH with optional sudo wrapper."""
|
|
if not self._client:
|
|
raise Exception("Not connected")
|
|
password = self.server.get("password", "")
|
|
user = self.server.get("user", "root")
|
|
if self.sudo_mode and user != "root":
|
|
full_cmd = f"sudo -S -p '' bash -c {_shell_quote(cmd)}"
|
|
else:
|
|
full_cmd = cmd
|
|
stdin, stdout, stderr = self._client.exec_command(full_cmd, timeout=30)
|
|
if self.sudo_mode and user != "root" and password:
|
|
# Wait briefly for sudo prompt to appear before sending password
|
|
time.sleep(0.1)
|
|
stdin.write(password + "\n")
|
|
stdin.flush()
|
|
return stdout.read().decode("utf-8", errors="replace")
|
|
|
|
def listdir_attr_sudo(self, path: str) -> list:
|
|
"""List directory using sudo ls -la, returning objects with .filename/.st_size/.st_mtime/.st_mode."""
|
|
output = self.exec_command(f"ls -la --time-style=+%s {_shell_quote(path)}")
|
|
results = []
|
|
for line in output.strip().splitlines():
|
|
if line.startswith("total "):
|
|
continue
|
|
# With --time-style=+%s the columns are:
|
|
# perms links owner group size mtime name [-> target if symlink]
|
|
# Use maxsplit=8 to preserve spaces in filenames, giving us:
|
|
# [0=perms, 1=links, 2=owner, 3=group, 4=size, 5=mtime, 6=name...]
|
|
parts = line.split(None, 8)
|
|
if len(parts) < 7:
|
|
continue
|
|
perms = parts[0]
|
|
if not (perms.startswith("d") or perms.startswith("l") or perms.startswith("-")):
|
|
continue
|
|
try:
|
|
size = int(parts[4])
|
|
except ValueError:
|
|
size = 0
|
|
continue
|
|
try:
|
|
mtime = int(parts[5])
|
|
except ValueError:
|
|
mtime = 0
|
|
continue
|
|
name = parts[6]
|
|
# Handle cases where name contains " -> " (symlinks) or has spaces
|
|
if len(parts) > 7:
|
|
# This means the filename itself contained spaces and was split
|
|
name = parts[6] + " " + parts[7]
|
|
# Strip symlink target (e.g. "name -> target")
|
|
name = name.split(" -> ")[0].strip()
|
|
if name in (".", ".."):
|
|
continue
|
|
mode = _parse_ls_perms(perms)
|
|
entry = _SudoFileAttr(name, size, mtime, mode)
|
|
results.append(entry)
|
|
return results
|
|
|
|
def upload_sudo(self, local_path: str, remote_path: str, progress_cb=None):
|
|
"""Upload via SFTP to /tmp then sudo mv to destination."""
|
|
import random
|
|
tmp_name = f"/tmp/.sm_upload_{random.randint(100000, 999999)}"
|
|
if progress_cb:
|
|
self._sftp.put(local_path, tmp_name, callback=progress_cb)
|
|
else:
|
|
self._sftp.put(local_path, tmp_name)
|
|
self.exec_command(f"mv {_shell_quote(tmp_name)} {_shell_quote(remote_path)}")
|
|
|
|
def download_sudo(self, remote_path: str, local_path: str, progress_cb=None):
|
|
"""Copy via sudo to /tmp then download via SFTP."""
|
|
import random
|
|
tmp_name = f"/tmp/.sm_download_{random.randint(100000, 999999)}"
|
|
self.exec_command(f"cp {_shell_quote(remote_path)} {_shell_quote(tmp_name)} && chmod 644 {_shell_quote(tmp_name)}")
|
|
try:
|
|
if progress_cb:
|
|
self._sftp.get(tmp_name, local_path, callback=progress_cb)
|
|
else:
|
|
self._sftp.get(tmp_name, local_path)
|
|
finally:
|
|
self.exec_command(f"rm -f {_shell_quote(tmp_name)}")
|
|
|
|
|
|
class _SudoFileAttr:
|
|
"""Mimics paramiko SFTPAttributes for sudo ls output."""
|
|
def __init__(self, filename: str, st_size: int, st_mtime: int, st_mode: int):
|
|
self.filename = filename
|
|
self.st_size = st_size
|
|
self.st_mtime = st_mtime
|
|
self.st_mode = st_mode
|
|
|
|
|
|
def _parse_ls_perms(perms: str) -> int:
|
|
"""Parse ls -l permission string like 'drwxr-xr-x' into stat mode int."""
|
|
import stat as stat_mod
|
|
mode = 0
|
|
if perms[0] == "d":
|
|
mode |= stat_mod.S_IFDIR
|
|
elif perms[0] == "l":
|
|
mode |= stat_mod.S_IFLNK
|
|
else:
|
|
mode |= stat_mod.S_IFREG
|
|
mapping = "rwxrwxrwx"
|
|
bits = [stat_mod.S_IRUSR, stat_mod.S_IWUSR, stat_mod.S_IXUSR,
|
|
stat_mod.S_IRGRP, stat_mod.S_IWGRP, stat_mod.S_IXGRP,
|
|
stat_mod.S_IROTH, stat_mod.S_IWOTH, stat_mod.S_IXOTH]
|
|
for i, (ch, bit) in enumerate(zip(perms[1:10], bits)):
|
|
if ch != "-":
|
|
mode |= bit
|
|
return mode
|
|
|
|
|
|
def _shell_quote(s: str) -> str:
|
|
return "'" + s.replace("'", "'\\''") + "'"
|