Files
server-manager/core/server_store.py
2026-03-01 09:00:27 -05:00

451 lines
17 KiB
Python

"""
Server store — CRUD + JSON persistence + observer pattern.
Supports encryption, backups, and configurable config path.
Thread-safe with atomic writes.
"""
import hashlib
import json
import os
import shutil
import threading
import time
from datetime import datetime
from typing import Callable, Optional
from core.encryption import encrypt, decrypt, is_encrypted
from core.logger import log
# Shared config — same file used by ssh.py and Claude Code /ssh skill
SHARED_DIR = os.path.expanduser("~/.server-connections")
SETTINGS_FILE = os.path.join(SHARED_DIR, "settings.json")
DEFAULT_SERVERS_FILE = os.path.join(SHARED_DIR, "servers.json")
BACKUP_DIR = os.path.join(SHARED_DIR, "backups")
# Fallback: local config dir (for example file)
LOCAL_CONFIG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "config")
EXAMPLE_FILE = os.path.join(LOCAL_CONFIG_DIR, "servers.example.json")
SERVER_TYPES = ["ssh", "telnet", "rdp", "vnc", "winrm", "mariadb", "mssql", "postgresql", "redis", "grafana", "prometheus"]
DEFAULT_PORTS = {
"ssh": 22,
"telnet": 23,
"rdp": 3389,
"vnc": 5900,
"winrm": 5985,
"mariadb": 3306,
"mssql": 1433,
"postgresql": 5432,
"redis": 6379,
"grafana": 3000,
"prometheus": 9090,
}
# Auto-backup interval: 10 minutes
_BACKUP_INTERVAL = 600
class ServerStore:
def __init__(self):
self._data: dict = {"servers": [], "ssh_key": {"type": "ed25519", "path": "~/.ssh/id_ed25519"}}
self._observers: list[Callable] = []
self._check_interval: int = 60
self._statuses: dict[str, str] = {} # alias -> "online" | "offline" | "unknown"
self._statuses_lock = threading.Lock()
self._file_lock = threading.Lock()
self._last_backup_time: float = 0
self._last_backup_hash: str = ""
self._terminal_font_size: int = 11
self._window_geometry: str = ""
self._servers_file: str = DEFAULT_SERVERS_FILE
# Update settings
self._update_mode: str = "auto-download" # "notify-only" | "auto-download" | "full-auto"
self._last_update_check: float = 0
self._skip_version: str = ""
self._load_settings()
self._load()
# ── Settings ──────────────────────────────────────
def _load_settings(self):
if os.path.exists(SETTINGS_FILE):
try:
with open(SETTINGS_FILE, "r", encoding="utf-8") as f:
settings = json.load(f)
path = settings.get("servers_path", "")
if path and os.path.exists(path):
self._servers_file = path
# Load language preference
from core import i18n
lang = settings.get("language", "en")
i18n.set_language(lang)
self._check_interval = settings.get("check_interval", 60)
self._terminal_font_size = settings.get("terminal_font_size", 11)
self._window_geometry = settings.get("window_geometry", "")
self._update_mode = settings.get("update_mode", "auto-download")
self._last_update_check = settings.get("last_update_check", 0)
self._skip_version = settings.get("skip_version", "")
except json.JSONDecodeError:
log.warning("Corrupted settings.json, using defaults")
except Exception as e:
log.error(f"Failed to load settings: {e}")
def _save_settings(self):
os.makedirs(SHARED_DIR, exist_ok=True)
from core import i18n
settings = {
"servers_path": self._servers_file,
"language": i18n.get_language(),
"check_interval": self._check_interval,
"terminal_font_size": self._terminal_font_size,
"window_geometry": self._window_geometry,
"update_mode": self._update_mode,
"last_update_check": self._last_update_check,
"skip_version": self._skip_version,
}
try:
tmp = SETTINGS_FILE + ".tmp"
with open(tmp, "w", encoding="utf-8") as f:
json.dump(settings, f, indent=2, ensure_ascii=False)
os.replace(tmp, SETTINGS_FILE)
except Exception as e:
log.error(f"Failed to save settings: {e}")
def get_config_path(self) -> str:
return self._servers_file
def set_config_path(self, path: str):
self._servers_file = path
self._save_settings()
self._load()
self._notify()
# ── Load / Save (encrypted, thread-safe, atomic) ──
def _load(self):
with self._file_lock:
self._load_unsafe()
def _load_unsafe(self):
if os.path.exists(self._servers_file):
try:
with open(self._servers_file, "rb") as f:
raw = f.read()
if not raw.strip():
return
if is_encrypted(raw):
text = decrypt(raw)
self._data = json.loads(text)
else:
self._data = json.loads(raw.decode("utf-8"))
# Auto-migration: backup plain file, then encrypt
pre_enc = os.path.join(BACKUP_DIR, "servers_pre-encryption.json")
if not os.path.exists(pre_enc):
os.makedirs(BACKUP_DIR, exist_ok=True)
shutil.copy2(self._servers_file, pre_enc)
self._save_unsafe()
# Re-encrypt with new key if needed (migration from old key)
self._save_unsafe()
except json.JSONDecodeError as e:
log.error(f"Corrupted servers.json: {e}")
self._try_restore_from_backup()
except Exception as e:
log.error(f"Failed to load servers: {e}")
self._try_restore_from_backup()
elif os.path.exists(EXAMPLE_FILE):
try:
with open(EXAMPLE_FILE, "r", encoding="utf-8") as f:
self._data = json.load(f)
self._save_unsafe()
except Exception as e:
log.error(f"Failed to load example: {e}")
def _try_restore_from_backup(self):
"""Attempt to restore from latest backup on corruption."""
backups = self.list_backups()
if backups:
log.warning(f"Attempting restore from backup: {backups[0]}")
try:
src = os.path.join(BACKUP_DIR, backups[0])
with open(src, "rb") as f:
raw = f.read()
if is_encrypted(raw):
text = decrypt(raw)
self._data = json.loads(text)
else:
self._data = json.loads(raw.decode("utf-8"))
self._save_unsafe()
log.info("Restored from backup successfully")
except Exception as e2:
log.error(f"Backup restore also failed: {e2}")
self._data = {"servers": [], "ssh_key": {"type": "ed25519", "path": "~/.ssh/id_ed25519"}}
else:
log.warning("No backups available, starting fresh")
self._data = {"servers": [], "ssh_key": {"type": "ed25519", "path": "~/.ssh/id_ed25519"}}
def _save(self):
with self._file_lock:
self._save_unsafe()
def _save_unsafe(self):
"""Write encrypted data atomically (tmp + rename)."""
os.makedirs(os.path.dirname(self._servers_file), exist_ok=True)
text = json.dumps(self._data, indent=2, ensure_ascii=False)
encrypted = encrypt(text)
tmp = self._servers_file + ".tmp"
try:
with open(tmp, "wb") as f:
f.write(encrypted)
os.replace(tmp, self._servers_file)
except Exception as e:
log.error(f"Failed to save servers: {e}")
# Clean up temp file
if os.path.exists(tmp):
try:
os.remove(tmp)
except Exception:
pass
return
# Auto-backup
now = time.time()
if now - self._last_backup_time >= _BACKUP_INTERVAL:
self._auto_backup()
def _data_hash(self) -> str:
text = json.dumps(self._data, sort_keys=True, ensure_ascii=False)
return hashlib.sha256(text.encode("utf-8")).hexdigest()
def _auto_backup(self):
current_hash = self._data_hash()
if current_hash == self._last_backup_hash:
self._last_backup_time = time.time()
return
try:
self.create_backup()
self._last_backup_hash = current_hash
except Exception as e:
log.warning(f"Auto-backup failed: {e}")
# ── Backups ───────────────────────────────────────
def create_backup(self) -> str:
os.makedirs(BACKUP_DIR, exist_ok=True)
stamp = datetime.now().strftime("%Y-%m-%d_%H%M%S")
name = f"servers_{stamp}.json"
dst = os.path.join(BACKUP_DIR, name)
shutil.copy2(self._servers_file, dst)
self._last_backup_time = time.time()
self._last_backup_hash = self._data_hash()
log.info(f"Backup created: {name}")
return name
def list_backups(self) -> list[str]:
if not os.path.isdir(BACKUP_DIR):
return []
files = [f for f in os.listdir(BACKUP_DIR) if f.startswith("servers_") and f.endswith(".json")]
files.sort(reverse=True)
return files
def restore_backup(self, filename: str):
src = os.path.join(BACKUP_DIR, filename)
if not os.path.exists(src):
raise FileNotFoundError(f"Backup not found: {filename}")
# Validate backup before restoring
with open(src, "rb") as f:
raw = f.read()
try:
if is_encrypted(raw):
text = decrypt(raw)
data = json.loads(text)
else:
data = json.loads(raw.decode("utf-8"))
except Exception as e:
raise ValueError(f"Backup is corrupted: {e}")
self._data = data
self._save()
self._notify()
log.info(f"Restored from: {filename}")
# ── Import / Export ──────────────────────────────
def export_config(self, dest_path: str) -> str:
text = json.dumps(self._data, indent=2, ensure_ascii=False)
with open(dest_path, "w", encoding="utf-8") as f:
f.write(text)
log.info(f"Config exported to: {dest_path}")
return dest_path
def import_config(self, src_path: str):
if not os.path.exists(src_path):
raise FileNotFoundError(f"File not found: {src_path}")
with open(src_path, "rb") as f:
raw = f.read()
try:
if is_encrypted(raw):
text = decrypt(raw)
data = json.loads(text)
else:
data = json.loads(raw.decode("utf-8"))
except Exception as e:
raise ValueError(f"Invalid config file: {e}")
if not isinstance(data, dict) or not isinstance(data.get("servers"), list):
raise ValueError("Invalid config structure: missing 'servers' list")
self._data = data
self._save()
self._notify()
log.info(f"Config imported from: {src_path}")
def export_backup(self, filename: str, dest_path: str) -> str:
src = os.path.join(BACKUP_DIR, filename)
if not os.path.exists(src):
raise FileNotFoundError(f"Backup not found: {filename}")
with open(src, "rb") as f:
raw = f.read()
if is_encrypted(raw):
text = decrypt(raw)
data = json.loads(text)
else:
data = json.loads(raw.decode("utf-8"))
with open(dest_path, "w", encoding="utf-8") as f:
json.dump(data, indent=2, ensure_ascii=False, fp=f)
log.info(f"Backup exported to: {dest_path}")
return dest_path
def import_backup(self, src_path: str) -> str:
if not os.path.exists(src_path):
raise FileNotFoundError(f"File not found: {src_path}")
with open(src_path, "rb") as f:
raw = f.read()
try:
if is_encrypted(raw):
text = decrypt(raw)
data = json.loads(text)
else:
data = json.loads(raw.decode("utf-8"))
except Exception as e:
raise ValueError(f"Invalid backup file: {e}")
if not isinstance(data, dict) or not isinstance(data.get("servers"), list):
raise ValueError("Invalid backup structure: missing 'servers' list")
name = os.path.basename(src_path)
dest = os.path.join(BACKUP_DIR, name)
if os.path.exists(dest):
stem, ext = os.path.splitext(name)
suffix = datetime.now().strftime("_%H%M%S")
name = stem + suffix + ext
dest = os.path.join(BACKUP_DIR, name)
os.makedirs(BACKUP_DIR, exist_ok=True)
encrypted = encrypt(json.dumps(data, indent=2, ensure_ascii=False))
with open(dest, "wb") as f:
f.write(encrypted)
log.info(f"Backup imported: {name}")
return name
# ── Observer ──────────────────────────────────────
def _notify(self):
for cb in self._observers:
try:
cb()
except Exception:
pass
def subscribe(self, callback: Callable):
self._observers.append(callback)
# ── CRUD ──────────────────────────────────────────
def get_all(self) -> list[dict]:
return list(self._data.get("servers", []))
def get_server(self, alias: str) -> Optional[dict]:
for s in self._data.get("servers", []):
if s["alias"] == alias:
return dict(s)
return None
def add_server(self, server: dict):
if self.get_server(server["alias"]):
raise ValueError(f"Server '{server['alias']}' already exists")
self._data.setdefault("servers", []).append(server)
self._save()
self._notify()
def update_server(self, alias: str, updated: dict):
servers = self._data.get("servers", [])
for i, s in enumerate(servers):
if s["alias"] == alias:
new_alias = updated.get("alias", alias)
servers[i] = updated
if new_alias != alias:
with self._statuses_lock:
old_status = self._statuses.pop(alias, None)
if old_status:
self._statuses[new_alias] = old_status
self._save()
self._notify()
return
raise ValueError(f"Server '{alias}' not found")
def remove_server(self, alias: str):
self._data["servers"] = [s for s in self._data.get("servers", []) if s["alias"] != alias]
with self._statuses_lock:
self._statuses.pop(alias, None)
self._save()
self._notify()
def get_ssh_key_path(self) -> str:
path = self._data.get("ssh_key", {}).get("path", "~/.ssh/id_ed25519")
return os.path.expanduser(path)
# ── Status management (thread-safe) ───────────────
def get_check_interval(self) -> int:
return self._check_interval
def set_check_interval(self, seconds: int):
self._check_interval = max(10, min(600, seconds))
self._save_settings()
def set_status(self, alias: str, status: str):
with self._statuses_lock:
self._statuses[alias] = status
def get_status(self, alias: str) -> str:
with self._statuses_lock:
return self._statuses.get(alias, "unknown")
# ── Terminal font size ────────────────────────────────
def get_terminal_font_size(self) -> int:
return self._terminal_font_size
def set_terminal_font_size(self, size: int):
self._terminal_font_size = max(6, min(28, size))
self._save_settings()
# ── Update settings ──────────────────────────────────
def get_update_mode(self) -> str:
return self._update_mode
def set_update_mode(self, mode: str):
if mode in ("notify-only", "auto-download", "full-auto"):
self._update_mode = mode
self._save_settings()
def get_last_update_check(self) -> float:
return self._last_update_check
def set_last_update_check(self):
import time
self._last_update_check = time.time()
self._save_settings()
def get_skip_version(self) -> str:
return self._skip_version
def set_skip_version(self, version: str):
self._skip_version = version
self._save_settings()