#!/usr/bin/env python3
"""
KeepItTechie Homelab MCP Server (Python)

Safe-by-default MCP server that exposes whitelisted Linux automation tools.

Key goals:
- No arbitrary shell execution.
- Whitelist-only tools and arguments.
- Optional system changes (restart) disabled unless explicitly enabled.

Run:
  python3 homelab_mcp_server.py

Then connect with MCP Inspector:
  npx -y @modelcontextprotocol/inspector
  Connect to: http://<server-ip>:8000/mcp
"""

from __future__ import annotations

import os
import shlex
import shutil
import subprocess
from typing import Any, Dict, List, Optional

import httpx
from mcp.server.fastmcp import FastMCP

# ----------------------------
# Config (environment variables)
# ----------------------------
MCP_NAME = os.getenv("MCP_NAME", "keepittechie-homelab")
MCP_HOST = os.getenv("MCP_HOST", "0.0.0.0")
MCP_PORT = int(os.getenv("MCP_PORT", "8000"))

# Ollama
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://10.10.0.60:11434")
OLLAMA_DEFAULT_MODEL = os.getenv("OLLAMA_DEFAULT_MODEL", "llama3.2:latest")

# Safety toggles
ENABLE_SERVICE_RESTART = os.getenv("ENABLE_SERVICE_RESTART", "false").lower() in ("1", "true", "yes")
ENABLE_DOCKER_TOOLS = os.getenv("ENABLE_DOCKER_TOOLS", "true").lower() in ("1", "true", "yes")

# Allowlists
# Comma-separated service names you allow for systemctl status/restart.
# Example: "docker,jellyfin,nginx,ssh"
ALLOWED_SERVICES = {
    s.strip()
    for s in os.getenv("ALLOWED_SERVICES", "docker,ssh,nginx,jellyfin").split(",")
    if s.strip()
}

# Docker allowlists (optional)
# If empty, docker_logs will allow any container name (still sanitized), but you can lock it down.
ALLOWED_DOCKER_CONTAINERS = {
    c.strip()
    for c in os.getenv("ALLOWED_DOCKER_CONTAINERS", "").split(",")
    if c.strip()
}

# Timeouts
CMD_TIMEOUT_SECONDS = float(os.getenv("CMD_TIMEOUT_SECONDS", "10"))
HTTP_TIMEOUT_SECONDS = float(os.getenv("HTTP_TIMEOUT_SECONDS", "480"))

# ----------------------------
# MCP server instance
# ----------------------------
from mcp.server.transport_security import TransportSecuritySettings

mcp = FastMCP(
    MCP_NAME,
    json_response=True,
    transport_security=TransportSecuritySettings(
        # Keep protection ON, but allow your LAN host(s)
        enable_dns_rebinding_protection=True,
        allowed_hosts=[
            "127.0.0.1:*",
            "localhost:*",
            "10.10.0.106:8000",
            "10.10.0.106:*",
        ],
    ),
)

# ----------------------------
# Helpers
# ----------------------------
def _have_cmd(cmd: str) -> bool:
    return shutil.which(cmd) is not None


def _run_cmd(argv: List[str], timeout: float = CMD_TIMEOUT_SECONDS) -> Dict[str, Any]:
    """
    Run a command safely (no shell) and return structured output.
    """
    try:
        proc = subprocess.run(
            argv,
            capture_output=True,
            text=True,
            timeout=timeout,
            check=False,
        )
        return {
            "ok": proc.returncode == 0,
            "returncode": proc.returncode,
            "stdout": (proc.stdout or "").strip(),
            "stderr": (proc.stderr or "").strip(),
            "cmd": " ".join(shlex.quote(a) for a in argv),
        }
    except subprocess.TimeoutExpired:
        return {
            "ok": False,
            "returncode": None,
            "stdout": "",
            "stderr": f"Command timed out after {timeout} seconds",
            "cmd": " ".join(shlex.quote(a) for a in argv),
        }
    except Exception as e:
        return {
            "ok": False,
            "returncode": None,
            "stdout": "",
            "stderr": f"Command failed: {e}",
            "cmd": " ".join(shlex.quote(a) for a in argv),
        }


def _require_allowed(name: str, allowed: set[str], kind: str) -> None:
    if name not in allowed:
        raise ValueError(
            f"{kind} '{name}' is not in allowlist. Allowed {kind}s: {sorted(list(allowed))}"
        )


# ----------------------------
# Tools (Safe informational tools)
# ----------------------------
@mcp.tool()
def health_check() -> Dict[str, Any]:
    """
    Basic health check for the MCP server. Useful to confirm connectivity.
    """
    return {
        "ok": True,
        "server": MCP_NAME,
        "host": MCP_HOST,
        "port": MCP_PORT,
        "enable_service_restart": ENABLE_SERVICE_RESTART,
        "enable_docker_tools": ENABLE_DOCKER_TOOLS,
        "allowed_services": sorted(list(ALLOWED_SERVICES)),
        "ollama_base_url": OLLAMA_BASE_URL,
        "ollama_default_model": OLLAMA_DEFAULT_MODEL,
    }


@mcp.tool()
def disk_usage() -> Dict[str, Any]:
    """
    Show disk usage (df -h).
    """
    return _run_cmd(["df", "-h"])

@mcp.tool()
async def summarize_disk_usage(model: Optional[str] = None) -> Dict[str, Any]:
    """
    Runs df -h on this MCP server and asks Ollama to summarize it.
    """
    df_result = disk_usage()
    if not df_result.get("ok"):
        return {"ok": False, "error": "Failed to run df -h", "df": df_result}

    df_text = df_result.get("stdout", "").strip()
    if not df_text:
        return {"ok": False, "error": "df -h output was empty", "df": df_result}

    prompt = (
        "You are a Linux sysadmin assistant.\n"
        "Summarize the following df -h output in exactly 3 bullets.\n"
        "Mention any filesystem at or above 80% usage.\n"
        "If none are at/above 80%, explicitly say so.\n\n"
        "df -h output:\n"
        f"{df_text}\n"
    )

    return await ollama_generate(prompt=prompt, model=model)

@mcp.tool()
def memory_usage() -> Dict[str, Any]:
    """
    Show memory usage (free -h).
    """
    return _run_cmd(["free", "-h"])


@mcp.tool()
async def summarize_memory_usage(model: Optional[str] = None) -> Dict[str, Any]:
    """
    Runs free -h on this MCP server and asks Ollama to summarize it.
    """
    mem_result = memory_usage()
    if not mem_result.get("ok"):
        return {"ok": False, "error": "Failed to run free -h", "mem": mem_result}

    mem_text = mem_result.get("stdout", "").strip()
    if not mem_text:
        return {"ok": False, "error": "free -h output was empty", "mem": mem_result}

    prompt = (
        "You are a Linux sysadmin assistant.\n"
        "Summarize the following free -h output in exactly 3 bullets.\n"
        "Include total RAM, used RAM, available RAM, and swap usage if present.\n"
        "If swap usage is high, mention it.\n\n"
        "free -h output:\n"
        f"{mem_text}\n"
    )

    return await ollama_generate(prompt=prompt, model=model)

@mcp.tool()
def uptime() -> Dict[str, Any]:
    """
    Show uptime (uptime).
    """
    return _run_cmd(["uptime"])

@mcp.tool()
async def summarize_uptime(model: Optional[str] = None) -> Dict[str, Any]:
    """
    Runs uptime and asks Ollama to summarize uptime + load averages.
    """
    up_result = uptime()
    if not up_result.get("ok"):
        return {"ok": False, "error": "Failed to run uptime", "uptime": up_result}

    up_text = up_result.get("stdout", "").strip()
    if not up_text:
        return {"ok": False, "error": "uptime output was empty", "uptime": up_result}

    prompt = (
        "You are a Linux sysadmin assistant.\n"
        "Explain this uptime output in 2-3 short bullets.\n"
        "Mention uptime duration, number of users, and load averages.\n"
        "If load averages look high relative to a small VM, mention it.\n\n"
        "uptime output:\n"
        f"{up_text}\n"
    )

    return await ollama_generate(prompt=prompt, model=model)

@mcp.tool()
def ip_info() -> Dict[str, Any]:
    """
    Show IP addresses and interfaces (ip a).
    """
    return _run_cmd(["ip", "a"])

@mcp.tool()
async def summarize_ip_info(model: Optional[str] = None) -> Dict[str, Any]:
    """
    Runs ip a and asks Ollama to summarize key addresses/interfaces.
    """
    ip_result = ip_info()
    if not ip_result.get("ok"):
        return {"ok": False, "error": "Failed to run ip a", "ip": ip_result}

    ip_text = ip_result.get("stdout", "").strip()
    if not ip_text:
        return {"ok": False, "error": "ip a output was empty", "ip": ip_result}

    prompt = (
        "You are a Linux sysadmin assistant.\n"
        "From the `ip a` output, summarize:\n"
        "- primary interface name\n"
        "- IPv4 address(es)\n"
        "- whether the interface looks UP\n"
        "Keep it to 3-5 bullets max.\n\n"
        "ip a output:\n"
        f"{ip_text}\n"
    )

    return await ollama_generate(prompt=prompt, model=model)

# ----------------------------
# Docker tools (optional)
# ----------------------------
@mcp.tool()
def docker_ps() -> Dict[str, Any]:
    """
    List running Docker containers (docker ps).
    Requires docker installed and accessible to the user running this server.
    """
    if not ENABLE_DOCKER_TOOLS:
        return {"ok": False, "error": "Docker tools disabled (ENABLE_DOCKER_TOOLS=false)"}
    if not _have_cmd("docker"):
        return {"ok": False, "error": "docker command not found on this host"}

    return _run_cmd(["docker", "ps"])


@mcp.tool()
def docker_logs(container: str, tail: int = 200) -> Dict[str, Any]:
    """
    Fetch recent logs for a Docker container (docker logs --tail N).
    If ALLOWED_DOCKER_CONTAINERS env var is set, container must be in that allowlist.
    """
    if not ENABLE_DOCKER_TOOLS:
        return {"ok": False, "error": "Docker tools disabled (ENABLE_DOCKER_TOOLS=false)"}
    if not _have_cmd("docker"):
        return {"ok": False, "error": "docker command not found on this host"}

    # Optional allowlist enforcement
    if ALLOWED_DOCKER_CONTAINERS:
        _require_allowed(container, ALLOWED_DOCKER_CONTAINERS, "container")

    # Basic sanity
    if tail < 1 or tail > 5000:
        return {"ok": False, "error": "tail must be between 1 and 5000"}

    return _run_cmd(["docker", "logs", "--tail", str(tail), container], timeout=CMD_TIMEOUT_SECONDS)

@mcp.tool()
async def summarize_docker_ps(model: Optional[str] = None) -> Dict[str, Any]:
    """
    Runs docker ps and asks Ollama to summarize what's running.
    """
    ps_result = docker_ps()
    if not ps_result.get("ok"):
        return {"ok": False, "error": "Failed to run docker ps", "docker_ps": ps_result}

    ps_text = ps_result.get("stdout", "").strip()
    if not ps_text:
        return {"ok": False, "error": "docker ps output was empty", "docker_ps": ps_result}

    prompt = (
        "You are a Linux sysadmin assistant.\n"
        "Summarize the following `docker ps` output.\n"
        "List running container names, images, and mapped ports.\n"
        "Keep it short: 5 bullets max.\n\n"
        "docker ps output:\n"
        f"{ps_text}\n"
    )

    return await ollama_generate(prompt=prompt, model=model)

# ----------------------------
# systemd tools (status + optional restart)
# ----------------------------
@mcp.tool()
def systemctl_status(service: str) -> Dict[str, Any]:
    """
    Get systemd service status (systemctl status SERVICE --no-pager).
    Service must be in ALLOWED_SERVICES.
    """
    _require_allowed(service, ALLOWED_SERVICES, "service")
    if not _have_cmd("systemctl"):
        return {"ok": False, "error": "systemctl not found on this host"}

    return _run_cmd(["systemctl", "status", service, "--no-pager"], timeout=CMD_TIMEOUT_SECONDS)


@mcp.tool()
def systemctl_restart(service: str) -> Dict[str, Any]:
    """
    Restart a systemd service (systemctl restart SERVICE).
    Disabled by default. Enable with ENABLE_SERVICE_RESTART=true.
    Service must be in ALLOWED_SERVICES.
    """
    if not ENABLE_SERVICE_RESTART:
        return {
            "ok": False,
            "error": "Service restart disabled. Set ENABLE_SERVICE_RESTART=true to allow this tool.",
        }

    _require_allowed(service, ALLOWED_SERVICES, "service")
    if not _have_cmd("systemctl"):
        return {"ok": False, "error": "systemctl not found on this host"}

    # Restart
    result = _run_cmd(["systemctl", "restart", service], timeout=CMD_TIMEOUT_SECONDS)
    if not result.get("ok"):
        return result

    # Confirm status
    status = _run_cmd(["systemctl", "is-active", service], timeout=CMD_TIMEOUT_SECONDS)
    return {
        "ok": True,
        "restart": result,
        "is_active": status,
    }


# ----------------------------
# full summarize
# ----------------------------
@mcp.tool()
async def system_health_report(model: Optional[str] = None) -> Dict[str, Any]:
    """
    Generates a short health report by combining disk, memory, and uptime outputs.
    """
    df_result = disk_usage()
    mem_result = memory_usage()
    up_result = uptime()

    if not df_result.get("ok") or not mem_result.get("ok") or not up_result.get("ok"):
        return {
            "ok": False,
            "error": "One or more commands failed",
            "disk": df_result,
            "memory": mem_result,
            "uptime": up_result,
        }

    df_text = (df_result.get("stdout") or "").strip()
    mem_text = (mem_result.get("stdout") or "").strip()
    up_text = (up_result.get("stdout") or "").strip()

    prompt = (
        "You are a Linux sysadmin assistant.\n"
        "Create a short system health report in 6 bullets max.\n"
        "Include:\n"
        "- any disk filesystem >= 80%\n"
        "- memory summary (total/used/available) and swap notes\n"
        "- uptime and load averages\n"
        "End with a final line: 'Overall: OK' or 'Overall: Needs Attention' with a brief reason.\n\n"
        "df -h output:\n"
        f"{df_text}\n\n"
        "free -h output:\n"
        f"{mem_text}\n\n"
        "uptime output:\n"
        f"{up_text}\n"
    )

    return await ollama_generate(prompt=prompt, model=model)


# ----------------------------
# Ollama tool (local LLM call via HTTP)
# ----------------------------
@mcp.tool()
async def ollama_generate(prompt: str, model: Optional[str] = None) -> Dict[str, Any]:
    """
    Generate a completion using your Ollama server.

    Args:
      prompt: The prompt to send to the model.
      model: Optional model name. Defaults to OLLAMA_DEFAULT_MODEL.

    Notes:
      Uses POST /api/generate with stream=false so we get one JSON response.
      Ollama API docs: /api/generate supports stream:false. (See Ollama docs)
    """
    use_model = (model or OLLAMA_DEFAULT_MODEL).strip()
    if not prompt or not prompt.strip():
        return {"ok": False, "error": "prompt is required"}

    url = f"{OLLAMA_BASE_URL.rstrip('/')}/api/generate"
    payload = {
        "model": use_model,
        "prompt": prompt,
        "stream": False,
    }

    try:
        async with httpx.AsyncClient(timeout=HTTP_TIMEOUT_SECONDS) as client:
            r = await client.post(url, json=payload)
            r.raise_for_status()
            data = r.json()

        # Ollama returns fields like response, done, context, etc.
        return {
            "ok": True,
            "model": use_model,
            "ollama_url": url,
            "response": data.get("response", ""),
            "raw": data,
        }
    except httpx.HTTPError as e:
        return {"ok": False, "error": f"Ollama HTTP error: {e}", "ollama_url": url}
    except Exception as e:
        return {"ok": False, "error": f"Unexpected error: {e}", "ollama_url": url}


# ----------------------------
# Entrypoint
# ----------------------------
if __name__ == "__main__":
    import uvicorn

    # Build the ASGI app for Streamable HTTP
    app = mcp.streamable_http_app()

    # Serve it on the network
    uvicorn.run(app, host=MCP_HOST, port=MCP_PORT, log_level="info")
