yincheng.zhong
2025-11-24 346cc7d685283df529aadbcf9c156de040ce44f9
python/hitl/simulator.py
@@ -5,10 +5,14 @@
from __future__ import annotations
import dataclasses
import heapq
import itertools
import math
import queue
import threading
import time
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Callable
import serial
@@ -16,12 +20,100 @@
from . import geo
from .dynamics import DifferentialDriveModel, DifferentialDriveState
from .protocols import (
    ControlStatus,
    PoseStatus,
    PythonLinkDecoder,
    PythonLinkFrame,
    StackStatus,
    StateStatus,
    build_gpimu_sentence,
    build_gprmi_sentence,
    decode_control_status,
    decode_pose_status,
    decode_stack_status,
    decode_state_status,
    parse_ascii_message,
)
GPS_EPOCH = datetime(1980, 1, 6, tzinfo=timezone.utc)
class RunLogger:
    """异步日志器:按 GPS 时间优先排序写入;无 GPS 时间则回退到主机时间。"""
    def __init__(self, path: Path | str, flush_delay_ms: float = 500.0):
        self.path = Path(path)
        self.path.parent.mkdir(parents=True, exist_ok=True)
        self._file = self.path.open("w", encoding="utf-8")
        self._queue: queue.Queue = queue.Queue()
        self._buffer: list[
            tuple[tuple[int, float], int, float, str, str]
        ] = []
        self._stop = threading.Event()
        self._counter = itertools.count()
        self._flush_delay = flush_delay_ms / 1000.0
        self._worker = threading.Thread(target=self._loop, name="hitl-runlog", daemon=True)
        self._worker.start()
    def log(
        self,
        prefix: str,
        content: str,
        *,
        gps_time_s: float | None = None,
        source_rank: int = 0,
    ):
        arrival = time.time()
        priority = 0 if gps_time_s is not None else 1
        key_time = float(gps_time_s if gps_time_s is not None else arrival)
        self._queue.put(
            (
                (priority, key_time, source_rank),
                next(self._counter),
                arrival,
                prefix,
                content,
            )
        )
    def close(self):
        self._stop.set()
        self._worker.join()
        self._flush_buffer(force=True)
        if not self._file.closed:
            self._file.close()
    def _loop(self):
        while not self._stop.is_set() or not self._queue.empty() or self._buffer:
            try:
                item = self._queue.get(timeout=0.05)
                heapq.heappush(self._buffer, item)
            except queue.Empty:
                pass
            self._flush_buffer()
    def _flush_buffer(self, force: bool = False):
        now = time.time()
        while self._buffer:
            key, _seq, arrival, prefix, content = self._buffer[0]
            if not force and (now - arrival) < self._flush_delay and not self._stop.is_set():
                break
            heapq.heappop(self._buffer)
            self._write_line(key, prefix, content)
    def _write_line(
        self, key: tuple[int, float, int], prefix: str, content: str
    ):
        priority, timestamp_s, source_rank = key
        if priority == 0:
            ts_label = f"GPS={timestamp_s:09.3f}s"
        else:
            ts_label = f"HOST={timestamp_s:012.3f}s"
        line = f"[{ts_label} r={source_rank}] {prefix}: {content}\n"
        self._file.write(line)
        self._file.flush()
@dataclasses.dataclass
class HitlConfig:
@@ -32,7 +124,7 @@
    origin_gga: str
    initial_enu: tuple[float, float, float] = (0.0, 0.0, 0.0)
    initial_heading_deg: float = 0.0
    gps_baudrate: int = 460800
    gps_baudrate: int = 115200
    log_baudrate: int = 921600
    track_width: float = 0.78
    baseline_distance: float = 0.9
@@ -118,8 +210,12 @@
    _threads: list[threading.Thread]
    on_control: Callable[[float, float], None] | None
    on_log: Callable[[str], None] | None
    on_control_status: Callable[[ControlStatus], None] | None
    on_pose_status: Callable[[PoseStatus], None] | None
    on_state_status: Callable[[StateStatus], None] | None
    on_stack_status: Callable[[StackStatus], None] | None
    def __init__(self, config: HitlConfig):
    def __init__(self, config: HitlConfig, run_logger: RunLogger | None = None):
        self.config = config
        self.origin = geo.parse_origin(config.origin_gga)
        self.model = DifferentialDriveModel(
@@ -145,9 +241,15 @@
        self._decoder = PythonLinkDecoder(self._handle_control_frame)
        self._running = threading.Event()
        self._threads = []
        self.run_logger = run_logger
        self.on_control = None
        self.on_log = None
        self.on_control_status = None
        self.on_pose_status = None
        self.on_state_status = None
        self.on_stack_status = None
        self.on_ascii = None
    # ------------------------------------------------------------------ #
    # 生命周期
@@ -197,24 +299,29 @@
        period = 0.1  # 10 Hz
        while self._running.is_set():
            start = time.perf_counter()
            sentence = self._build_gprmi_sentence()
            sentence, gps_time_s = self._build_gprmi_sentence()
            if sentence:
                self.uart2.write(sentence)
                self._log_ascii("PY->STM32 UART2 GPFMI", sentence, gps_time_s=gps_time_s, source_rank=0)
            self._sleep_remaining(start, period)
    def _loop_gpimu(self):
        period = 0.01  # 100 Hz
        while self._running.is_set():
            start = time.perf_counter()
            sentence = self._build_gpimu_sentence()
            sentence, gps_time_s = self._build_gpimu_sentence()
            if sentence:
                self.uart2.write(sentence)
                self._log_ascii("PY->STM32 UART2 GPIMU", sentence, gps_time_s=gps_time_s, source_rank=0)
            self._sleep_remaining(start, period)
    def _loop_control(self):
        while self._running.is_set():
            data = self.uart2.read(128)
            if data:
                self._log_binary(
                    "STM32->PY UART2 CTRL_RAW", data, source_rank=1
                )
                self._decoder.feed(data)
            else:
                time.sleep(0.002)
@@ -230,37 +337,87 @@
                time.sleep(0.01)
                continue
            text = line.decode("utf-8", errors="replace").strip()
            if text and self.on_log:
            if not text:
                continue
            handled = False
            msg = parse_ascii_message(text)
            if msg:
                ctrl = decode_control_status(msg)
                if ctrl and self.on_control_status:
                    self.on_control_status(ctrl)
                    self._log_text(
                        "STM32 UART5 CTRL",
                        text,
                        gps_time_s=_ascii_timestamp_to_seconds(ctrl.timestamp_ms),
                        source_rank=1,
                    )
                    self._apply_ascii_control(ctrl)
                    handled = True
                pose = decode_pose_status(msg)
                if pose and self.on_pose_status:
                    self.on_pose_status(pose)
                    self._log_text(
                        "STM32 UART5 POSE",
                        text,
                        gps_time_s=_ascii_timestamp_to_seconds(pose.timestamp_ms),
                        source_rank=1,
                    )
                    handled = True
                state = decode_state_status(msg)
                if state and self.on_state_status:
                    self.on_state_status(state)
                    self._log_text(
                        "STM32 UART5 STATE",
                        text,
                        gps_time_s=_ascii_timestamp_to_seconds(state.timestamp_ms),
                        source_rank=1,
                    )
                    handled = True
                stack = decode_stack_status(msg)
                if stack and self.on_stack_status:
                    self.on_stack_status(stack)
                    handled = True
            if not handled and self.on_log:
                self.on_log(text)
            if not handled:
                self._log_text("STM32 UART5", text, source_rank=1)
    # ------------------------------------------------------------------ #
    # 构造帧
    # ------------------------------------------------------------------ #
    def _build_gprmi_sentence(self) -> bytes | None:
    def _build_gprmi_sentence(self) -> tuple[bytes | None, float]:
        state, timestamp = self._snapshot()
        lat, lon, alt = geo.enu_to_lla(state.east, state.north, state.up, self.origin)
        heading_nav = geo.heading_math_to_nav(state.heading)
        return build_gprmi_sentence(
            timestamp=timestamp,
            lat_deg=lat,
            lon_deg=lon,
            alt_m=alt,
            east_vel=state.east_velocity,
            north_vel=state.north_velocity,
            up_vel=state.up_velocity,
            heading_deg=heading_nav,
            pitch_deg=state.pitch_deg,
            roll_deg=state.roll_deg,
            baseline_m=self.config.baseline_distance,
        gps_time_s = _seconds_of_day(timestamp)
        return (
            build_gprmi_sentence(
                timestamp=timestamp,
                lat_deg=lat,
                lon_deg=lon,
                alt_m=alt,
                east_vel=state.east_velocity,
                north_vel=state.north_velocity,
                up_vel=state.up_velocity,
                heading_deg=heading_nav,
                pitch_deg=state.pitch_deg,
                roll_deg=state.roll_deg,
                baseline_m=self.config.baseline_distance,
            ),
            gps_time_s,
        )
    def _build_gpimu_sentence(self) -> bytes | None:
    def _build_gpimu_sentence(self) -> tuple[bytes | None, float]:
        state, timestamp = self._snapshot()
        return build_gpimu_sentence(
            timestamp=timestamp,
            accel_g=state.body_accel_g,
            gyro_deg_s=state.gyro_deg_s,
            temperature_c=state.temperature_c,
        gps_time_s = _seconds_of_day(timestamp)
        return (
            build_gpimu_sentence(
                timestamp=timestamp,
                accel_g=state.body_accel_g,
                gyro_deg_s=state.gyro_deg_s,
                temperature_c=state.temperature_c,
            ),
            gps_time_s,
        )
    # ------------------------------------------------------------------ #
@@ -276,6 +433,11 @@
            self._target_angular = frame.turn
        if self.on_control:
            self.on_control(frame.forward, frame.turn)
        if self.run_logger:
            self.run_logger.log(
                "STM32->PY CTRL_FRAME",
                f"forward={frame.forward:.3f} turn={frame.turn:.3f} pwm={frame.steering_pwm}/{frame.throttle_pwm}",
            )
    @staticmethod
    def _sleep_remaining(start: float, period: float):
@@ -284,6 +446,108 @@
        if remaining > 0:
            time.sleep(remaining)
    # ------------------------------------------------------------------ #
    # 外部控制
    # ------------------------------------------------------------------ #
    def update_origin(self, origin_gga: str):
        if not origin_gga:
            return
        try:
            new_origin = geo.parse_origin(origin_gga)
        except ValueError:
            return
        self.config.origin_gga = origin_gga
        self.origin = new_origin
        self._sim_time = _initial_timestamp_from_gga(origin_gga)
    def reset_state(self, east: float, north: float, up: float, heading_deg: float):
        with self._state_lock:
            self.model.reset(east=east, north=north, up=up, heading_deg=heading_deg)
            self._latest_state = self.model.state.copy()
            self._target_linear = 0.0
            self._target_angular = 0.0
    # ------------------------------------------------------------------ #
    # 日志工具
    # ------------------------------------------------------------------ #
    def _apply_ascii_control(self, ctrl: ControlStatus):
        with self._state_lock:
            self._target_linear = ctrl.forward_mps
            self._target_angular = ctrl.turn_rate
        if self.on_control:
            self.on_control(ctrl.forward_mps, ctrl.turn_rate)
    def _log_ascii(
        self,
        prefix: str,
        payload: bytes,
        *,
        gps_time_s: float | None = None,
        source_rank: int = 0,
    ):
        if not self.run_logger or not payload:
            return
        text = payload.decode("utf-8", errors="replace").strip()
        self.run_logger.log(prefix, text, gps_time_s=gps_time_s, source_rank=source_rank)
    def _log_binary(
        self,
        prefix: str,
        payload: bytes,
        *,
        gps_time_s: float | None = None,
        source_rank: int = 0,
    ):
        if not self.run_logger or not payload:
            return
        self.run_logger.log(
            prefix, payload.hex(), gps_time_s=gps_time_s, source_rank=source_rank
        )
    def _log_text(
        self,
        prefix: str,
        text: str,
        *,
        gps_time_s: float | None = None,
        source_rank: int = 0,
    ):
        if not self.run_logger:
            return
        self.run_logger.log(prefix, text, gps_time_s=gps_time_s, source_rank=source_rank)
def _ensure_utc_datetime(dt: datetime) -> datetime:
    if dt.tzinfo is None:
        return dt.replace(tzinfo=timezone.utc)
    return dt.astimezone(timezone.utc)
def _seconds_of_day(dt: datetime) -> float:
    utc = _ensure_utc_datetime(dt)
    return (
        utc.hour * 3600
        + utc.minute * 60
        + utc.second
        + utc.microsecond / 1_000_000.0
    )
def _ascii_timestamp_to_seconds(raw: float | int | None) -> float | None:
    if raw is None:
        return None
    try:
        value = float(raw)
    except (TypeError, ValueError):
        return None
    text = f"{int(abs(value)):09d}"
    hh = int(text[0:2])
    mm = int(text[2:4])
    ss = int(text[4:6])
    ms = int(text[6:9])
    if 0 <= hh < 24 and 0 <= mm < 60 and 0 <= ss < 60:
        return hh * 3600 + mm * 60 + ss + ms / 1000.0
    return value / 1000.0
def _initial_timestamp_from_gga(gga: str) -> datetime:
    parts = (gga or "").split(",")