# -*- coding: utf-8 -*-
# !/usr/bin/env python3
"""
OTA 认证网关 (最终整合版)
- 无状态 Basic Auth / 自定义 Header 认证
- 动态 Manifest 下发 (双行 stable 文件支持灰度)
- 极速灰度范围判定 (SQL LIMIT 1 + 联合索引 + 结果缓存)
"""

import base64
import hashlib
import hmac
import logging
import os
import re
from contextlib import asynccontextmanager
from typing import Optional

import aiomysql
from cachetools import TTLCache
from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.responses import JSONResponse
from pathlib import Path

# ═══════════════════════════════════════════════════════
#  配置
# ═══════════════════════════════════════════════════════
MYSQL_HOST = os.environ.get("MYSQL_HOST", "rm-uf6i08orz6jc9xw71o.mysql.rds.aliyuncs.com")
MYSQL_PORT = int(os.environ.get("MYSQL_PORT", "3306"))
MYSQL_DB = os.environ.get("MYSQL_DB", "emqx")
MYSQL_USER = os.environ.get("MYSQL_USER", "ota_auth")
MYSQL_PASSWORD = os.environ.get("MYSQL_PASSWORD", "Lwgetin1")

FIRMWARE_ROOT = os.environ.get("FIRMWARE_ROOT", "/var/www/ota/firmware")

CREDENTIAL_TTL = int(os.environ.get("CREDENTIAL_TTL", "300"))  # 凭证缓存 5分钟
MAX_CREDENTIALS = int(os.environ.get("MAX_CREDENTIALS", "50000"))

OTA_GRAY_RESULT_TTL = int(os.environ.get("OTA_GRAY_RESULT_TTL", "180"))  # 灰度结果缓存 3分钟
MAX_OTA_GRAY_RESULTS = int(os.environ.get("MAX_OTA_GRAY_RESULTS", "500000"))

# ═══════════════════════════════════════════════════════
#  日志与缓存
# ═══════════════════════════════════════════════════════
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)-7s %(name)s  %(message)s",
                    datefmt="%Y-%m-%d %H:%M:%S")
logger = logging.getLogger("ota-auth")

credential_cache: TTLCache[str, str] = TTLCache(maxsize=MAX_CREDENTIALS, ttl=CREDENTIAL_TTL)
ota_gray_result_cache: TTLCache[str, bool] = TTLCache(maxsize=MAX_OTA_GRAY_RESULTS, ttl=OTA_GRAY_RESULT_TTL)

db_pool: Optional[aiomysql.Pool] = None


@asynccontextmanager
async def lifespan(_app: FastAPI):
    global db_pool
    db_pool = await aiomysql.create_pool(
        host=MYSQL_HOST, port=MYSQL_PORT, db=MYSQL_DB,
        user=MYSQL_USER, password=MYSQL_PASSWORD,
        minsize=2, maxsize=50, charset="utf8mb4", autocommit=True
    )
    logger.info("MySQL 连接池已建立  host=%s  db=%s", MYSQL_HOST, MYSQL_DB)
    yield
    if db_pool:
        db_pool.close()
        await db_pool.wait_closed()


app = FastAPI(title="OTA Auth Gateway", lifespan=lifespan)


# ═══════════════════════════════════════════════════════
#  核心工具函数
# ═══════════════════════════════════════════════════════
def sha256_suffix(plaintext: str, salt: str) -> str:
    return hashlib.sha256((plaintext + salt).encode("utf-8")).hexdigest()


def extract_credentials(request: Request):
    """从 Header 提取凭证，支持 Basic Auth 和自定义 Header"""
    device_id = device_secret = None
    auth_header = request.headers.get("Authorization", "")
    if auth_header.startswith("Basic "):
        try:
            decoded = base64.b64decode(auth_header[6:]).decode("utf-8")
            device_id, device_secret = decoded.split(":", 1)
            device_id, device_secret = device_id.strip(), device_secret.strip()
        except Exception:
            pass
    if not device_id:
        device_id = request.headers.get("X-Device-ID", "").strip()
        device_secret = request.headers.get("X-Device-Secret", "").strip()
    return device_id, device_secret


async def authenticate_device(device_id: str, device_secret: str):
    """验证设备凭证 (查缓存或查 mqtt_user 表)"""
    if not device_id or not device_secret:
        raise HTTPException(401, "Missing credentials")

    cached_secret = credential_cache.get(device_id)
    if cached_secret is not None:
        if hmac.compare_digest(device_secret, cached_secret):
            return
        raise HTTPException(401, "Invalid credentials")

    async with db_pool.acquire() as conn:
        async with conn.cursor(aiomysql.DictCursor) as cur:
            await cur.execute("SELECT password_hash, salt FROM mqtt_user WHERE clientid = %s", (device_id,))
            row = await cur.fetchone()

    if not row:
        raise HTTPException(401, "Device not found")

    stored_hash, salt = row["password_hash"] or "", row["salt"] or ""
    if not hmac.compare_digest(sha256_suffix(device_secret, salt), stored_hash):
        raise HTTPException(401, "Invalid credentials")

    credential_cache[device_id] = device_secret


async def check_device_ota_gray(device_id: str, product_id: str) -> bool:
    """判断设备是否在灰度范围内 (查结果缓存或查 ota_gray_rule 表)"""
    cache_key = f"{device_id}:{product_id}"

    if cache_key in ota_gray_result_cache:
        return ota_gray_result_cache[cache_key]

    is_ota_gray = False
    try:
        async with db_pool.acquire() as conn:
            async with conn.cursor() as cur:
                await cur.execute(
                    "SELECT 1 FROM ota_gray_rule "
                    "WHERE product_id = %s AND start_device_id <= %s AND end_device_id >= %s "
                    "LIMIT 1",
                    (product_id, device_id, device_id)
                )
                if await cur.fetchone():
                    is_ota_gray = True
    except Exception as exc:
        logger.error("MySQL 灰度查询异常: %s", exc)

    ota_gray_result_cache[cache_key] = is_ota_gray
    return is_ota_gray


# ═══════════════════════════════════════════════════════
#  路由 1：静态文件下载鉴权 (Caddy forward_auth)
# ═══════════════════════════════════════════════════════
@app.api_route("/auth", methods=["GET", "HEAD", "POST"])
async def verify_for_file_download(request: Request):
    device_id, device_secret = extract_credentials(request)
    await authenticate_device(device_id, device_secret)
    return Response(status_code=200)


# ═══════════════════════════════════════════════════════
#  路由 2：动态 Manifest 下发 (灰度核心)
# ═══════════════════════════════════════════════════════
# ★ 修复：FastAPI 不允许路径中出现重复的参数名，因此使用 manifest_name 接收文件名
@app.api_route("/{product_id}/{manifest_name}.manifest", methods=["GET", "HEAD"])
async def get_dynamic_manifest(request: Request, product_id: str, manifest_name: str):
    # 1. 认证
    device_id, device_secret = extract_credentials(request)
    await authenticate_device(device_id, device_secret)

    # 防路径遍历
    if not re.match(r'^[a-zA-Z0-9_-]+$', product_id):
        raise HTTPException(400, "Invalid product ID")

    # ★ 安全校验：确保请求的 manifest 文件名与产品目录名一致
    if product_id != manifest_name:
        raise HTTPException(400, "Manifest name must match product ID")

    current_version = request.headers.get("X-Current-Version", "").strip()

    # 2. 解析 stable 文件 (第一行: 全量基线版本, 第二行: 灰度版本)
    stable_file = Path(FIRMWARE_ROOT) / product_id / "stable"
    if not stable_file.exists():
        raise HTTPException(404, "stable file not found")

    lines = stable_file.read_text().strip().splitlines()
    baseline_ver = lines[0].strip() if len(lines) > 0 else None
    ota_gray_ver = lines[1].strip() if len(lines) > 1 else None

    if not baseline_ver:
        raise HTTPException(500, "stable file format error: missing baseline version")

    # 3. 查询灰度状态
    is_ota_gray_device = await check_device_ota_gray(device_id, product_id)

    # 4. 决定目标版本
    target_version = ota_gray_ver if (is_ota_gray_device and ota_gray_ver) else baseline_ver

    if not target_version:
        raise HTTPException(500, "Target version undetermined")

    # 5. 版本比对 (避免重复更新)
    if current_version == target_version:
        return Response(status_code=204)

    # 6. 读取并返回对应的 manifest 文件
    if not re.match(r'^[a-zA-Z0-9._-]+$', target_version):
        raise HTTPException(400, "Invalid target version format")

    manifest_path = Path(FIRMWARE_ROOT) / product_id / target_version / f"{product_id}.manifest"

    if not manifest_path.exists():
        logger.error(f"Manifest 文件缺失: {manifest_path}")
        raise HTTPException(404, "Manifest file not found on server")

    channel = "ota_gray" if is_ota_gray_device else "stable"
    logger.info(f"下发 Manifest: 设备={device_id}, 产品={product_id}, 通道={channel}, 目标版本={target_version}")

    return Response(content=manifest_path.read_bytes(), media_type="application/octet-stream")


# ═══════════════════════════════════════════════════════
#  管理与监控接口
# ═══════════════════════════════════════════════════════
@app.delete("/admin/cache/credentials/{device_id}")
async def revoke_credential_cache(device_id: str):
    """吊销设备凭证缓存"""
    credential_cache.pop(device_id, None)
    return JSONResponse({"status": "revoked", "device_id": device_id})


@app.post("/admin/cache/ota_gray/clear")
async def clear_ota_gray_result_cache():
    """清空灰度判定结果缓存 (修改 ota_gray_rule 后调用)"""
    ota_gray_result_cache.clear()
    logger.info("灰度结果缓存已被手动清空")
    return JSONResponse({"status": "ok", "message": "灰度结果缓存已清空"})


@app.get("/health")
async def health():
    return JSONResponse({
        "status": "ok",
        "cached_credentials": len(credential_cache),
        "cached_ota_gray_results": len(ota_gray_result_cache),
        "db_pool_size": db_pool.size if db_pool else 0,
        "db_pool_freesize": db_pool.freesize if db_pool else 0,
    })