""" Server store — CRUD + JSON persistence + observer pattern. Supports encryption, backups, and configurable config path. Thread-safe with atomic writes. """ import hashlib import json import os import shutil import threading import time from datetime import datetime from typing import Callable, Optional from core.encryption import encrypt, decrypt, is_encrypted from core.logger import log # Shared config — same file used by ssh.py and Claude Code /ssh skill SHARED_DIR = os.path.expanduser("~/.server-connections") SETTINGS_FILE = os.path.join(SHARED_DIR, "settings.json") DEFAULT_SERVERS_FILE = os.path.join(SHARED_DIR, "servers.json") BACKUP_DIR = os.path.join(SHARED_DIR, "backups") # Fallback: local config dir (for example file) LOCAL_CONFIG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "config") EXAMPLE_FILE = os.path.join(LOCAL_CONFIG_DIR, "servers.example.json") SERVER_TYPES = ["ssh", "telnet", "rdp", "vnc", "winrm", "mariadb", "mssql", "postgresql", "redis", "grafana", "prometheus"] DEFAULT_PORTS = { "ssh": 22, "telnet": 23, "rdp": 3389, "vnc": 5900, "winrm": 5985, "mariadb": 3306, "mssql": 1433, "postgresql": 5432, "redis": 6379, "grafana": 3000, "prometheus": 9090, } # Auto-backup interval: 10 minutes _BACKUP_INTERVAL = 600 class ServerStore: def __init__(self): self._data: dict = {"servers": [], "ssh_key": {"type": "ed25519", "path": "~/.ssh/id_ed25519"}} self._observers: list[Callable] = [] self._check_interval: int = 60 self._statuses: dict[str, str] = {} # alias -> "online" | "offline" | "unknown" self._statuses_lock = threading.Lock() self._file_lock = threading.Lock() self._last_backup_time: float = 0 self._last_backup_hash: str = "" self._terminal_font_size: int = 11 self._window_geometry: str = "" self._servers_file: str = DEFAULT_SERVERS_FILE # Update settings self._update_mode: str = "auto-download" # "notify-only" | "auto-download" | "full-auto" self._last_update_check: float = 0 self._skip_version: str = "" self._load_settings() self._load() # ── Settings ────────────────────────────────────── def _load_settings(self): if os.path.exists(SETTINGS_FILE): try: with open(SETTINGS_FILE, "r", encoding="utf-8") as f: settings = json.load(f) path = settings.get("servers_path", "") if path and os.path.exists(path): self._servers_file = path # Load language preference from core import i18n lang = settings.get("language", "en") i18n.set_language(lang) self._check_interval = settings.get("check_interval", 60) self._terminal_font_size = settings.get("terminal_font_size", 11) self._window_geometry = settings.get("window_geometry", "") self._update_mode = settings.get("update_mode", "auto-download") self._last_update_check = settings.get("last_update_check", 0) self._skip_version = settings.get("skip_version", "") except json.JSONDecodeError: log.warning("Corrupted settings.json, using defaults") except Exception as e: log.error(f"Failed to load settings: {e}") def _save_settings(self): os.makedirs(SHARED_DIR, exist_ok=True) from core import i18n settings = { "servers_path": self._servers_file, "language": i18n.get_language(), "check_interval": self._check_interval, "terminal_font_size": self._terminal_font_size, "window_geometry": self._window_geometry, "update_mode": self._update_mode, "last_update_check": self._last_update_check, "skip_version": self._skip_version, } try: tmp = SETTINGS_FILE + ".tmp" with open(tmp, "w", encoding="utf-8") as f: json.dump(settings, f, indent=2, ensure_ascii=False) os.replace(tmp, SETTINGS_FILE) except Exception as e: log.error(f"Failed to save settings: {e}") def get_config_path(self) -> str: return self._servers_file def set_config_path(self, path: str): self._servers_file = path self._save_settings() self._load() self._notify() # ── Load / Save (encrypted, thread-safe, atomic) ── def _load(self): with self._file_lock: self._load_unsafe() def _load_unsafe(self): if os.path.exists(self._servers_file): try: with open(self._servers_file, "rb") as f: raw = f.read() if not raw.strip(): return if is_encrypted(raw): text = decrypt(raw) self._data = json.loads(text) else: self._data = json.loads(raw.decode("utf-8")) # Auto-migration: backup plain file, then encrypt pre_enc = os.path.join(BACKUP_DIR, "servers_pre-encryption.json") if not os.path.exists(pre_enc): os.makedirs(BACKUP_DIR, exist_ok=True) shutil.copy2(self._servers_file, pre_enc) self._save_unsafe() # Re-encrypt with new key if needed (migration from old key) self._save_unsafe() except json.JSONDecodeError as e: log.error(f"Corrupted servers.json: {e}") self._try_restore_from_backup() except Exception as e: log.error(f"Failed to load servers: {e}") self._try_restore_from_backup() elif os.path.exists(EXAMPLE_FILE): try: with open(EXAMPLE_FILE, "r", encoding="utf-8") as f: self._data = json.load(f) self._save_unsafe() except Exception as e: log.error(f"Failed to load example: {e}") def _try_restore_from_backup(self): """Attempt to restore from latest backup on corruption.""" backups = self.list_backups() if backups: log.warning(f"Attempting restore from backup: {backups[0]}") try: src = os.path.join(BACKUP_DIR, backups[0]) with open(src, "rb") as f: raw = f.read() if is_encrypted(raw): text = decrypt(raw) self._data = json.loads(text) else: self._data = json.loads(raw.decode("utf-8")) self._save_unsafe() log.info("Restored from backup successfully") except Exception as e2: log.error(f"Backup restore also failed: {e2}") self._data = {"servers": [], "ssh_key": {"type": "ed25519", "path": "~/.ssh/id_ed25519"}} else: log.warning("No backups available, starting fresh") self._data = {"servers": [], "ssh_key": {"type": "ed25519", "path": "~/.ssh/id_ed25519"}} def _save(self): with self._file_lock: self._save_unsafe() def _save_unsafe(self): """Write encrypted data atomically (tmp + rename).""" os.makedirs(os.path.dirname(self._servers_file), exist_ok=True) text = json.dumps(self._data, indent=2, ensure_ascii=False) encrypted = encrypt(text) tmp = self._servers_file + ".tmp" try: with open(tmp, "wb") as f: f.write(encrypted) os.replace(tmp, self._servers_file) except Exception as e: log.error(f"Failed to save servers: {e}") # Clean up temp file if os.path.exists(tmp): try: os.remove(tmp) except Exception: pass return # Auto-backup now = time.time() if now - self._last_backup_time >= _BACKUP_INTERVAL: self._auto_backup() def _data_hash(self) -> str: text = json.dumps(self._data, sort_keys=True, ensure_ascii=False) return hashlib.sha256(text.encode("utf-8")).hexdigest() def _auto_backup(self): current_hash = self._data_hash() if current_hash == self._last_backup_hash: self._last_backup_time = time.time() return try: self.create_backup() self._last_backup_hash = current_hash except Exception as e: log.warning(f"Auto-backup failed: {e}") # ── Backups ─────────────────────────────────────── def create_backup(self) -> str: os.makedirs(BACKUP_DIR, exist_ok=True) stamp = datetime.now().strftime("%Y-%m-%d_%H%M%S") name = f"servers_{stamp}.json" dst = os.path.join(BACKUP_DIR, name) shutil.copy2(self._servers_file, dst) self._last_backup_time = time.time() self._last_backup_hash = self._data_hash() log.info(f"Backup created: {name}") return name def list_backups(self) -> list[str]: if not os.path.isdir(BACKUP_DIR): return [] files = [f for f in os.listdir(BACKUP_DIR) if f.startswith("servers_") and f.endswith(".json")] files.sort(reverse=True) return files def restore_backup(self, filename: str): src = os.path.join(BACKUP_DIR, filename) if not os.path.exists(src): raise FileNotFoundError(f"Backup not found: {filename}") # Validate backup before restoring with open(src, "rb") as f: raw = f.read() try: if is_encrypted(raw): text = decrypt(raw) data = json.loads(text) else: data = json.loads(raw.decode("utf-8")) except Exception as e: raise ValueError(f"Backup is corrupted: {e}") self._data = data self._save() self._notify() log.info(f"Restored from: {filename}") # ── Import / Export ────────────────────────────── def export_config(self, dest_path: str) -> str: text = json.dumps(self._data, indent=2, ensure_ascii=False) with open(dest_path, "w", encoding="utf-8") as f: f.write(text) log.info(f"Config exported to: {dest_path}") return dest_path def import_config(self, src_path: str): if not os.path.exists(src_path): raise FileNotFoundError(f"File not found: {src_path}") with open(src_path, "rb") as f: raw = f.read() try: if is_encrypted(raw): text = decrypt(raw) data = json.loads(text) else: data = json.loads(raw.decode("utf-8")) except Exception as e: raise ValueError(f"Invalid config file: {e}") if not isinstance(data, dict) or not isinstance(data.get("servers"), list): raise ValueError("Invalid config structure: missing 'servers' list") self._data = data self._save() self._notify() log.info(f"Config imported from: {src_path}") def export_backup(self, filename: str, dest_path: str) -> str: src = os.path.join(BACKUP_DIR, filename) if not os.path.exists(src): raise FileNotFoundError(f"Backup not found: {filename}") with open(src, "rb") as f: raw = f.read() if is_encrypted(raw): text = decrypt(raw) data = json.loads(text) else: data = json.loads(raw.decode("utf-8")) with open(dest_path, "w", encoding="utf-8") as f: json.dump(data, indent=2, ensure_ascii=False, fp=f) log.info(f"Backup exported to: {dest_path}") return dest_path def import_backup(self, src_path: str) -> str: if not os.path.exists(src_path): raise FileNotFoundError(f"File not found: {src_path}") with open(src_path, "rb") as f: raw = f.read() try: if is_encrypted(raw): text = decrypt(raw) data = json.loads(text) else: data = json.loads(raw.decode("utf-8")) except Exception as e: raise ValueError(f"Invalid backup file: {e}") if not isinstance(data, dict) or not isinstance(data.get("servers"), list): raise ValueError("Invalid backup structure: missing 'servers' list") name = os.path.basename(src_path) dest = os.path.join(BACKUP_DIR, name) if os.path.exists(dest): stem, ext = os.path.splitext(name) suffix = datetime.now().strftime("_%H%M%S") name = stem + suffix + ext dest = os.path.join(BACKUP_DIR, name) os.makedirs(BACKUP_DIR, exist_ok=True) encrypted = encrypt(json.dumps(data, indent=2, ensure_ascii=False)) with open(dest, "wb") as f: f.write(encrypted) log.info(f"Backup imported: {name}") return name # ── Observer ────────────────────────────────────── def _notify(self): for cb in self._observers: try: cb() except Exception: pass def subscribe(self, callback: Callable): self._observers.append(callback) # ── CRUD ────────────────────────────────────────── def get_all(self) -> list[dict]: return list(self._data.get("servers", [])) def get_server(self, alias: str) -> Optional[dict]: for s in self._data.get("servers", []): if s["alias"] == alias: return dict(s) return None def add_server(self, server: dict): if self.get_server(server["alias"]): raise ValueError(f"Server '{server['alias']}' already exists") self._data.setdefault("servers", []).append(server) self._save() self._notify() def update_server(self, alias: str, updated: dict): servers = self._data.get("servers", []) for i, s in enumerate(servers): if s["alias"] == alias: new_alias = updated.get("alias", alias) servers[i] = updated if new_alias != alias: with self._statuses_lock: old_status = self._statuses.pop(alias, None) if old_status: self._statuses[new_alias] = old_status self._save() self._notify() return raise ValueError(f"Server '{alias}' not found") def remove_server(self, alias: str): self._data["servers"] = [s for s in self._data.get("servers", []) if s["alias"] != alias] with self._statuses_lock: self._statuses.pop(alias, None) self._save() self._notify() def get_ssh_key_path(self) -> str: path = self._data.get("ssh_key", {}).get("path", "~/.ssh/id_ed25519") return os.path.expanduser(path) # ── Status management (thread-safe) ─────────────── def get_check_interval(self) -> int: return self._check_interval def set_check_interval(self, seconds: int): self._check_interval = max(10, min(600, seconds)) self._save_settings() def set_status(self, alias: str, status: str): with self._statuses_lock: self._statuses[alias] = status def get_status(self, alias: str) -> str: with self._statuses_lock: return self._statuses.get(alias, "unknown") # ── Terminal font size ──────────────────────────────── def get_terminal_font_size(self) -> int: return self._terminal_font_size def set_terminal_font_size(self, size: int): self._terminal_font_size = max(6, min(28, size)) self._save_settings() # ── Update settings ────────────────────────────────── def get_update_mode(self) -> str: return self._update_mode def set_update_mode(self, mode: str): if mode in ("notify-only", "auto-download", "full-auto"): self._update_mode = mode self._save_settings() def get_last_update_check(self) -> float: return self._last_update_check def set_last_update_check(self): import time self._last_update_check = time.time() self._save_settings() def get_skip_version(self) -> str: return self._skip_version def set_skip_version(self, version: str): self._skip_version = version self._save_settings()