""" 硬件在环 (HITL) 仿真器:生成 IM23A fmin/fmim 传感器帧,并与 STM32H7 通过 PythonLink 闭环。 """ from __future__ import annotations import dataclasses import heapq import itertools import math import queue import threading import time from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Callable import serial from . import geo from .dynamics import DifferentialDriveModel, DifferentialDriveState from .protocols import ( ControlStatus, PoseStatus, PythonLinkDecoder, PythonLinkFrame, StackStatus, StateStatus, build_im23a_imu_frame, build_im23a_nav_frame, decode_control_status, decode_pose_status, decode_stack_status, decode_state_status, parse_ascii_message, ) GPS_EPOCH = datetime(1980, 1, 6, tzinfo=timezone.utc) class RunLogger: """异步日志器:按 GPS 时间优先排序写入;无 GPS 时间则回退到主机时间。""" def __init__(self, path: Path | str, flush_delay_ms: float = 500.0): self.path = Path(path) self.path.parent.mkdir(parents=True, exist_ok=True) self._file = self.path.open("w", encoding="utf-8") self._queue: queue.Queue = queue.Queue() self._buffer: list[ tuple[tuple[int, float], int, float, str, str] ] = [] self._stop = threading.Event() self._counter = itertools.count() self._flush_delay = flush_delay_ms / 1000.0 self._worker = threading.Thread(target=self._loop, name="hitl-runlog", daemon=True) self._worker.start() def log( self, prefix: str, content: str, *, gps_time_s: float | None = None, source_rank: int = 0, ): arrival = time.time() priority = 0 if gps_time_s is not None else 1 key_time = float(gps_time_s if gps_time_s is not None else arrival) self._queue.put( ( (priority, key_time, source_rank), next(self._counter), arrival, prefix, content, ) ) def close(self): self._stop.set() self._worker.join() self._flush_buffer(force=True) if not self._file.closed: self._file.close() def _loop(self): while not self._stop.is_set() or not self._queue.empty() or self._buffer: try: item = self._queue.get(timeout=0.05) heapq.heappush(self._buffer, item) except queue.Empty: pass self._flush_buffer() def _flush_buffer(self, force: bool = False): now = time.time() while self._buffer: key, _seq, arrival, prefix, content = self._buffer[0] if not force and (now - arrival) < self._flush_delay and not self._stop.is_set(): break heapq.heappop(self._buffer) self._write_line(key, prefix, content) def _write_line( self, key: tuple[int, float, int], prefix: str, content: str ): priority, timestamp_s, source_rank = key if priority == 0: ts_label = f"GPS={timestamp_s:09.3f}s" else: ts_label = f"HOST={timestamp_s:012.3f}s" line = f"[{ts_label} r={source_rank}] {prefix}: {content}\n" self._file.write(line) self._file.flush() @dataclasses.dataclass class HitlConfig: """HITL 仿真运行必需的配置项。""" uart2_port: str # STM32 UART2 (GPS/IMU 输入 + 控制输出) uart5_port: str | None # STM32 UART5 日志输出 origin_gga: str initial_enu: tuple[float, float, float] = (0.0, 0.0, 0.0) initial_heading_deg: float = 0.0 gps_baudrate: int = 115200 log_baudrate: int = 921600 track_width: float = 0.78 baseline_distance: float = 0.9 max_linear_speed: float = 2.0 max_angular_speed_deg: float = 140.0 position_quality: int = 4 class SerialEndpoint: """线程安全的串口包装。""" port: str | None baudrate: int timeout: float _serial: serial.Serial | None _lock: threading.Lock def __init__(self, port: str | None, baudrate: int, timeout: float = 0.0): self.port = port self.baudrate = baudrate self.timeout = timeout self._serial = None self._lock = threading.Lock() def open(self): if not self.port: return if self._serial and self._serial.is_open: return self._serial = serial.Serial( self.port, self.baudrate, timeout=self.timeout, write_timeout=0, ) self._serial.reset_input_buffer() self._serial.reset_output_buffer() def close(self): if self._serial: try: self._serial.close() finally: self._serial = None def write(self, data: bytes): if not data or not self._serial: return with self._lock: _ = self._serial.write(data) def read(self, size: int = 1) -> bytes: if not self._serial: return b"" try: return self._serial.read(size) except serial.SerialException: return b"" def readline(self) -> bytes: if not self._serial: return b"" try: return self._serial.readline() except serial.SerialException: return b"" class HitlSimulator: """电脑仿真环境与 STM32H7 控制板之间的桥梁。""" config: HitlConfig origin: geo.Origin model: DifferentialDriveModel _state_lock: threading.Lock _latest_state: DifferentialDriveState _target_linear: float _target_angular: float _sim_time: datetime uart2: SerialEndpoint log_uart: SerialEndpoint _decoder: PythonLinkDecoder _running: threading.Event _threads: list[threading.Thread] on_control: Callable[[float, float], None] | None on_log: Callable[[str], None] | None on_control_status: Callable[[ControlStatus], None] | None on_pose_status: Callable[[PoseStatus], None] | None on_state_status: Callable[[StateStatus], None] | None on_stack_status: Callable[[StackStatus], None] | None def __init__(self, config: HitlConfig, run_logger: RunLogger | None = None): self.config = config self.origin = geo.parse_origin(config.origin_gga) self.model = DifferentialDriveModel( track_width=config.track_width, max_linear_speed=config.max_linear_speed, max_angular_speed=math.radians(config.max_angular_speed_deg), ) self.model.reset( east=config.initial_enu[0], north=config.initial_enu[1], up=config.initial_enu[2], heading_deg=config.initial_heading_deg, ) self._state_lock = threading.Lock() self._latest_state = self.model.state.copy() self._target_linear = 0.0 self._target_angular = 0.0 self._sim_time = _initial_timestamp_from_gga(config.origin_gga) self.uart2 = SerialEndpoint(config.uart2_port, config.gps_baudrate, timeout=0.0) self.log_uart = SerialEndpoint(config.uart5_port, config.log_baudrate, timeout=0.1) self._decoder = PythonLinkDecoder(self._handle_control_frame) self._running = threading.Event() self._threads = [] self.run_logger = run_logger self.on_control = None self.on_log = None self.on_control_status = None self.on_pose_status = None self.on_state_status = None self.on_stack_status = None self.on_ascii = None # ------------------------------------------------------------------ # # 生命周期 # ------------------------------------------------------------------ # def start(self): if self._running.is_set(): return self.uart2.open() self.log_uart.open() self._running.set() self._threads = [ threading.Thread(target=self._loop_physics, name="hitl-phys", daemon=True), threading.Thread(target=self._loop_nav, name="hitl-nav", daemon=True), threading.Thread(target=self._loop_imu, name="hitl-imu", daemon=True), threading.Thread(target=self._loop_control, name="hitl-ctrl", daemon=True), threading.Thread(target=self._loop_log, name="hitl-log", daemon=True), ] for t in self._threads: t.start() def stop(self): if not self._running.is_set(): return self._running.clear() for t in self._threads: t.join(timeout=1.0) self._threads.clear() self.uart2.close() self.log_uart.close() # ------------------------------------------------------------------ # # 线程 # ------------------------------------------------------------------ # def _loop_physics(self): last = time.perf_counter() while self._running.is_set(): now = time.perf_counter() dt = max(now - last, 1e-4) last = now with self._state_lock: # 注意:这里的 self._target_linear/angular 是从 STM32 发回来的 PWM 转换后的估算值 # 实际上是 STM32 认为的命令,或者如果我们只收 PWM 的话,它就是实际控制量 # STM32 发送的是 target_mps 和 target_turn,或者是 PWM # 在 protocols.py 中,如果我们收到了 PWM,就会转换为速度 # 如果收到的是 target_mps,那就是 target_mps # 现在的逻辑是:Simulator 收到 PythonLinkFrame,里面可能是 PWM 反算的 velocity # 这样就构成了闭环:STM32 计算 PWM -> Python 接收 PWM -> 反算速度 -> 物理模型积分 -> 传感器数据 -> STM32 state = self.model.step(self._target_linear, self._target_angular, dt).copy() self._latest_state = state self._sim_time += timedelta(seconds=dt) time.sleep(0.005) def _loop_nav(self): period = 0.1 # 10 Hz while self._running.is_set(): start = time.perf_counter() frame, gps_time_s = self._build_nav_frame() if frame: self.uart2.write(frame) self._sleep_remaining(start, period) def _loop_imu(self): period = 0.01 # 100 Hz while self._running.is_set(): start = time.perf_counter() frame, gps_time_s = self._build_imu_frame() if frame: self.uart2.write(frame) self._sleep_remaining(start, period) def _loop_control(self): while self._running.is_set(): data = self.uart2.read(128) if data: self._log_binary( "STM32->PY UART2 CTRL_RAW", data, source_rank=1 ) self._decoder.feed(data) else: time.sleep(0.002) def _loop_log(self): if not self.config.uart5_port: while self._running.is_set(): time.sleep(0.5) return while self._running.is_set(): line = self.log_uart.readline() if not line: time.sleep(0.01) continue text = line.decode("utf-8", errors="replace").strip() if not text: continue handled = False msg = parse_ascii_message(text) if msg: ctrl = decode_control_status(msg) if ctrl and self.on_control_status: self.on_control_status(ctrl) self._log_text( "STM32 UART5 CTRL", text, gps_time_s=_ascii_timestamp_to_seconds(ctrl.timestamp_ms), source_rank=1, ) self._apply_ascii_control(ctrl) handled = True pose = decode_pose_status(msg) if pose and self.on_pose_status: self.on_pose_status(pose) self._log_text( "STM32 UART5 POSE", text, gps_time_s=_ascii_timestamp_to_seconds(pose.timestamp_ms), source_rank=1, ) handled = True state = decode_state_status(msg) if state and self.on_state_status: self.on_state_status(state) self._log_text( "STM32 UART5 STATE", text, gps_time_s=_ascii_timestamp_to_seconds(state.timestamp_ms), source_rank=1, ) handled = True stack = decode_stack_status(msg) if stack and self.on_stack_status: self.on_stack_status(stack) handled = True if not handled and self.on_log: self.on_log(text) if not handled: self._log_text("STM32 UART5", text, source_rank=1) # ------------------------------------------------------------------ # # 构造帧 # ------------------------------------------------------------------ # def _build_nav_frame(self) -> tuple[bytes | None, float]: state, timestamp = self._snapshot() lat, lon, alt = geo.enu_to_lla(state.east, state.north, state.up, self.origin) heading_nav_deg = geo.heading_math_to_nav(state.heading) # 将导航坐标系角度(0-360度)转换为弧度(-π到π),0度正北对应0弧度 # 0°(正北)→0rad, 90°(正东)→π/2rad, 180°(正南)→πrad, 270°(正西)→-π/2rad heading_nav_rad = math.radians(heading_nav_deg) if heading_nav_deg <= 180.0 else math.radians(heading_nav_deg - 360.0) gps_time_s = _seconds_of_day(timestamp) status_flags = int(self.config.position_quality) & 0xFF frame = build_im23a_nav_frame( timestamp=timestamp, lat_deg=lat, lon_deg=lon, alt_m=alt, east_vel=state.east_velocity, north_vel=state.north_velocity, up_vel=state.up_velocity, heading_rad=heading_nav_rad, pitch_deg=state.pitch_deg, roll_deg=state.roll_deg, accel_bias=(0.0, 0.0, 0.0), gyro_bias=(0.0, 0.0, 0.0), temperature_c=state.temperature_c, status_flags=status_flags, ) return frame, gps_time_s def _build_imu_frame(self) -> tuple[bytes | None, float]: state, timestamp = self._snapshot() gps_time_s = _seconds_of_day(timestamp) frame = build_im23a_imu_frame( timestamp=timestamp, accel_g=state.body_accel_g, gyro_deg_s=state.gyro_deg_s, ) return frame, gps_time_s # ------------------------------------------------------------------ # # 工具 # ------------------------------------------------------------------ # def _snapshot(self) -> tuple[DifferentialDriveState, datetime]: with self._state_lock: return self._latest_state.copy(), self._sim_time def _handle_control_frame(self, frame: PythonLinkFrame): with self._state_lock: self._target_linear = frame.forward self._target_angular = frame.turn if self.on_control: self.on_control(frame.forward, frame.turn) if self.run_logger: self.run_logger.log( "STM32->PY CTRL_FRAME", f"forward={frame.forward:.3f} turn={frame.turn:.3f} pwm={frame.steering_pwm}/{frame.throttle_pwm}", ) @staticmethod def _sleep_remaining(start: float, period: float): elapsed = time.perf_counter() - start remaining = period - elapsed if remaining > 0: time.sleep(remaining) # ------------------------------------------------------------------ # # 外部控制 # ------------------------------------------------------------------ # def update_origin(self, origin_gga: str): if not origin_gga: return try: new_origin = geo.parse_origin(origin_gga) except ValueError: return self.config.origin_gga = origin_gga self.origin = new_origin self._sim_time = _initial_timestamp_from_gga(origin_gga) def reset_state(self, east: float, north: float, up: float, heading_deg: float): with self._state_lock: self.model.reset(east=east, north=north, up=up, heading_deg=heading_deg) self._latest_state = self.model.state.copy() self._target_linear = 0.0 self._target_angular = 0.0 # ------------------------------------------------------------------ # # 日志工具 # ------------------------------------------------------------------ # def _apply_ascii_control(self, ctrl: ControlStatus): with self._state_lock: self._target_linear = ctrl.forward_mps self._target_angular = ctrl.turn_rate if self.on_control: self.on_control(ctrl.forward_mps, ctrl.turn_rate) def _log_ascii( self, prefix: str, payload: bytes, *, gps_time_s: float | None = None, source_rank: int = 0, ): if not self.run_logger or not payload: return text = payload.decode("utf-8", errors="replace").strip() self.run_logger.log(prefix, text, gps_time_s=gps_time_s, source_rank=source_rank) def _log_binary( self, prefix: str, payload: bytes, *, gps_time_s: float | None = None, source_rank: int = 0, ): if not self.run_logger or not payload: return self.run_logger.log( prefix, payload.hex(), gps_time_s=gps_time_s, source_rank=source_rank ) def _log_text( self, prefix: str, text: str, *, gps_time_s: float | None = None, source_rank: int = 0, ): if not self.run_logger: return self.run_logger.log(prefix, text, gps_time_s=gps_time_s, source_rank=source_rank) def _ensure_utc_datetime(dt: datetime) -> datetime: if dt.tzinfo is None: return dt.replace(tzinfo=timezone.utc) return dt.astimezone(timezone.utc) def _seconds_of_day(dt: datetime) -> float: utc = _ensure_utc_datetime(dt) return ( utc.hour * 3600 + utc.minute * 60 + utc.second + utc.microsecond / 1_000_000.0 ) def _ascii_timestamp_to_seconds(raw: float | int | None) -> float | None: if raw is None: return None try: value = float(raw) except (TypeError, ValueError): return None text = f"{int(abs(value)):09d}" hh = int(text[0:2]) mm = int(text[2:4]) ss = int(text[4:6]) ms = int(text[6:9]) if 0 <= hh < 24 and 0 <= mm < 60 and 0 <= ss < 60: return hh * 3600 + mm * 60 + ss + ms / 1000.0 return value / 1000.0 def _initial_timestamp_from_gga(gga: str) -> datetime: parts = (gga or "").split(",") if len(parts) > 1 and parts[1]: time_str = parts[1] try: hh = int(time_str[0:2]) mm = int(time_str[2:4]) ss = int(time_str[4:6]) frac = float("0." + time_str.split(".")[1]) if "." in time_str else 0.0 today = datetime.now(timezone.utc).date() base = datetime(today.year, today.month, today.day, tzinfo=timezone.utc) return base.replace(hour=hh, minute=mm, second=ss, microsecond=int(frac * 1_000_000)) except (ValueError, IndexError): pass return datetime.now(timezone.utc)