from __future__ import annotations
import io
import base64
import zipfile
from typing import List, Tuple
import cv2
from flask import request
import numpy as np
import svgwrite
from dataclasses import dataclass

@dataclass
class GCodeConfig:
    # scale/placement (keep your A4 fit values here)
    mm_per_px: float = 0.1
    offset_x_mm: float = 0.0
    offset_y_mm: float = 0.0

    feed_draw: int = 1800
    feed_travel: int = 3000

    # Use SERVO (not Z)
    use_z_lift: bool = False
    z_up: float = 5.0
    z_down: float = 0.0

    # ---- Servo behavior ----
    # Down = M03 S{servo_down_s}
    servo_down_s: int = 100          # -> M03 S100 for pen down
    # Up = either M05 (if True) or M03 S{servo_up_s} (if False)
    servo_up_uses_M5: bool = True    # True => M05; False => M03 S{servo_up_s}
    servo_up_s: int = 0              # -> M03 S0 for pen up (when using M3S-up)
    servo_dwell_s: float = 0.70      # wait after toggling pen (seconds)

    # Some firmwares require leading zero: M03/M05
    pad_zero_in_m_codes: bool = True

def polylines_to_gcode(polylines, width_px: int, height_px: int, cfg: GCodeConfig) -> str:
    lines = []
    ap = lines.append
    ap("; Pen plot from image2lines")
    ap("G90")
    ap("G21")
    dwell = float(cfg.servo_dwell_s)

    def mcode(n: int) -> str:
        return f"M{n:02d}" if cfg.pad_zero_in_m_codes else f"M{n}"

    def px_to_mm(xp, yp):
        # flip Y (SVG->CNC), scale, then offset to page
        x_mm = xp * cfg.mm_per_px + cfg.offset_x_mm
        y_mm = (height_px - yp) * cfg.mm_per_px + cfg.offset_y_mm
        return x_mm, y_mm

    def pen_up():
        if cfg.use_z_lift:
            ap(f"G0 Z{cfg.z_up:.3f} F{cfg.feed_travel}")
        else:
            if cfg.servo_up_uses_M5:
                ap(mcode(5))  # M05
            else:
                ap(f"{mcode(3)} S{int(cfg.servo_up_s)}")  # M03 S0
            ap(f"G4 P{dwell:.3f}")

    def pen_down():
        if cfg.use_z_lift:
            ap(f"G1 Z{cfg.z_down:.3f} F{cfg.feed_travel}")
        else:
            ap(f"{mcode(3)} S{int(cfg.servo_down_s)}")  # M03 S100
            ap(f"G4 P{dwell:.3f}")

    # Start with pen up
    pen_up()

    for pts in polylines:
        if len(pts) < 1:
            continue
        x0, y0 = px_to_mm(float(pts[0][0]), float(pts[0][1]))
        ap(f"G0 X{x0:.3f} Y{y0:.3f} F{cfg.feed_travel}")
        pen_down()
        ap(f"F{cfg.feed_draw}")
        for x, y in pts[1:]:
            xm, ym = px_to_mm(float(x), float(y))
            ap(f"G1 X{xm:.3f} Y{ym:.3f}")
        pen_up()

    ap("; done")
    return "\n".join(lines)


def polylines_to_gcode_basic(polylines, width_px: int, height_px: int, mm_per_px=0.1) -> str:
    lines = ["; basic gcode", "G90", "G21", "G0 Z5.000"]
    ap = lines.append

    def px_to_mm(xp, yp):
        return xp * mm_per_px, (height_px - yp) * mm_per_px

    for pts in polylines:
        if len(pts) < 1:
            continue
        x0, y0 = px_to_mm(float(pts[0][0]), float(pts[0][1]))
        ap(f"G0 X{x0:.3f} Y{y0:.3f} F3000")
        ap("G1 Z0.000 F3000")
        ap("F1800")
        for x, y in pts[1:]:
            xm, ym = px_to_mm(float(x), float(y))
            ap(f"G1 X{xm:.3f} Y{ym:.3f}")
        ap("G0 Z5.000 F3000")

    ap("; done")
    return "\n".join(lines)


# -------------------------------
# Helpers (unchanged logic)
# -------------------------------
def read_image_to_array(file_bytes: bytes) -> np.ndarray:
    arr = np.frombuffer(file_bytes, dtype=np.uint8)
    img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
    if img is None:
        raise ValueError("Unable to read image.")
    return img

def resize_keep_aspect(img: np.ndarray, target_w: int) -> np.ndarray:
    h, w = img.shape[:2]
    if w == 0 or h == 0:
        return img
    if w <= target_w:
        return img.copy()
    scale = target_w / float(w)
    new_size = (target_w, int(round(h * scale)))
    return cv2.resize(img, new_size, interpolation=cv2.INTER_AREA)

def to_gray(img_bgr: np.ndarray) -> np.ndarray:
    return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)

def canny_edges(g: np.ndarray, low: int, high: int) -> np.ndarray:
    return cv2.Canny(g, threshold1=low, threshold2=high)

def adaptive_thresh(g: np.ndarray, block: int, C: int) -> np.ndarray:
    block = max(3, int(block) | 1)
    th = cv2.adaptiveThreshold(
        g, 255,
        cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV,
        block, C
    )
    return th

def sobel_mag(g: np.ndarray) -> np.ndarray:
    sx = cv2.Sobel(g, cv2.CV_32F, 1, 0, ksize=3)
    sy = cv2.Sobel(g, cv2.CV_32F, 0, 1, ksize=3)
    mag = cv2.magnitude(sx, sy)
    mag = np.clip(mag / (mag.max() + 1e-6) * 255, 0, 255).astype(np.uint8)
    return (mag > 50).astype(np.uint8) * 255

def scharr_mag(g: np.ndarray) -> np.ndarray:
    sx = cv2.Scharr(g, cv2.CV_32F, 1, 0)
    sy = cv2.Scharr(g, cv2.CV_32F, 0, 1)
    mag = cv2.magnitude(sx, sy)
    mag = np.clip(mag / (mag.max() + 1e-6) * 255, 0, 255).astype(np.uint8)
    return (mag > 50).astype(np.uint8) * 255

def laplacian_edges(g: np.ndarray) -> np.ndarray:
    lap = cv2.Laplacian(g, cv2.CV_16S, ksize=3)
    absd = cv2.convertScaleAbs(lap)
    return cv2.threshold(absd, 40, 255, cv2.THRESH_BINARY)[1]

def dog_edges(g: np.ndarray, sigma: float, k: float, scale: float) -> np.ndarray:
    s1 = max(0.1, float(sigma))
    s2 = max(0.1, float(k) * s1)
    g1 = cv2.GaussianBlur(g, (0,0), s1)
    g2 = cv2.GaussianBlur(g, (0,0), s2)
    dog = (g1.astype(np.float32) - g2.astype(np.float32)) * float(scale)
    dog = np.clip((dog - dog.min()) / (dog.max() - dog.min() + 1e-6) * 255, 0, 255).astype(np.uint8)
    return cv2.threshold(dog, 80, 255, cv2.THRESH_BINARY)[1]

def xdog_edges(g: np.ndarray, sigma: float, k: float, phi: float, eps: float, p: float) -> np.ndarray:
    s1 = max(0.1, float(sigma))
    s2 = max(0.1, float(k) * s1)
    G1 = cv2.GaussianBlur(g, (0,0), s1).astype(np.float32) / 255.0
    G2 = cv2.GaussianBlur(g, (0,0), s2).astype(np.float32) / 255.0
    dog = G1 - float(p) * G2
    dog = (dog - dog.min()) / (dog.max() - dog.min() + 1e-6)
    phi = float(phi); eps = float(eps)
    xdog = np.tanh(phi * (dog - eps))
    xdog = (xdog - xdog.min()) / (xdog.max() - xdog.min() + 1e-6)
    xdog_u8 = np.clip((1.0 - xdog) * 255, 0, 255).astype(np.uint8)
    return cv2.threshold(xdog_u8, 128, 255, cv2.THRESH_BINARY)[1]

def edge_detect(
    img_gray: np.ndarray,
    method: str,
    canny_low: int, canny_high: int,
    th_block: int, th_C: int,
    blur_sigma: float,
    dog_sigma: float, dog_k: float, dog_scale: float,
    xdog_sigma: float, xdog_k: float, xdog_phi: float, xdog_eps: float, xdog_p: float
) -> np.ndarray:
    g = img_gray
    if blur_sigma and blur_sigma > 0:
        g = cv2.GaussianBlur(g, (0,0), float(blur_sigma))

    if method == "threshold":
        edges = adaptive_thresh(g, th_block, th_C)
    elif method == "sobel":
        edges = sobel_mag(g)
    elif method == "scharr":
        edges = scharr_mag(g)
    elif method == "laplacian":
        edges = laplacian_edges(g)
    elif method == "dog":
        edges = dog_edges(g, dog_sigma, dog_k, dog_scale)
    elif method == "xdog":
        edges = xdog_edges(g, xdog_sigma, xdog_k, xdog_phi, xdog_eps, xdog_p)
    else:
        edges = canny_edges(g, canny_low, canny_high)

    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
    edges = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, kernel, iterations=1)
    return edges

def contours_to_polylines(edges: np.ndarray, min_len: float, epsilon: float) -> List[np.ndarray]:
    cnts_info = cv2.findContours(edges, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
    contours = cnts_info[0] if len(cnts_info) == 2 else cnts_info[1]
    polylines = []
    for cnt in contours:
        if len(cnt) < 2:
            continue
        pts = cnt[:, 0, :].astype(np.float32)
        if epsilon and epsilon > 0:
            approx = cv2.approxPolyDP(pts, epsilon, False)
            pts = approx.reshape(-1, 2) if len(approx) >= 2 else pts
        if cv2.arcLength(cnt, False) >= min_len:
            polylines.append(pts)
    return polylines

def order_paths_nearest_neighbor(paths: List[np.ndarray]) -> List[np.ndarray]:
    if not paths:
        return paths
    used = [False]*len(paths)
    ordered = []
    endpoints = [(p[0], p[-1]) for p in paths]
    lengths = [np.sum(np.linalg.norm(np.diff(p, axis=0), axis=1)) for p in paths]
    idx = int(np.argmax(lengths))
    ordered.append(paths[idx]); used[idx] = True
    current_end = endpoints[idx][1]
    for _ in range(len(paths) - 1):
        best_j, best_dist, best_flip = None, float('inf'), False
        for j, p in enumerate(paths):
            if used[j]:
                continue
            start, end = endpoints[j]
            d_forward = np.linalg.norm(current_end - start)
            d_flipped = np.linalg.norm(current_end - end)
            if d_forward < best_dist:
                best_dist, best_j, best_flip = d_forward, j, False
            if d_flipped < best_dist:
                best_dist, best_j, best_flip = d_flipped, j, True
        if best_j is None:
            break
        p = paths[best_j][::-1].copy() if best_flip else paths[best_j]
        ordered.append(p); used[best_j] = True; current_end = p[-1]
    return ordered

# --- SVG writers --------------------------------------------------------------
def write_svg_basic(polylines, width: int, height: int, stroke_width: float) -> bytes:
    parts = [f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}px" height="{height}px" viewBox="0 0 {width} {height}">']
    for pts in polylines:
        if len(pts) < 2:
            continue
        points_str = " ".join(f"{float(x)},{float(y)}" for x, y in pts)
        parts.append(
            f'<polyline points="{points_str}" fill="none" stroke="black" '
            f'stroke-width="{stroke_width}" stroke-linecap="round" stroke-linejoin="round"/>'
        )
    parts.append('</svg>')
    return ("\n".join(parts)).encode("utf-8")

def write_svg(polylines, width: int, height: int, stroke_width: float) -> bytes:
    """
    Cuba guna svgwrite; jika tak ada/ralat, fallback ke writer basic.
    """
    try:
        import svgwrite
        dwg = svgwrite.Drawing(size=(f"{width}px", f"{height}px"), profile='tiny')
        dwg.viewbox(0, 0, width, height)
        for pts in polylines:
            if len(pts) < 2:
                continue
            points = [(float(x), float(y)) for x, y in pts]
            dwg.add(dwg.polyline(
                points=points, fill="none", stroke="black", stroke_width=stroke_width,
                stroke_linecap="round", stroke_linejoin="round"
            ))
        return dwg.tostring().encode("utf-8")
    except Exception:
        # Fallback yang sentiasa berjaya
        return write_svg_basic(polylines, width, height, stroke_width)



def draw_polylines_preview(polylines: List[np.ndarray], size: Tuple[int,int]) -> np.ndarray:
    w, h = size
    canvas = np.full((h, w, 3), 255, dtype=np.uint8)
    for pts in polylines:
        pts_i = np.round(pts).astype(np.int32).reshape(-1,1,2)
        cv2.polylines(canvas, [pts_i], isClosed=False, color=(0,0,0), thickness=1, lineType=cv2.LINE_AA)
    return canvas

def to_b64_png(img_bgr: np.ndarray) -> str:
    ok, buf = cv2.imencode(".png", img_bgr)
    return base64.b64encode(buf.tobytes()).decode("ascii")

# -------------------------------
# High-level pipelines
# -------------------------------
def generate_preview(raw_bytes: bytes, params: dict) -> dict:
    # Parse params (existing)
    method = params.get("edge_method", "canny")
    target_w = int(float(params.get("target_w", 1600)))
    blur = float(params.get("blur", 1.0))
    canny_low = int(float(params.get("canny_low", 50)))
    canny_high = int(float(params.get("canny_high", 150)))
    th_block = int(float(params.get("th_block", 31)))
    th_C = int(float(params.get("th_C", 5)))
    min_len = float(params.get("min_len", 20))
    epsilon = float(params.get("epsilon", 2.0))

    dog_sigma = float(params.get("dog_sigma", 1.0))
    dog_k = float(params.get("dog_k", 1.6))
    dog_scale = float(params.get("dog_scale", 2.0))
    xdog_sigma = float(params.get("xdog_sigma", 0.9))
    xdog_k = float(params.get("xdog_k", 1.6))
    xdog_phi = float(params.get("xdog_phi", 10.0))
    xdog_eps = float(params.get("xdog_eps", 0.0))
    xdog_p = float(params.get("xdog_p", 1.0))

    # NEW: preprocess params
    pre_method = params.get("preprocess", "none")
    pre_bilat_iter = int(float(params.get("pre_bilat_iter", 2)))
    pre_bilat_d = int(float(params.get("pre_bilat_d", 9)))
    pre_bilat_sc = float(params.get("pre_bilat_sc", 75.0))
    pre_bilat_ss = float(params.get("pre_bilat_ss", 9.0))
    pre_quant_k = int(float(params.get("pre_quant_k", 8)))
    pre_median_k = int(float(params.get("pre_median_k", 0)))
    pre_st_s = int(float(params.get("pre_st_sigma_s", 60)))
    pre_st_r = float(params.get("pre_st_sigma_r", 0.07))

    # Load + resize
    img = read_image_to_array(raw_bytes)
    img = resize_keep_aspect(img, target_w)

    # NEW: preprocess BEFORE grayscale/edges
    pre_img = preprocess_for_edges(
        img_bgr=img,
        method=pre_method,
        bilat_iter=pre_bilat_iter,
        bilat_d=pre_bilat_d,
        bilat_sigma_color=pre_bilat_sc,
        bilat_sigma_space=pre_bilat_ss,
        quant_k=pre_quant_k,
        median_k=pre_median_k,
        stylize_sigma_s=pre_st_s,
        stylize_sigma_r=pre_st_r,
    )

    gray = to_gray(pre_img)

    # Edge → contours → polylines → order (unchanged)
    edges = edge_detect(gray, method, canny_low, canny_high, th_block, th_C, blur,
                        dog_sigma, dog_k, dog_scale, xdog_sigma, xdog_k, xdog_phi, xdog_eps, xdog_p)
    polylines = contours_to_polylines(edges, min_len=min_len, epsilon=epsilon)
    polylines = order_paths_nearest_neighbor(polylines)

    preview = draw_polylines_preview(polylines, (edges.shape[1], edges.shape[0]))
    resp = {
        "orig_b64": to_b64_png(img),               # original (resized)
        "proc_b64": to_b64_png(pre_img),           # NEW: preprocessed preview
        "sketch_b64": to_b64_png(preview),
        "edges_b64": base64.b64encode(cv2.imencode(".png", edges)[1].tobytes()).decode("ascii"),
        "meta": {"contours": len(polylines)}
    }
    return resp


def build_zip(raw_bytes: bytes, params: dict) -> bytes:
    # Reuse the same param casting as preview and add stroke
    p = dict(params)  # copy
    stroke = float(p.get("stroke", 0.5))

    # Run pipeline once to get edges + polylines
    method = p.get("edge_method", "canny")
    target_w = int(float(p.get("target_w", 1600)))
    blur = float(p.get("blur", 1.0))
    canny_low = int(float(p.get("canny_low", 50)))
    canny_high = int(float(p.get("canny_high", 150)))
    th_block = int(float(p.get("th_block", 31)))
    th_C = int(float(p.get("th_C", 5)))
    min_len = float(p.get("min_len", 20))
    epsilon = float(p.get("epsilon", 2.0))
    dog_sigma = float(p.get("dog_sigma", 1.0))
    dog_k = float(p.get("dog_k", 1.6))
    dog_scale = float(p.get("dog_scale", 2.0))
    xdog_sigma = float(p.get("xdog_sigma", 0.9))
    xdog_k = float(p.get("xdog_k", 1.6))
    xdog_phi = float(p.get("xdog_phi", 10.0))
    xdog_eps = float(p.get("xdog_eps", 0.0))
    xdog_p = float(p.get("xdog_p", 1.0))
    off_x = 5.0
    off_y = 5.0
    # read + resize
    img = read_image_to_array(raw_bytes)
    img = resize_keep_aspect(img, target_w)

    # NEW: same preprocess as preview
    pre_img = preprocess_for_edges(
        img_bgr=img,
        method=request.form.get("preprocess", "none"),
        bilat_iter=int(float(request.form.get("pre_bilat_iter", 2))),
        bilat_d=int(float(request.form.get("pre_bilat_d", 9))),
        bilat_sigma_color=float(request.form.get("pre_bilat_sc", 75.0)),
        bilat_sigma_space=float(request.form.get("pre_bilat_ss", 9.0)),
        quant_k=int(float(request.form.get("pre_quant_k", 8))),
        median_k=int(float(request.form.get("pre_median_k", 0))),
        stylize_sigma_s=int(float(request.form.get("pre_st_sigma_s", 60))),
        stylize_sigma_r=float(request.form.get("pre_st_sigma_r", 0.07)),
    )

    gray = to_gray(pre_img)


    edges = edge_detect(
        gray, method, canny_low, canny_high, th_block, th_C, blur,
        dog_sigma, dog_k, dog_scale, xdog_sigma, xdog_k, xdog_phi, xdog_eps, xdog_p
    )
    polylines = contours_to_polylines(edges, min_len=min_len, epsilon=epsilon)
    polylines = order_paths_nearest_neighbor(polylines)

       # ... after polylines = order_paths_nearest_neighbor(polylines)
    # --- SIZE & ORIGIN ANCHOR ----------------------------------------------------
    h, w = edges.shape[:2]

    # 1) Limit physical size (keep aspect) to <= 120 x 120 mm
    max_w_mm = float(request.form.get("max_w_mm", 50.0))
    max_h_mm = float(request.form.get("max_h_mm", 50.0))
    mm_per_px = min(max_w_mm / float(w), max_h_mm / float(h))

    # 2) Choose where you want the *bottom-left* of the drawing to be
    #    (set to 0.0 if you want to start exactly at machine 0,0).
    origin_x_mm = float(request.form.get("origin_x_mm", 1.0))  # ← near 0,0
    origin_y_mm = float(request.form.get("origin_y_mm", 1.0))

    # 3) Compute the polyline bounding box *in mm* (remember Y is flipped for CNC)
    import math
    minx, miny, maxx, maxy = math.inf, math.inf, -math.inf, -math.inf
    for pts in polylines:
        if len(pts) < 1:
            continue
        xs = pts[:, 0] * mm_per_px
        ys = (h - pts[:, 1]) * mm_per_px   # Y flip: SVG→CNC
        minx = min(minx, float(xs.min()))
        maxx = max(maxx, float(xs.max()))
        miny = min(miny, float(ys.min()))
        maxy = max(maxy, float(ys.max()))

    # 4) Offsets so that the bbox min lands exactly at (origin_x_mm, origin_y_mm)
    if math.isfinite(minx) and math.isfinite(miny):
        off_x = origin_x_mm - minx
        off_y = origin_y_mm - miny
    else:
        off_x = origin_x_mm
        off_y = origin_y_mm

    final_w_mm = (maxx - minx) if math.isfinite(maxx) and math.isfinite(minx) else 0.0
    final_h_mm = (maxy - miny) if math.isfinite(maxy) and math.isfinite(miny) else 0.0
    # ----------------------------------------------------------------------------- 

    # SVG (guna writer anda yang sudah ada dgn fallback)
    svg_bytes = write_svg(polylines, width=w, height=h, stroke_width=stroke)

    # G-code (pastikan sentiasa cuba hasilkan; fallback jika perlu)
    try:
        gcfg = GCodeConfig(
            mm_per_px=mm_per_px,
            offset_x_mm=off_x,
            offset_y_mm=off_y,
            feed_draw=1800,
            feed_travel=3000,
            use_z_lift=False,
            servo_down_s=100,          # M03 S100
            servo_up_uses_M5=True,     # or False for M3 S0
            servo_up_s=0,
            servo_dwell_s=0.70,
            pad_zero_in_m_codes=True
        )
        gcode = polylines_to_gcode(polylines, width_px=w, height_px=h, cfg=gcfg).encode("utf-8")

    except Exception:
        gcode = polylines_to_gcode_basic(polylines, width_px=w, height_px=h, mm_per_px=0.1).encode("utf-8")

    edges_png = cv2.imencode(".png", edges)[1].tobytes()

    buf = io.BytesIO()
    with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
        zf.writestr("line_sketch.svg", svg_bytes)
        zf.writestr("edges_preview.png", edges_png)
        zf.writestr("plot.gcode", gcode)  # ← pastikan sentiasa ditulis
        zf.writestr("README.txt", (
            "line_sketch.svg: Vector polylines for Inkscape/pen plotting.\n"
            "edges_preview.png: Binary edge mask used to extract contours.\n"
            "plot.gcode: G-code untuk pen plotter (ubah skala & Z/servo ikut mesin).\n"
            "Tip: mm_per_px=0.1 → 1600px ≈ 160mm. Laras untuk muat atas kertas/bed.\n"
        ))
    buf.seek(0)
    return buf.getvalue()

# ---------- Cartoon / Preprocess helpers ----------
def bilateral_smooth(img_bgr: np.ndarray, iterations: int, d: int, sigma_color: float, sigma_space: float) -> np.ndarray:
    d = int(d)
    iterations = max(1, int(iterations))
    out = img_bgr.copy()
    for _ in range(iterations):
        out = cv2.bilateralFilter(out, d=d, sigmaColor=float(sigma_color), sigmaSpace=float(sigma_space))
    return out

def color_quantize_kmeans(img_bgr: np.ndarray, k: int, attempts: int = 1) -> np.ndarray:
    # K-means in BGR space
    Z = img_bgr.reshape((-1, 3)).astype(np.float32)
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 20, 1.0)
    k = int(max(2, min(32, k)))
    _, labels, centers = cv2.kmeans(Z, k, None, criteria, attempts, cv2.KMEANS_PP_CENTERS)
    centers = np.uint8(centers)
    out = centers[labels.flatten()].reshape(img_bgr.shape)
    return out

# ---------- Paper fit (A4 landscape) ----------
def compute_fit_mm(width_px: int, height_px: int,
                   page_w_mm: float = 150.0, page_h_mm: float = 150.0,
                   margin_mm: float = 5.0, mm_per_px_requested: float | None = None,
                   center: bool = True) -> tuple[float, float, float]:
    avail_w = max(1e-6, page_w_mm - 2*margin_mm)
    avail_h = max(1e-6, page_h_mm - 2*margin_mm)
    mm_per_px_fit = min(avail_w / float(width_px), avail_h / float(height_px))
    mm_per_px = min(mm_per_px_requested, mm_per_px_fit) if mm_per_px_requested else mm_per_px_fit
    out_w = width_px * mm_per_px
    out_h = height_px * mm_per_px
    if center:
        offset_x = margin_mm + (avail_w - out_w) / 2.0
        offset_y = margin_mm + (avail_h - out_h) / 2.0
    else:
        offset_x = margin_mm; offset_y = margin_mm
    return mm_per_px, offset_x, offset_y


def preprocess_for_edges(
    img_bgr: np.ndarray,
    method: str = "none",
    bilat_iter: int = 2,
    bilat_d: int = 9,
    bilat_sigma_color: float = 75.0,
    bilat_sigma_space: float = 9.0,
    quant_k: int = 8,
    median_k: int = 0,
    stylize_sigma_s: int = 60,
    stylize_sigma_r: float = 0.07,
) -> np.ndarray:
    """
    Returns a preprocessed BGR image to be used for grayscale + edges.
    - 'none'            : passthrough
    - 'bilateral'       : edge-preserving smoothing
    - 'quantize'        : bilateral smoothing + K-means color quantization
    - 'stylize' (best)  : cv2.stylization if available, else fallback to 'quantize'
    """
    method = (method or "none").lower()
    if method == "none":
        out = img_bgr
    elif method == "bilateral":
        out = bilateral_smooth(img_bgr, bilat_iter, bilat_d, bilat_sigma_color, bilat_sigma_space)
    elif method == "quantize":
        tmp = bilateral_smooth(img_bgr, bilat_iter, bilat_d, bilat_sigma_color, bilat_sigma_space)
        out = color_quantize_kmeans(tmp, quant_k)
    elif method == "stylize":
        # Requires opencv-contrib; fallback gracefully if missing
        try:
            out = cv2.stylization(img_bgr, sigma_s=int(stylize_sigma_s), sigma_r=float(stylize_sigma_r))
        except Exception:
            tmp = bilateral_smooth(img_bgr, bilat_iter, bilat_d, bilat_sigma_color, bilat_sigma_space)
            out = color_quantize_kmeans(tmp, quant_k)
    else:
        out = img_bgr

    mk = int(median_k)
    if mk > 0:
        mk = mk | 1  # kernel must be odd
        out = cv2.medianBlur(out, mk)
    return out
