yincheng.zhong
2 天以前 567085ead3f6adaabd884f16ab4b17c62e8f0403
python/hitl/simulator.py
@@ -1,14 +1,18 @@
"""
硬件在环 (HITL) 仿真器:生成 $GPRMI/$GPIMU 传感器数据,并与 STM32H7 通过 PythonLink 闭环。
硬件在环 (HITL) 仿真器:生成 IM23A fmin/fmim 传感器帧,并与 STM32H7 通过 PythonLink 闭环。
"""
from __future__ import annotations
import dataclasses
import heapq
import itertools
import math
import queue
import threading
import time
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Callable
import serial
@@ -16,12 +20,100 @@
from . import geo
from .dynamics import DifferentialDriveModel, DifferentialDriveState
from .protocols import (
    ControlStatus,
    PoseStatus,
    PythonLinkDecoder,
    PythonLinkFrame,
    build_gpimu_sentence,
    build_gprmi_sentence,
    StackStatus,
    StateStatus,
    build_im23a_imu_frame,
    build_im23a_nav_frame,
    decode_control_status,
    decode_pose_status,
    decode_stack_status,
    decode_state_status,
    parse_ascii_message,
)
GPS_EPOCH = datetime(1980, 1, 6, tzinfo=timezone.utc)
class RunLogger:
    """异步日志器:按 GPS 时间优先排序写入;无 GPS 时间则回退到主机时间。"""
    def __init__(self, path: Path | str, flush_delay_ms: float = 500.0):
        self.path = Path(path)
        self.path.parent.mkdir(parents=True, exist_ok=True)
        self._file = self.path.open("w", encoding="utf-8")
        self._queue: queue.Queue = queue.Queue()
        self._buffer: list[
            tuple[tuple[int, float], int, float, str, str]
        ] = []
        self._stop = threading.Event()
        self._counter = itertools.count()
        self._flush_delay = flush_delay_ms / 1000.0
        self._worker = threading.Thread(target=self._loop, name="hitl-runlog", daemon=True)
        self._worker.start()
    def log(
        self,
        prefix: str,
        content: str,
        *,
        gps_time_s: float | None = None,
        source_rank: int = 0,
    ):
        arrival = time.time()
        priority = 0 if gps_time_s is not None else 1
        key_time = float(gps_time_s if gps_time_s is not None else arrival)
        self._queue.put(
            (
                (priority, key_time, source_rank),
                next(self._counter),
                arrival,
                prefix,
                content,
            )
        )
    def close(self):
        self._stop.set()
        self._worker.join()
        self._flush_buffer(force=True)
        if not self._file.closed:
            self._file.close()
    def _loop(self):
        while not self._stop.is_set() or not self._queue.empty() or self._buffer:
            try:
                item = self._queue.get(timeout=0.05)
                heapq.heappush(self._buffer, item)
            except queue.Empty:
                pass
            self._flush_buffer()
    def _flush_buffer(self, force: bool = False):
        now = time.time()
        while self._buffer:
            key, _seq, arrival, prefix, content = self._buffer[0]
            if not force and (now - arrival) < self._flush_delay and not self._stop.is_set():
                break
            heapq.heappop(self._buffer)
            self._write_line(key, prefix, content)
    def _write_line(
        self, key: tuple[int, float, int], prefix: str, content: str
    ):
        priority, timestamp_s, source_rank = key
        if priority == 0:
            ts_label = f"GPS={timestamp_s:09.3f}s"
        else:
            ts_label = f"HOST={timestamp_s:012.3f}s"
        line = f"[{ts_label} r={source_rank}] {prefix}: {content}\n"
        self._file.write(line)
        self._file.flush()
@dataclasses.dataclass
class HitlConfig:
@@ -32,12 +124,13 @@
    origin_gga: str
    initial_enu: tuple[float, float, float] = (0.0, 0.0, 0.0)
    initial_heading_deg: float = 0.0
    gps_baudrate: int = 460800
    gps_baudrate: int = 115200
    log_baudrate: int = 921600
    track_width: float = 0.78
    baseline_distance: float = 0.9
    max_linear_speed: float = 2.0
    max_angular_speed_deg: float = 140.0
    position_quality: int = 4
class SerialEndpoint:
@@ -118,8 +211,12 @@
    _threads: list[threading.Thread]
    on_control: Callable[[float, float], None] | None
    on_log: Callable[[str], None] | None
    on_control_status: Callable[[ControlStatus], None] | None
    on_pose_status: Callable[[PoseStatus], None] | None
    on_state_status: Callable[[StateStatus], None] | None
    on_stack_status: Callable[[StackStatus], None] | None
    def __init__(self, config: HitlConfig):
    def __init__(self, config: HitlConfig, run_logger: RunLogger | None = None):
        self.config = config
        self.origin = geo.parse_origin(config.origin_gga)
        self.model = DifferentialDriveModel(
@@ -145,9 +242,15 @@
        self._decoder = PythonLinkDecoder(self._handle_control_frame)
        self._running = threading.Event()
        self._threads = []
        self.run_logger = run_logger
        self.on_control = None
        self.on_log = None
        self.on_control_status = None
        self.on_pose_status = None
        self.on_state_status = None
        self.on_stack_status = None
        self.on_ascii = None
    # ------------------------------------------------------------------ #
    # 生命周期
@@ -160,8 +263,8 @@
        self._running.set()
        self._threads = [
            threading.Thread(target=self._loop_physics, name="hitl-phys", daemon=True),
            threading.Thread(target=self._loop_gprmi, name="hitl-gprmi", daemon=True),
            threading.Thread(target=self._loop_gpimu, name="hitl-gpimu", daemon=True),
            threading.Thread(target=self._loop_nav, name="hitl-nav", daemon=True),
            threading.Thread(target=self._loop_imu, name="hitl-imu", daemon=True),
            threading.Thread(target=self._loop_control, name="hitl-ctrl", daemon=True),
            threading.Thread(target=self._loop_log, name="hitl-log", daemon=True),
        ]
@@ -188,33 +291,43 @@
            dt = max(now - last, 1e-4)
            last = now
            with self._state_lock:
                # 注意:这里的 self._target_linear/angular 是从 STM32 发回来的 PWM 转换后的估算值
                # 实际上是 STM32 认为的命令,或者如果我们只收 PWM 的话,它就是实际控制量
                # STM32 发送的是 target_mps 和 target_turn,或者是 PWM
                # 在 protocols.py 中,如果我们收到了 PWM,就会转换为速度
                # 如果收到的是 target_mps,那就是 target_mps
                # 现在的逻辑是:Simulator 收到 PythonLinkFrame,里面可能是 PWM 反算的 velocity
                # 这样就构成了闭环:STM32 计算 PWM -> Python 接收 PWM -> 反算速度 -> 物理模型积分 -> 传感器数据 -> STM32
                state = self.model.step(self._target_linear, self._target_angular, dt).copy()
                self._latest_state = state
                self._sim_time += timedelta(seconds=dt)
            time.sleep(0.005)
    def _loop_gprmi(self):
    def _loop_nav(self):
        period = 0.1  # 10 Hz
        while self._running.is_set():
            start = time.perf_counter()
            sentence = self._build_gprmi_sentence()
            if sentence:
                self.uart2.write(sentence)
            frame, gps_time_s = self._build_nav_frame()
            if frame:
                self.uart2.write(frame)
            self._sleep_remaining(start, period)
    def _loop_gpimu(self):
    def _loop_imu(self):
        period = 0.01  # 100 Hz
        while self._running.is_set():
            start = time.perf_counter()
            sentence = self._build_gpimu_sentence()
            if sentence:
                self.uart2.write(sentence)
            frame, gps_time_s = self._build_imu_frame()
            if frame:
                self.uart2.write(frame)
            self._sleep_remaining(start, period)
    def _loop_control(self):
        while self._running.is_set():
            data = self.uart2.read(128)
            if data:
                self._log_binary(
                    "STM32->PY UART2 CTRL_RAW", data, source_rank=1
                )
                self._decoder.feed(data)
            else:
                time.sleep(0.002)
@@ -230,17 +343,64 @@
                time.sleep(0.01)
                continue
            text = line.decode("utf-8", errors="replace").strip()
            if text and self.on_log:
            if not text:
                continue
            handled = False
            msg = parse_ascii_message(text)
            if msg:
                ctrl = decode_control_status(msg)
                if ctrl and self.on_control_status:
                    self.on_control_status(ctrl)
                    self._log_text(
                        "STM32 UART5 CTRL",
                        text,
                        gps_time_s=_ascii_timestamp_to_seconds(ctrl.timestamp_ms),
                        source_rank=1,
                    )
                    self._apply_ascii_control(ctrl)
                    handled = True
                pose = decode_pose_status(msg)
                if pose and self.on_pose_status:
                    self.on_pose_status(pose)
                    self._log_text(
                        "STM32 UART5 POSE",
                        text,
                        gps_time_s=_ascii_timestamp_to_seconds(pose.timestamp_ms),
                        source_rank=1,
                    )
                    handled = True
                state = decode_state_status(msg)
                if state and self.on_state_status:
                    self.on_state_status(state)
                    self._log_text(
                        "STM32 UART5 STATE",
                        text,
                        gps_time_s=_ascii_timestamp_to_seconds(state.timestamp_ms),
                        source_rank=1,
                    )
                    handled = True
                stack = decode_stack_status(msg)
                if stack and self.on_stack_status:
                    self.on_stack_status(stack)
                    handled = True
            if not handled and self.on_log:
                self.on_log(text)
            if not handled:
                self._log_text("STM32 UART5", text, source_rank=1)
    # ------------------------------------------------------------------ #
    # 构造帧
    # ------------------------------------------------------------------ #
    def _build_gprmi_sentence(self) -> bytes | None:
    def _build_nav_frame(self) -> tuple[bytes | None, float]:
        state, timestamp = self._snapshot()
        lat, lon, alt = geo.enu_to_lla(state.east, state.north, state.up, self.origin)
        heading_nav = geo.heading_math_to_nav(state.heading)
        return build_gprmi_sentence(
        heading_nav_deg = geo.heading_math_to_nav(state.heading)
        # 将导航坐标系角度(0-360度)转换为弧度(-π到π),0度正北对应0弧度
        # 0°(正北)→0rad, 90°(正东)→π/2rad, 180°(正南)→πrad, 270°(正西)→-π/2rad
        heading_nav_rad = math.radians(heading_nav_deg) if heading_nav_deg <= 180.0 else math.radians(heading_nav_deg - 360.0)
        gps_time_s = _seconds_of_day(timestamp)
        status_flags = int(self.config.position_quality) & 0xFF
        frame = build_im23a_nav_frame(
            timestamp=timestamp,
            lat_deg=lat,
            lon_deg=lon,
@@ -248,20 +408,25 @@
            east_vel=state.east_velocity,
            north_vel=state.north_velocity,
            up_vel=state.up_velocity,
            heading_deg=heading_nav,
            heading_rad=heading_nav_rad,
            pitch_deg=state.pitch_deg,
            roll_deg=state.roll_deg,
            baseline_m=self.config.baseline_distance,
            accel_bias=(0.0, 0.0, 0.0),
            gyro_bias=(0.0, 0.0, 0.0),
            temperature_c=state.temperature_c,
            status_flags=status_flags,
        )
        return frame, gps_time_s
    def _build_gpimu_sentence(self) -> bytes | None:
    def _build_imu_frame(self) -> tuple[bytes | None, float]:
        state, timestamp = self._snapshot()
        return build_gpimu_sentence(
        gps_time_s = _seconds_of_day(timestamp)
        frame = build_im23a_imu_frame(
            timestamp=timestamp,
            accel_g=state.body_accel_g,
            gyro_deg_s=state.gyro_deg_s,
            temperature_c=state.temperature_c,
        )
        return frame, gps_time_s
    # ------------------------------------------------------------------ #
    # 工具
@@ -276,6 +441,11 @@
            self._target_angular = frame.turn
        if self.on_control:
            self.on_control(frame.forward, frame.turn)
        if self.run_logger:
            self.run_logger.log(
                "STM32->PY CTRL_FRAME",
                f"forward={frame.forward:.3f} turn={frame.turn:.3f} pwm={frame.steering_pwm}/{frame.throttle_pwm}",
            )
    @staticmethod
    def _sleep_remaining(start: float, period: float):
@@ -284,6 +454,108 @@
        if remaining > 0:
            time.sleep(remaining)
    # ------------------------------------------------------------------ #
    # 外部控制
    # ------------------------------------------------------------------ #
    def update_origin(self, origin_gga: str):
        if not origin_gga:
            return
        try:
            new_origin = geo.parse_origin(origin_gga)
        except ValueError:
            return
        self.config.origin_gga = origin_gga
        self.origin = new_origin
        self._sim_time = _initial_timestamp_from_gga(origin_gga)
    def reset_state(self, east: float, north: float, up: float, heading_deg: float):
        with self._state_lock:
            self.model.reset(east=east, north=north, up=up, heading_deg=heading_deg)
            self._latest_state = self.model.state.copy()
            self._target_linear = 0.0
            self._target_angular = 0.0
    # ------------------------------------------------------------------ #
    # 日志工具
    # ------------------------------------------------------------------ #
    def _apply_ascii_control(self, ctrl: ControlStatus):
        with self._state_lock:
            self._target_linear = ctrl.forward_mps
            self._target_angular = ctrl.turn_rate
        if self.on_control:
            self.on_control(ctrl.forward_mps, ctrl.turn_rate)
    def _log_ascii(
        self,
        prefix: str,
        payload: bytes,
        *,
        gps_time_s: float | None = None,
        source_rank: int = 0,
    ):
        if not self.run_logger or not payload:
            return
        text = payload.decode("utf-8", errors="replace").strip()
        self.run_logger.log(prefix, text, gps_time_s=gps_time_s, source_rank=source_rank)
    def _log_binary(
        self,
        prefix: str,
        payload: bytes,
        *,
        gps_time_s: float | None = None,
        source_rank: int = 0,
    ):
        if not self.run_logger or not payload:
            return
        self.run_logger.log(
            prefix, payload.hex(), gps_time_s=gps_time_s, source_rank=source_rank
        )
    def _log_text(
        self,
        prefix: str,
        text: str,
        *,
        gps_time_s: float | None = None,
        source_rank: int = 0,
    ):
        if not self.run_logger:
            return
        self.run_logger.log(prefix, text, gps_time_s=gps_time_s, source_rank=source_rank)
def _ensure_utc_datetime(dt: datetime) -> datetime:
    if dt.tzinfo is None:
        return dt.replace(tzinfo=timezone.utc)
    return dt.astimezone(timezone.utc)
def _seconds_of_day(dt: datetime) -> float:
    utc = _ensure_utc_datetime(dt)
    return (
        utc.hour * 3600
        + utc.minute * 60
        + utc.second
        + utc.microsecond / 1_000_000.0
    )
def _ascii_timestamp_to_seconds(raw: float | int | None) -> float | None:
    if raw is None:
        return None
    try:
        value = float(raw)
    except (TypeError, ValueError):
        return None
    text = f"{int(abs(value)):09d}"
    hh = int(text[0:2])
    mm = int(text[2:4])
    ss = int(text[4:6])
    ms = int(text[6:9])
    if 0 <= hh < 24 and 0 <= mm < 60 and 0 <= ss < 60:
        return hh * 3600 + mm * 60 + ss + ms / 1000.0
    return value / 1000.0
def _initial_timestamp_from_gga(gga: str) -> datetime:
    parts = (gga or "").split(",")