#!/usr/bin/env python3
"""
Apply dtparam and dtoverlay directives from a Raspberry Pi config.txt to a base .dtb file,
producing a merged .dtb using the dtmerge utility.
"""

import argparse
import os
import re
import shutil
import subprocess
import sys
import tempfile


# Maps section names from config.txt to model identifier strings.
# The values are sets of model IDs to which that section applies.
_SECTION_MAP = {
    'all':   None,   # sentinel: applies to everything
    'none':  (),     # applies to nothing
    'pi5':   ('pi5','pi500','cm5',),
    'pi500': ('pi5',),
    'cm5':   ('cm5',),
    'pi4':   ('pi4','pi400','cm4',),
    'pi400': ('pi400',),
    'cm4':   ('cm4',),
    'cm4s':  ('cm4s',),
    'pi3':   ('pi3','pi3+',),
    'pi3+':  ('pi3+',),
    'cm0':   {'cm0',},
    'pi2':   ('pi2',),
    'pi1':   ('pi1',),
    'pi0':   ('pi0','pi0w','pi02',),
    'pi0w':  ('pi0w','pi02',),
    'pi02':  ('pi02',),
}


def model_from_dtb(dtb_path: str) -> str | None:
    """Infer a Pi model identifier from a DTB filename."""
    name = os.path.basename(dtb_path).lower().removesuffix('.dtb')

    if 'bcm2712' in name:
        if 'cm5' in name:
            return 'cm5'
        if '500' in name:
            return 'pi500'
        return 'pi5'

    if 'bcm2711' in name:
        if 'cm4s' in name:
            return 'cm4s'
        if 'cm4' in name:
            return 'cm4'
        if '400' in name:
            return 'pi400'
        return 'pi4'

    if 'bcm2710' in name or 'bcm2837' in name:
        if '2-b' in name:
            return 'pi2'
        if 'cm0' in name:
            return 'cm0'
        if 'zero-2' in name:
            return 'pi02'
        if '3-b-plus' in name or '3-a-plus' in name:
            return 'pi3+'
        # There is no [cm3] filter
        return 'pi3'

    if 'bcm2709' in name or 'bcm2836' in name:
        # There is no [cm2] filter
        return 'pi2'

    if 'bcm2708' in name or 'bcm2835' in name:
        if 'zero-w' in name:
            return 'pi0w'
        if 'zero' in name:
            return 'pi0'
        # There is no [cm1] filter
        return 'pi1'

    return None


def section_applies(section: str, model: str | None) -> bool:
    """Return True if a config.txt conditional section applies to *model*."""
    if section not in _SECTION_MAP:
        # Unknown section (e.g. [board-type=…], [EDID=…]) — skip conservatively
        return False

    targets = _SECTION_MAP[section]
    if targets is None:    # [all]
        return True
    if not targets:        # [none]
        return False
    if model is None:
        # Model unknown; include everything to avoid silently dropping directives
        return True
    return model in targets


def parse_config(config_path: str, model: str | None) -> list[tuple]:
    """
    Parse *config_path* and return a list of actions:

      ('base_params', ['p=v', ...])
          Apply params to the base DT (dtmerge … - p=v …).

      ('overlay', name, ['p=v', ...])
          Load overlay and apply all accumulated params (dtmerge … name.dtbo p=v …).
          dtmerge resolves each param against the overlay's __overrides__ first,
          then the base, so global params mixed into an overlay line are handled
          correctly.

    The firmware model followed here is:
      - dtparam lines are applied in the context of the most-recently-seen
        dtoverlay; each param is resolved against that overlay before the base.
      - A bare 'dtoverlay=' or 'dtoverlay=none' ends the current overlay context.
      - dtparam with no preceding dtoverlay is a plain base-only operation.
    """
    actions: list[tuple] = []
    active = True  # [all] is implicitly assumed before the first section header

    # Overlay accumulator: we delay emitting an overlay action until we know
    # all its params (which may arrive on subsequent dtparam lines).
    pending_overlay: str | None = None
    pending_params: list[str] = []

    def flush_overlay() -> None:
        nonlocal pending_overlay, pending_params
        if pending_overlay is not None:
            actions.append(('overlay', pending_overlay, pending_params))
            pending_overlay = None
            pending_params = []

    with open(config_path) as fh:
        for raw in fh:
            line = raw.strip()

            # Strip inline comments
            line = re.sub(r'\s*#.*$', '', line)
            if not line:
                continue

            # Section header
            m = re.fullmatch(r'\[([^\]]+)\]', line)
            if m:
                active = section_applies(m.group(1).lower(), model)
                continue

            if not active:
                continue

            # dtparam / dtoverlay
            m = re.match(r'^(dtparam|dtoverlay)\s*=\s*(.*)', line, re.IGNORECASE)
            if not m:
                continue

            kind = m.group(1).lower()
            rest = m.group(2).strip()

            # Split "name,p1=v1,p2=v2" into tokens.  Commas separate tokens;
            # quoted values with embedded commas are not supported by the
            # firmware either, so this matches real-world usage.
            tokens = [t.strip() for t in rest.split(',') if t.strip()]

            if kind == 'dtparam':
                if not tokens:
                    continue
                if pending_overlay is None:
                    # No active overlay: apply to base only.
                    actions.append(('base_params', tokens))
                else:
                    # Extend the current overlay's param list; dtmerge will
                    # resolve each param against the overlay then the base.
                    pending_params.extend(tokens)

            else:  # dtoverlay
                overlay_name = tokens[0] if tokens else ''

                if not overlay_name or overlay_name == 'none':
                    # Bare 'dtoverlay=' or 'dtoverlay=none': end current context.
                    flush_overlay()

                elif overlay_name == 'base':
                    print("  Warning: dtoverlay=base (reset) is not supported; ignoring")
                    flush_overlay()

                else:
                    # New overlay: flush any pending one first, then start accumulating.
                    flush_overlay()
                    pending_overlay = overlay_name
                    pending_params = list(tokens[1:])

    flush_overlay()
    return actions


def find_overlay(name: str, overlays_dir: str) -> str:
    """Return the path to *name*.dtbo, or raise FileNotFoundError."""
    path = os.path.join(overlays_dir, name + '.dtbo')
    if not os.path.isfile(path):
        raise FileNotFoundError(f"Overlay not found: {path}")
    return path


def run_dtmerge(dtmerge_bin: str, base: str, output: str,
                overlay: str, params: list[str], debug: bool) -> None:
    """Run dtmerge to produce *output* from *base* + *overlay* + *params*."""
    cmd = [dtmerge_bin]
    if debug:
        cmd.append('-d')
    cmd += [base, output, overlay] + params

    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        if result.stdout:
            sys.stdout.write(result.stdout)
        if result.stderr:
            sys.stderr.write(result.stderr)
        raise RuntimeError(f"dtmerge failed (exit {result.returncode})")
    if debug and result.stderr:
        sys.stderr.write(result.stderr)


def default_overlays_dir() -> str:
    for candidate in ('/boot/firmware/overlays', '/boot/overlays'):
        if os.path.isdir(candidate):
            return candidate
    return '/boot/firmware/overlays'


def dtapply(base_dtb: str, config_path: str, output_dtb: str,
            overlays_dir: str, dtmerge_bin: str,
            model_override: str | None, debug: bool, dry_run: bool) -> None:

    model = model_override or model_from_dtb(base_dtb)
    if model:
        print(f"Detected model: {model}")
    else:
        print("Warning: could not detect model from DTB name; all sections will be included")

    directives = parse_config(config_path, model)
    if not directives:
        print("No dtparam/dtoverlay directives found — copying base DTB unchanged")
        shutil.copy2(base_dtb, output_dtb)
        return

    print(f"Found {len(directives)} action(s) to apply:")
    for action in directives:
        if action[0] == 'base_params':
            print(f"  dtparam={','.join(action[1])}")
        else:
            _, name, params = action
            print(f"  dtoverlay={','.join([name] + params)}")

    if dry_run:
        return

    with tempfile.TemporaryDirectory(prefix='dtmerge_') as tmpdir:
        # Chain: each step reads from 'current' and writes to 'next'
        current = base_dtb
        step = 0

        for action in directives:
            step += 1
            next_dtb = os.path.join(tmpdir, f"step{step:03d}.dtb")

            if action[0] == 'base_params':
                _, params = action
                if debug:
                    print(f"  dtmerge: base params {params}")
                try:
                    run_dtmerge(dtmerge_bin, current, next_dtb, '-', params, debug)
                except RuntimeError:
                    print(f"  Warning: dtparam={','.join(params)} failed; skipping")
                    continue

            else:  # overlay
                _, overlay_name, params = action
                try:
                    overlay_path = find_overlay(overlay_name, overlays_dir)
                except FileNotFoundError as exc:
                    print(f"  Warning: {exc}; skipping")
                    continue
                if debug:
                    print(f"  dtmerge: dtoverlay={overlay_name} params={params}")
                try:
                    run_dtmerge(dtmerge_bin, current, next_dtb, overlay_path, params, debug)
                except RuntimeError:
                    print(f"  Warning: dtoverlay={overlay_name} failed; skipping")
                    continue

            current = next_dtb

        shutil.copy2(current, output_dtb)

    print(f"Written: {output_dtb}")


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Apply config.txt dtparam/dtoverlay directives to a base .dtb file "
                    "using the dtmerge utility."
    )
    parser.add_argument('dtb',    help="Base DTB file (e.g. bcm2711-rpi-4-b.dtb)")
    parser.add_argument('config', help="Raspberry Pi config.txt file")
    parser.add_argument('-o', '--output',   default='output.dtb',
                        help="Output DTB file (default: output.dtb)")
    parser.add_argument('--overlays-dir',   default=None,
                        help="Directory containing .dtbo overlay files "
                             "(default: auto-detect /boot/firmware/overlays or /boot/overlays)")
    parser.add_argument('--dtmerge',        default='dtmerge',
                        help="Path to the dtmerge binary (default: dtmerge)")
    parser.add_argument('--model',          default=None,
                        help="Override Pi model identifier for section filtering "
                             "(e.g. pi4, cm4, pi5)")
    parser.add_argument('-d', '--debug',    action='store_true',
                        help="Pass -d to dtmerge and print extra information")
    parser.add_argument('-n', '--dry-run',  action='store_true',
                        help="Parse and print directives but do not run dtmerge")
    args = parser.parse_args()

    if not os.path.isfile(args.dtb):
        sys.exit(f"Error: base DTB not found: {args.dtb}")
    if not os.path.isfile(args.config):
        sys.exit(f"Error: config.txt not found: {args.config}")

    overlays_dir = args.overlays_dir or default_overlays_dir()

    if not args.dry_run and not os.path.isdir(overlays_dir):
        sys.exit(f"Error: overlays directory not found: {overlays_dir}")

    if shutil.which(args.dtmerge) is None and not os.path.isfile(args.dtmerge):
        sys.exit(f"Error: dtmerge binary not found: {args.dtmerge}")

    dtapply(
        base_dtb=args.dtb,
        config_path=args.config,
        output_dtb=args.output,
        overlays_dir=overlays_dir,
        dtmerge_bin=args.dtmerge,
        model_override=args.model,
        debug=args.debug,
        dry_run=args.dry_run,
    )

if __name__ == '__main__':
    main()
