"""Implement the fast in-memory forest navigation Gymnasium environment."""

# from __future__ import annotations
from typing import Any, Optional, Tuple
from pathlib import Path
import importlib
import sys

import numpy as np
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3.common.env_checker import check_env
from fastsim_forest_nav.dynamics.controls import (
    accel_limit_velocity as _accel_limit_velocity,
    approach_speed_cap as _approach_speed_cap,
    map_normalized_accel as _map_normalized_accel,
)
from fastsim_forest_nav.envs.params import SimParams
from fastsim_forest_nav.safety.geometry import (
    effective_drone_radius as _effective_drone_radius,
    protected_radius as _protected_radius,
    soft_clearance_margin as _soft_clearance_margin,
)


class ForestNavEnv(gym.Env):
    """Provide a continuous-control forest navigation environment for SB3."""

    def __init__(self, params: SimParams, render_mode: Optional[str] = None):
        """Initialize environment state, spaces, and cached geometry."""
        super().__init__()
        self.p = params
        self.render_mode = render_mode

        if self.p.action_mode != "hybrid":
            raise ValueError(
                f"Unsupported action_mode='{self.p.action_mode}'. "
                "Only action_mode='hybrid' is supported."
            )

        # Observation space: [lidar_ranges normalized...,
        # cos(theta_goal), sin(theta_goal), forward_speed norm, yaw_rate norm, height error normalized =
        # clip((z_target -z) / z_scale, -1, 1)]
        # I can change forward speed to vx, vy normalized (+1 box size)
        obs_dim = self.p.lidar_num_beams + 6
        self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(obs_dim,), dtype=np.float32)

        # Action space: normalized acceleration commands in each axis.
        # a[0] -> forward accel, a[1] -> yaw accel, a[2] -> vertical accel
        # SAC in SB3 is built for continuous Box actions
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32)

        self._t = 0.0
        self._step_count = 0

        # Cached lidar scan to avoid recomputation
        self._cached_lidar: np.ndarray | None = None
        self._cached_min_range: float = float("inf")

        # State example
        self.pos = np.zeros(
            3, dtype=np.float32
        )  # x, y, z. Will probably rip it from the sim/gz directly for now
        self.yaw = np.float32(0.0)
        self.v = np.float32(0.0)
        self.wz = np.float32(0.0)
        self.vz = np.float32(0.0)

        self.goal = np.zeros(3, dtype=np.float32)  # x, y, z
        self.z_target = np.float32(0.0)  # maintaing this height target for now

        self.trees = None  # list/array of cylinders (x, y, radius)
        self._prev_dist: Optional[float] = None  # last distance, can use to calculate delta
        self._world_half_extent = float(self.p.world_radius)
        self._last_worldgen_seed: Optional[int] = None
        self._last_worldgen_selection: dict[str, Any] | None = None
        self._episode_counter: int = 0

        # full 360 deg lidar beam geometry (relative to body x-axis)
        self._beam_angle_start = -np.pi
        self._beam_angle_step = (2.0 * np.pi) / float(self.p.lidar_num_beams)
        rel_angles = self._beam_angle_start + self._beam_angle_step * np.arange(
            self.p.lidar_num_beams,
            dtype=np.float64,
        )
        self._beam_rel_cos = np.cos(rel_angles)
        self._beam_rel_sin = np.sin(rel_angles)

        # spatial grid for fast nearest-tree queries
        self._tree_grid: Optional[dict] = None
        self._tree_grid_cell_size = float(self.p.lidar_range_max) / 2.0
        self._worldgen_generate_positions_fn = None

    @staticmethod
    def _project_root() -> Path:
        return Path(__file__).resolve().parents[4]

    def _worldgen_config_path(self) -> Path:
        return self._project_root() / self.p.worldgen_config_relpath

    def _effective_world_half_extent(self) -> float:
        return max(1e-3, float(self._world_half_extent) - float(self.p.boundary_margin))

    def _distance_to_world_boundary(self) -> float:
        half = self._effective_world_half_extent()
        dx = half - abs(float(self.pos[0]))
        dy = half - abs(float(self.pos[1]))
        return float(min(dx, dy))

    def _distance_to_world_boundary_along_motion(self, signed_speed: float) -> float:
        """Distance along signed body-x direction until hitting virtual boundary wall."""
        if abs(float(signed_speed)) <= 1e-9:
            return float("inf")

        half = self._effective_world_half_extent()
        x = float(self.pos[0])
        y = float(self.pos[1])

        motion_sign = 1.0 if float(signed_speed) >= 0.0 else -1.0
        dx = motion_sign * float(np.cos(float(self.yaw)))
        dy = motion_sign * float(np.sin(float(self.yaw)))

        candidates: list[float] = []
        eps = 1e-9

        if abs(dx) > eps:
            wall_x = half if dx > 0.0 else -half
            tx = (wall_x - x) / dx
            if tx >= 0.0:
                y_hit = y + tx * dy
                if -half <= y_hit <= half:
                    candidates.append(float(tx))

        if abs(dy) > eps:
            wall_y = half if dy > 0.0 else -half
            ty = (wall_y - y) / dy
            if ty >= 0.0:
                x_hit = x + ty * dx
                if -half <= x_hit <= half:
                    candidates.append(float(ty))

        if not candidates:
            return float("inf")
        return float(min(candidates))

    def _build_tree_grid(self):
        """Build spatial hash grid for fast nearest-tree queries."""
        if self.trees is None or len(self.trees) == 0:
            self._tree_grid = {}
            return

        cell_size = self._tree_grid_cell_size
        grid = {}
        for idx, tree in enumerate(self.trees):
            x, y, r = float(tree[0]), float(tree[1]), float(tree[2])
            # insert into all cells overlapped by this tree's bounding box
            min_cell_x = int(np.floor((x - r) / cell_size))
            max_cell_x = int(np.floor((x + r) / cell_size))
            min_cell_y = int(np.floor((y - r) / cell_size))
            max_cell_y = int(np.floor((y + r) / cell_size))
            for cx in range(min_cell_x, max_cell_x + 1):
                for cy in range(min_cell_y, max_cell_y + 1):
                    key = (cx, cy)
                    if key not in grid:
                        grid[key] = []
                    grid[key].append(idx)
        self._tree_grid = grid

    def _query_nearby_trees(self, x: float, y: float, radius: float) -> np.ndarray:
        """Return indices of trees within 'radius' of (x, y) using spatial grid."""
        if self._tree_grid is None or len(self._tree_grid) == 0 or self.trees is None:
            return np.array([], dtype=np.int64)

        cell_size = self._tree_grid_cell_size
        min_cx = int(np.floor((x - radius) / cell_size))
        max_cx = int(np.floor((x + radius) / cell_size))
        min_cy = int(np.floor((y - radius) / cell_size))
        max_cy = int(np.floor((y + radius) / cell_size))

        candidates = set()
        for cx in range(min_cx, max_cx + 1):
            for cy in range(min_cy, max_cy + 1):
                if (cx, cy) in self._tree_grid:
                    candidates.update(self._tree_grid[(cx, cy)])

        if not candidates:
            return np.array([], dtype=np.int64)

        candidate_indices = np.array(list(candidates), dtype=np.int64)
        tree_subset = self.trees[candidate_indices]
        dx = tree_subset[:, 0] - x
        dy = tree_subset[:, 1] - y
        dist = np.sqrt(dx * dx + dy * dy)
        reach = dist + tree_subset[:, 2]
        in_range = reach >= 0.0  # always true but kept for clarity
        return candidate_indices[in_range]

    def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None):
        """Reset state, (re)sample world assets, and return initial observation."""
        # gym seeding contract
        super().reset(seed=seed)
        self._t = 0.0
        self._step_count = 0
        self._episode_counter += 1

        # Sample world, start, goal
        resample_every = max(1, int(getattr(self.p, "worldgen_resample_every_n_episodes", 1)))
        should_resample_world = self.trees is None or (
            (self._episode_counter - 1) % resample_every == 0
        )

        start_xy_anchor: np.ndarray | None = None
        goal_xy_anchor: np.ndarray | None = None
        if should_resample_world:
            start_xy_anchor, goal_xy_anchor = self._sample_start_goal_anchors()
            exclusion_radius = float(max(0.0, self.p.start_goal_tree_exclusion_radius))
            exclusion_centers = np.stack([start_xy_anchor, goal_xy_anchor], axis=0)
            self.trees = self._sample_forest(
                exclusion_centers=exclusion_centers,
                exclusion_radius=exclusion_radius,
            )
            self._build_tree_grid()

        self.pos, self.yaw = self._sample_start_pose(preferred_xy=start_xy_anchor)
        self.goal = self._sample_goal_pose(preferred_xy=goal_xy_anchor)
        self.z_target = np.float32(self.p.default_z_target)

        # standing still
        self.v = np.float32(0.0)
        self.wz = np.float32(0.0)
        self.vz = np.float32(0.0)

        self._prev_dist = self._dist_to_goal()

        obs = self._get_obs()
        info = self._get_info(shield_active=0, collision=0, success=0, shield_delta=0.0)
        return obs, info

    def step(self, action: np.ndarray):
        """Advance one simulation step using acceleration-space actions."""
        action = np.asarray(action, dtype=np.float32)

        accel_v = _map_normalized_accel(float(action[0]), self.p.accel_v_max, self.p.decel_v_max)
        accel_wz = _map_normalized_accel(float(action[1]), self.p.accel_wz_max, self.p.decel_wz_max)
        accel_vz = _map_normalized_accel(float(action[2]), self.p.accel_vz_max, self.p.decel_vz_max)

        pre_clip_v = float(self.v) + accel_v * float(self.p.dt)
        pre_clip_wz = float(self.wz) + accel_wz * float(self.p.dt)
        pre_clip_vz = float(self.vz) + accel_vz * float(self.p.dt)

        cmd_v = float(np.clip(pre_clip_v, -self.p.v_max, self.p.v_max))
        cmd_wz = float(np.clip(pre_clip_wz, -self.p.wz_max, self.p.wz_max))
        cmd_vz = float(np.clip(pre_clip_vz, -self.p.vz_max, self.p.vz_max))
        accel_clipped = int(
            (cmd_v != pre_clip_v) or (cmd_wz != pre_clip_wz) or (cmd_vz != pre_clip_vz)
        )

        # safety shield clamps command (authoritative)
        safe_v, safe_wz, safe_vz, shield_active, shield_delta = self._apply_shield(
            cmd_v, cmd_wz, cmd_vz
        )

        applied_v, slew_clipped_v = _accel_limit_velocity(
            desired=float(safe_v),
            current=float(self.v),
            accel_max=float(self.p.accel_v_max),
            dt=float(self.p.dt),
            decel_max=float(self.p.decel_v_max),
        )
        applied_wz, slew_clipped_wz = _accel_limit_velocity(
            desired=float(safe_wz),
            current=float(self.wz),
            accel_max=float(self.p.accel_wz_max),
            dt=float(self.p.dt),
            decel_max=float(self.p.decel_wz_max),
        )
        applied_vz, slew_clipped_vz = _accel_limit_velocity(
            desired=float(safe_vz),
            current=float(self.vz),
            accel_max=float(self.p.accel_vz_max),
            dt=float(self.p.dt),
            decel_max=float(self.p.decel_vz_max),
        )
        accel_clipped = int(
            bool(accel_clipped) or slew_clipped_v or slew_clipped_wz or slew_clipped_vz
        )

        self.v = np.float32(applied_v)
        self.wz = np.float32(applied_wz)
        self.vz = np.float32(applied_vz)

        # integrate simple kinematics (fastsim)
        self._integrate(np.float32(applied_v), np.float32(applied_wz), np.float32(applied_vz))

        # sensor update
        lidar = self._lidar_scan()
        self._cached_lidar = lidar
        tree_min_range = float(np.min(lidar))
        boundary_range = self._distance_to_world_boundary()
        min_range = float(min(tree_min_range, boundary_range))
        self._cached_min_range = min_range

        # reward and termination
        dist = self._dist_to_goal()
        prev_dist = self._prev_dist if self._prev_dist is not None else dist
        d_progress = prev_dist - dist
        self._prev_dist = dist

        drone_radius = _effective_drone_radius(self.p)
        clearance = min_range - drone_radius

        collision = int(clearance < 0.0)
        success = int(dist < self.p.goal_tolerance)

        # Weights saved in the dataclass
        reward = 0.0
        reward += self.p.reward_progress_scale * d_progress  # encourage progress towards goal
        reward += self.p.reward_speed_scale * (applied_v / self.p.v_max)  # encourage faster speeds
        reward -= (
            self.p.reward_step_penalty
        )  # small penalty for each step to encourage faster completion

        speed_norm = abs(float(applied_v)) / max(float(self.p.v_max), 1e-6)
        yaw_rate_norm = abs(float(applied_wz)) / max(float(self.p.wz_max), 1e-6)
        reward -= self.p.reward_yaw_rate_scale * speed_norm * yaw_rate_norm

        if accel_clipped:
            reward -= self.p.reward_accel_clip_penalty

        if d_progress < self.p.progress_stall_threshold and not success:
            stall_ratio = (self.p.progress_stall_threshold - d_progress) / max(
                self.p.progress_stall_threshold,
                1e-6,
            )
            stall_penalty = self.p.reward_stall_penalty * float(np.clip(stall_ratio, 0.0, 1.0))

            if speed_norm < self.p.yaw_penalty_speed_gate:
                low_speed_ratio = (self.p.yaw_penalty_speed_gate - speed_norm) / max(
                    self.p.yaw_penalty_speed_gate,
                    1e-6,
                )
                stall_penalty += (
                    0.5 * self.p.reward_stall_penalty * float(np.clip(low_speed_ratio, 0.0, 1.0))
                )

            reward -= stall_penalty

        soft_margin = _soft_clearance_margin(self.p)
        if clearance < soft_margin and soft_margin > 1e-6:
            reward -= (
                self.p.reward_proximity_scale * (soft_margin - clearance) / soft_margin
            )  # penalty for getting too close to obstacles
        if shield_active:
            reward -= self.p.reward_shield_penalty
        if collision:
            reward -= self.p.reward_collision_penalty
        if success:
            reward += self.p.reward_success_bonus

        terminated = bool(collision or success)
        self._step_count += 1
        self._t += self.p.dt

        truncated = bool(self._t >= self.p.episode_seconds)
        if truncated and not terminated:
            reward -= self.p.reward_truncation_penalty

        obs = self._pack_obs(
            lidar, dist, np.float32(applied_v), np.float32(applied_wz)
        )  # purposefully not including vz
        info = self._get_info(
            shield_active=shield_active,
            collision=collision,
            success=success,
            shield_delta=shield_delta,
            accel_clipped=accel_clipped,
            clearance=clearance,
            drone_radius=drone_radius,
        )
        return obs, float(reward), terminated, truncated, info

    def render(self):
        """Return a compact textual snapshot of the current environment state."""
        return f"t={self._t:2f} pos = {self.pos} yaw = {float(self.yaw):.2f} goal = {self.goal} v = {float(self.v):.2f}"

    def close(self):
        """Release environment resources if cleanup is required."""
        pass

    ### HELPER FUNCTIONS ###

    def _pack_obs(
        self, lidar: np.ndarray, dist: float, v: np.float32, wz: np.float32
    ) -> np.ndarray:
        # normalize lidar
        lidar_n = np.clip(lidar / self.p.lidar_range_max, 0.0, 1.0).astype(np.float32)

        # goal direction in body frame
        dx = float(self.goal[0] - self.pos[0])
        dy = float(self.goal[1] - self.pos[1])
        theta = np.arctan2(dy, dx) - float(self.yaw)
        c, s = np.cos(theta), np.sin(theta)

        dist_n = np.clip(dist / (2.0 * self._world_half_extent), 0.0, 1.0)
        v_n = np.clip(v / self.p.v_max, -1.0, 1.0)
        wz_n = np.clip(wz / self.p.wz_max, -1.0, 1.0)
        z_err = np.clip(
            (float(self.z_target) - float(self.pos[2])) / self.p.z_error_scale, -1.0, 1.0
        )

        tail = np.array([c, s, dist_n, v_n, wz_n, z_err], dtype=np.float32)
        obs = np.concatenate([lidar_n, tail], axis=0).astype(np.float32)
        return obs

    def _get_obs(self):
        lidar = self._lidar_scan()
        dist = self._dist_to_goal()
        obs = self._pack_obs(lidar, dist, self.v, self.wz)
        return obs

    def _get_info(self, **kwargs) -> dict[str, Any]:
        # Use cached values from step() to avoid recomputation
        min_range = (
            self._cached_min_range
            if self._cached_lidar is not None
            else float(np.min(self._lidar_scan()))
        )
        info = {
            "dist_to_goal": float(self._dist_to_goal()),
            "min_range": min_range,
            "tree_count": int(0 if self.trees is None else len(self.trees)),
            "worldgen_seed": int(self._last_worldgen_seed)
            if self._last_worldgen_seed is not None
            else None,
            "worldgen_layout": (
                str(self._last_worldgen_selection.get("layout_ref"))
                if isinstance(self._last_worldgen_selection, dict)
                and self._last_worldgen_selection.get("layout_ref")
                else None
            ),
            "worldgen_distribution_refs": (
                [
                    str(v)
                    for v in self._last_worldgen_selection.get("selected_distribution_refs", [])
                ]
                if isinstance(self._last_worldgen_selection, dict)
                else []
            ),
        }
        info.update(kwargs)
        info["is_success"] = bool(info.get("success", False))
        return info

    def _dist_to_goal(self) -> float:
        # for now just use euclidean distance in xy plane ignoring z. I can add that in later if needed but it might not be super helpful
        return float(np.linalg.norm(self.goal[:2] - self.pos[:2]))

    def _apply_shield(self, v: float, wz: float, vz: float):
        """Velocity-barrier safety shield.

        For every tree, compute a safe velocity target from clearance constraints.
        This target is later passed through acceleration/deceleration slew limits
        before being applied to the simulated state.

        Returns (safe_v, safe_wz, safe_vz, shield_active, shield_delta_norm)
        #NOTE:XXX This might result in the drone just getting stuck?
        """
        safe_v = float(v)
        safe_wz = float(wz)
        safe_vz = float(vz)
        shield_active = 0
        v_decel_cap = (
            float(self.p.decel_v_max)
            if float(self.p.decel_v_max) > 0.0
            else float(self.p.accel_v_max)
        )

        # horizontal tree avoidance (query nearby trees only)
        if self.trees is not None and len(self.trees) > 0:
            vel_dir = np.array(
                [np.cos(float(self.yaw)), np.sin(float(self.yaw))],
                dtype=np.float64,
            )
            pos_xy = self.pos[:2].astype(np.float64)

            # dynamic lookahead: one-step motion + stopping distance + safety bubble
            stopping_distance = (
                (safe_v * safe_v) / max(2.0 * v_decel_cap, 1e-6) if safe_v > 0.0 else 0.0
            )
            protected_radius = _protected_radius(self.p)
            lookahead = (
                abs(safe_v) * self.p.dt
                + stopping_distance
                + protected_radius
                + self.p.shield_lookahead_margin
            )
            nearby_idx = self._query_nearby_trees(float(self.pos[0]), float(self.pos[1]), lookahead)

            for idx in nearby_idx:
                tree = self.trees[idx]
                t_xy = tree[:2].astype(np.float64)
                t_r = float(tree[2])

                delta = t_xy - pos_xy  # vector from UAV to tree centre
                d = float(np.linalg.norm(delta))
                if d < 1e-6:
                    # degenerate: UAV on top of tree centre -> full stop
                    safe_v = 0.0
                    shield_active = 1
                    continue

                protected_radius = _protected_radius(self.p)
                gap = d - t_r - protected_radius  # remaining clearance
                c = float(np.dot(vel_dir, delta / d))  # cos(heading, tree dir)
                approach = safe_v * c  # > 0 when closing distance

                if gap <= 0.0:
                    # already inside safety bubble, block any further approach
                    if approach > 0.0:
                        safe_v = 0.0
                        shield_active = 1
                elif approach > 0.0:
                    # positive gap but approaching, cap approach speed with braking room
                    v_approach_cap = _approach_speed_cap(gap, v_decel_cap, self.p.dt)
                    v_max_safe = v_approach_cap / max(abs(c), 1e-6)
                    if abs(safe_v) > v_max_safe:
                        safe_v = float(np.sign(safe_v)) * v_max_safe
                        shield_active = 1

        # world boundary avoidance (virtual walls)
        if abs(safe_v) > 1e-6:
            dist_to_wall = self._distance_to_world_boundary_along_motion(safe_v)
            protected_radius = _protected_radius(self.p)
            gap = dist_to_wall - protected_radius

            if gap <= 0.0:
                safe_v = 0.0
                shield_active = 1
            else:
                v_max_safe = _approach_speed_cap(gap, v_decel_cap, self.p.dt)
                if abs(safe_v) > v_max_safe:
                    safe_v = float(np.sign(safe_v)) * v_max_safe
                    shield_active = 1

        # vertical floor guard
        if float(self.pos[2]) <= self.p.shield_floor_z_min and safe_vz < 0.0:
            safe_vz = 0.0
            shield_active = 1

        # vertical ceiling guard
        if (
            float(self.p.shield_ceiling_z_max) > 0.0
            and float(self.pos[2]) >= float(self.p.shield_ceiling_z_max)
            and safe_vz > 0.0
        ):
            safe_vz = 0.0
            shield_active = 1

        if shield_active and abs(safe_wz) > 0.0 and abs(safe_v) > 1e-3:
            damping = float(np.clip(self.p.shield_yaw_damping, 0.0, 1.0))
            safe_wz = safe_wz * (1.0 - damping)

        # normalised intervention magnitude
        cmd_norm = abs(v) + abs(wz) + abs(vz)
        delta_norm = abs(v - safe_v) + abs(wz - safe_wz) + abs(vz - safe_vz)
        shield_delta = delta_norm / cmd_norm if cmd_norm > 1e-6 else 0.0

        return (
            np.float32(safe_v),
            np.float32(safe_wz),
            np.float32(safe_vz),
            int(shield_active),
            float(shield_delta),
        )

    def _integrate(self, v: np.float32, wz: np.float32, vz: np.float32):
        # simple kinematics
        self.yaw = np.float32(float(self.yaw) + wz * self.p.dt)
        self.pos[0] = np.float32(float(self.pos[0]) + v * np.cos(float(self.yaw)) * self.p.dt)
        self.pos[1] = np.float32(float(self.pos[1]) + v * np.sin(float(self.yaw)) * self.p.dt)
        self.pos[2] = np.float32(float(self.pos[2]) + vz * self.p.dt)

    def _lidar_scan(self) -> np.ndarray:
        # raycast against circular tree cross-sections in xy plane
        n_beams = self.p.lidar_num_beams
        max_range = float(self.p.lidar_range_max)
        ranges = np.full((n_beams,), max_range, dtype=np.float64)

        if self.trees is None or len(self.trees) == 0:
            return ranges.astype(np.float32)

        px = float(self.pos[0])
        py = float(self.pos[1])
        yaw = float(self.yaw)

        cos_yaw = float(np.cos(yaw))
        sin_yaw = float(np.sin(yaw))
        dir_x = cos_yaw * self._beam_rel_cos - sin_yaw * self._beam_rel_sin
        dir_y = sin_yaw * self._beam_rel_cos + cos_yaw * self._beam_rel_sin

        # query nearby trees only
        nearby_idx = self._query_nearby_trees(px, py, max_range)

        for idx in nearby_idx:
            tree = self.trees[idx]
            tx = float(tree[0])
            ty = float(tree[1])
            radius = float(tree[2])

            dx = tx - px
            dy = ty - py
            d2 = dx * dx + dy * dy
            r2 = radius * radius

            # If UAV center is inside a trunk, every beam collides at zero range
            if d2 <= r2:
                return np.zeros((n_beams,), dtype=np.float32)

            center_dist = float(np.sqrt(d2))
            if center_dist - radius > max_range:
                continue

            tree_rel = float(np.arctan2(dy, dx) - yaw)
            tree_rel = float(np.arctan2(np.sin(tree_rel), np.cos(tree_rel)))
            half_span = float(np.arcsin(min(0.999999, radius / center_dist)))

            center_float = (tree_rel - self._beam_angle_start) / self._beam_angle_step
            center_idx = int(np.round(center_float)) % n_beams
            half_beams = int(np.ceil(half_span / self._beam_angle_step))

            if half_beams >= n_beams // 2:
                candidate_idx = np.arange(n_beams, dtype=np.int64)
            else:
                offsets = np.arange(-half_beams, half_beams + 1, dtype=np.int64)
                candidate_idx = (center_idx + offsets) % n_beams

            cdx = dir_x[candidate_idx]
            cdy = dir_y[candidate_idx]

            # Ray-circle intersection from origin p + t * d, with |d|=1 and t>=0
            proj = dx * cdx + dy * cdy
            forward = proj > 0.0
            if not np.any(forward):
                continue

            perp2 = d2 - proj * proj
            hit = forward & (perp2 <= r2)
            if not np.any(hit):
                continue

            proj_hit = proj[hit]
            perp2_hit = perp2[hit]
            chord = np.sqrt(np.maximum(0.0, r2 - perp2_hit))
            t_near = np.maximum(0.0, proj_hit - chord)

            hit_idx = candidate_idx[hit]
            ranges[hit_idx] = np.minimum(ranges[hit_idx], t_near)

        np.clip(ranges, 0.0, max_range, out=ranges)
        return ranges.astype(np.float32)

    def _sample_forest(
        self,
        exclusion_centers: Optional[np.ndarray] = None,
        exclusion_radius: float = 0.0,
    ):
        """Sample tree positions/radii from worldgen and apply optional exclusions."""
        # Try package import first; if this file is run directly, ensure project paths
        # are available so worldgen can still be imported.
        if self._worldgen_generate_positions_fn is None:
            module = None
            for module_name in (
                "worldgen.forest_worldgen.generate_world",
                "forest_worldgen.generate_world",
            ):
                try:
                    module = importlib.import_module(module_name)
                    break
                except ImportError:
                    continue

            if module is None:
                project_root = self._project_root()
                candidate_paths = (project_root, project_root / "worldgen")
                for path in candidate_paths:
                    path_str = str(path)
                    if path_str not in sys.path:
                        sys.path.insert(0, path_str)

                for module_name in (
                    "worldgen.forest_worldgen.generate_world",
                    "forest_worldgen.generate_world",
                ):
                    try:
                        module = importlib.import_module(module_name)
                        break
                    except ImportError:
                        continue

            if module is None:
                raise ImportError(
                    "Could not import world generator module. Tried "
                    "'worldgen.forest_worldgen.generate_world' and "
                    "'forest_worldgen.generate_world'."
                )

            self._worldgen_generate_positions_fn = getattr(module, "generate_positions_from_config")

        generate_positions_from_config = self._worldgen_generate_positions_fn

        config_path = self._worldgen_config_path()
        if not config_path.exists():
            raise FileNotFoundError(f"worldgen config not found: {config_path}")

        episode_seed = int(self.np_random.integers(0, np.iinfo(np.int32).max))
        episode_seed += int(self.p.worldgen_seed_offset)
        self._last_worldgen_seed = episode_seed

        positions_xy, world_config, _, selection_meta = generate_positions_from_config(
            str(config_path),
            seed=episode_seed,
            verbose=bool(self.p.worldgen_verbose),
            return_selection_meta=True,
        )
        self._last_worldgen_selection = selection_meta if isinstance(selection_meta, dict) else None

        area_size = float(world_config["generation"]["area_size"])
        self._world_half_extent = area_size / 2.0

        points_xy = np.asarray(positions_xy, dtype=np.float32)
        if points_xy.size == 0:
            return np.zeros((0, 3), dtype=np.float32)

        radii = self.np_random.normal(
            loc=self.p.tree_radius_mean,
            scale=max(self.p.tree_radius_std, 1e-6),
            size=(points_xy.shape[0],),
        ).astype(np.float32)
        radii = np.clip(radii, self.p.tree_radius_min, self.p.tree_radius_max)

        if exclusion_centers is not None and len(exclusion_centers) > 0 and exclusion_radius > 0.0:
            centers = np.asarray(exclusion_centers, dtype=np.float32)
            if centers.ndim == 1:
                centers = centers.reshape(1, 2)

            diff = points_xy[:, None, :] - centers[None, :, :]
            dist = np.linalg.norm(diff, axis=2)
            required = radii[:, None] + float(exclusion_radius)
            keep_mask = np.all(dist >= required, axis=1)

            points_xy = points_xy[keep_mask]
            radii = radii[keep_mask]

            if points_xy.size == 0:
                return np.zeros((0, 3), dtype=np.float32)

        return np.column_stack([points_xy, radii]).astype(np.float32)

    def _sample_start_goal_anchors(self) -> tuple[np.ndarray, np.ndarray]:
        """Sample stochastic start/goal anchor points from opposite boundary bands (tree-agnostic)."""
        half = float(self._effective_world_half_extent())
        band_inner = 0.55 * half

        def _uniform_between(a: float, b: float) -> float:
            lo = min(float(a), float(b))
            hi = max(float(a), float(b))
            return float(self.np_random.uniform(lo, hi))

        for _ in range(self.p.spawn_max_attempts):
            axis = int(self.np_random.integers(0, 2))
            start_side = -1.0 if bool(self.np_random.integers(0, 2)) else 1.0
            goal_side = -start_side

            if axis == 0:
                start_x = _uniform_between(start_side * band_inner, start_side * half)
                start_y = float(self.np_random.uniform(-half, half))
                goal_x = _uniform_between(goal_side * band_inner, goal_side * half)
                goal_y = float(self.np_random.uniform(-half, half))
            else:
                start_x = float(self.np_random.uniform(-half, half))
                start_y = _uniform_between(start_side * band_inner, start_side * half)
                goal_x = float(self.np_random.uniform(-half, half))
                goal_y = _uniform_between(goal_side * band_inner, goal_side * half)

            start_xy = np.array([start_x, start_y], dtype=np.float32)
            goal_xy = np.array([goal_x, goal_y], dtype=np.float32)

            if np.linalg.norm(goal_xy - start_xy) >= self.p.min_start_goal_distance:
                return start_xy, goal_xy

        fallback_start = np.array([band_inner, 0.0], dtype=np.float32)
        fallback_goal = np.array([-band_inner, 0.0], dtype=np.float32)
        return fallback_start, fallback_goal

    def _point_clear_of_trees(self, x: float, y: float, clearance: float) -> bool:
        if self.trees is None or len(self.trees) == 0:
            return True
        dxy = self.trees[:, :2] - np.array([x, y], dtype=np.float32)
        d = np.linalg.norm(dxy, axis=1)
        needed = self.trees[:, 2] + clearance
        return bool(np.all(d >= needed))

    def _sample_free_xy(
        self,
        clearance: float,
        x_bounds: Optional[Tuple[float, float]] = None,
        y_bounds: Optional[Tuple[float, float]] = None,
    ) -> np.ndarray:
        """Sample a tree-free XY point within optional bounds and clearance."""
        if x_bounds is None:
            half = self._effective_world_half_extent()
            x_bounds = (-half, half)
        if y_bounds is None:
            half = self._effective_world_half_extent()
            y_bounds = (-half, half)

        effective_half = self._effective_world_half_extent()
        x_lo = float(max(-effective_half, min(x_bounds[0], x_bounds[1])))
        x_hi = float(min(effective_half, max(x_bounds[0], x_bounds[1])))
        y_lo = float(max(-effective_half, min(y_bounds[0], y_bounds[1])))
        y_hi = float(min(effective_half, max(y_bounds[0], y_bounds[1])))

        if x_lo >= x_hi or y_lo >= y_hi:
            return np.array([0.0, 0.0], dtype=np.float32)

        for _ in range(self.p.spawn_max_attempts):
            x = float(self.np_random.uniform(x_lo, x_hi))
            y = float(self.np_random.uniform(y_lo, y_hi))
            if self._point_clear_of_trees(x, y, clearance):
                return np.array([x, y], dtype=np.float32)

        return np.array([0.0, 0.0], dtype=np.float32)

    def _sample_start_pose(self, preferred_xy: Optional[np.ndarray] = None):
        """Sample a collision-free start pose near boundary bands."""
        if preferred_xy is not None:
            preferred = np.asarray(preferred_xy, dtype=np.float32)
            if preferred.shape[0] >= 2:
                if self._point_clear_of_trees(
                    float(preferred[0]), float(preferred[1]), self.p.start_goal_clearance
                ):
                    yaw = np.float32(self.np_random.uniform(-np.pi, np.pi))
                    return np.array(
                        [preferred[0], preferred[1], self.p.default_z_target], dtype=np.float32
                    ), yaw

        half = float(self._effective_world_half_extent())
        band_inner = 0.55 * half

        for _ in range(self.p.spawn_max_attempts):
            axis = int(self.np_random.integers(0, 2))
            side = -1.0 if bool(self.np_random.integers(0, 2)) else 1.0

            if axis == 0:
                x_bounds = (side * band_inner, side * half)
                y_bounds = (-half, half)
            else:
                x_bounds = (-half, half)
                y_bounds = (side * band_inner, side * half)

            start_xy = self._sample_free_xy(
                clearance=self.p.start_goal_clearance,
                x_bounds=x_bounds,
                y_bounds=y_bounds,
            )
            if not np.allclose(start_xy, 0.0):
                yaw = np.float32(self.np_random.uniform(-np.pi, np.pi))
                return np.array(
                    [start_xy[0], start_xy[1], self.p.default_z_target], dtype=np.float32
                ), yaw

        start_xy = self._sample_free_xy(clearance=self.p.start_goal_clearance)
        yaw = np.float32(self.np_random.uniform(-np.pi, np.pi))
        return np.array([start_xy[0], start_xy[1], self.p.default_z_target], dtype=np.float32), yaw

    def _sample_goal_pose(self, preferred_xy: Optional[np.ndarray] = None):
        """Sample a collision-free goal pose opposite from the start side."""
        if preferred_xy is not None:
            preferred = np.asarray(preferred_xy, dtype=np.float32)
            if preferred.shape[0] >= 2:
                if self._point_clear_of_trees(
                    float(preferred[0]), float(preferred[1]), self.p.start_goal_clearance
                ):
                    return np.array(
                        [preferred[0], preferred[1], self.p.default_z_target], dtype=np.float32
                    )

        half = float(self._effective_world_half_extent())
        band_inner = 0.55 * half

        start_xy = self.pos[:2].astype(np.float32)
        if abs(float(start_xy[0])) >= abs(float(start_xy[1])):
            axis = 0
            start_side = 1.0 if float(start_xy[0]) >= 0.0 else -1.0
        else:
            axis = 1
            start_side = 1.0 if float(start_xy[1]) >= 0.0 else -1.0

        goal_side = -start_side

        for _ in range(self.p.spawn_max_attempts):
            if axis == 0:
                x_bounds = (goal_side * band_inner, goal_side * half)
                y_bounds = (-half, half)
            else:
                x_bounds = (-half, half)
                y_bounds = (goal_side * band_inner, goal_side * half)

            goal_xy = self._sample_free_xy(
                clearance=self.p.start_goal_clearance,
                x_bounds=x_bounds,
                y_bounds=y_bounds,
            )
            if np.allclose(goal_xy, 0.0):
                continue

            if np.linalg.norm(goal_xy - self.pos[:2]) >= self.p.min_start_goal_distance:
                return np.array([goal_xy[0], goal_xy[1], self.p.default_z_target], dtype=np.float32)

        fallback = np.array(
            [
                -np.sign(float(start_xy[0])) * band_inner if axis == 0 else 0.0,
                -np.sign(float(start_xy[1])) * band_inner if axis == 1 else 0.0,
            ],
            dtype=np.float32,
        )
        return np.array([fallback[0], fallback[1], self.p.default_z_target], dtype=np.float32)


if __name__ == "__main__":
    # Minimal test params - for real usage, load from sac.yaml via utils.build_env_params
    params = SimParams(
        dt=0.1,
        lidar_num_beams=180,
        lidar_range_max=30.0,
        v_max=6.0,
        wz_max=2.5,
        vz_max=2.0,
        r_safe=0.6,
        episode_seconds=30.0,
        goal_tolerance=0.5,
        world_radius=20.0,
        collision_threshold=0.05,
        default_z_target=2.0,
        z_error_scale=5.0,
        reward_progress_scale=2.0,
        reward_speed_scale=0.02,
        reward_step_penalty=0.02,
        reward_proximity_scale=0.2,
        reward_shield_penalty=0.02,
        reward_collision_penalty=20.0,
        reward_success_bonus=5.0,
        reward_truncation_penalty=1.0,
        reward_yaw_rate_scale=0.03,
        reward_stall_penalty=0.05,
        progress_stall_threshold=0.02,
        yaw_penalty_speed_gate=0.25,
        shield_floor_z_min=0.05,
        shield_yaw_damping=0.35,
        shield_lookahead_margin=1.0,
        worldgen_config_relpath="configs/worldgen/worldgen_run.yaml",
        worldgen_seed_offset=0,
        tree_radius_mean=0.25,
        tree_radius_std=0.05,
        tree_radius_min=0.10,
        tree_radius_max=0.60,
        start_goal_clearance=1.0,
        min_start_goal_distance=8.0,
        spawn_max_attempts=500,
    )
    env = ForestNavEnv(params)
    check_env(env, warn=True)
