"""
|
硬件在环 (HITL) 仿真器:生成 IM23A fmin/fmim 传感器帧,并与 STM32H7 通过 PythonLink 闭环。
|
"""
|
|
from __future__ import annotations
|
|
import dataclasses
|
import heapq
|
import itertools
|
import math
|
import queue
|
import threading
|
import time
|
from datetime import datetime, timedelta, timezone
|
from pathlib import Path
|
from typing import Callable
|
|
import serial
|
|
from . import geo
|
from .dynamics import DifferentialDriveModel, DifferentialDriveState
|
from .protocols import (
|
ControlStatus,
|
PoseStatus,
|
PythonLinkDecoder,
|
PythonLinkFrame,
|
StackStatus,
|
StateStatus,
|
build_im23a_imu_frame,
|
build_im23a_nav_frame,
|
decode_control_status,
|
decode_pose_status,
|
decode_stack_status,
|
decode_state_status,
|
parse_ascii_message,
|
)
|
|
GPS_EPOCH = datetime(1980, 1, 6, tzinfo=timezone.utc)
|
|
|
class RunLogger:
|
"""异步日志器:按 GPS 时间优先排序写入;无 GPS 时间则回退到主机时间。"""
|
|
def __init__(self, path: Path | str, flush_delay_ms: float = 500.0):
|
self.path = Path(path)
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
self._file = self.path.open("w", encoding="utf-8")
|
self._queue: queue.Queue = queue.Queue()
|
self._buffer: list[
|
tuple[tuple[int, float], int, float, str, str]
|
] = []
|
self._stop = threading.Event()
|
self._counter = itertools.count()
|
self._flush_delay = flush_delay_ms / 1000.0
|
self._worker = threading.Thread(target=self._loop, name="hitl-runlog", daemon=True)
|
self._worker.start()
|
|
def log(
|
self,
|
prefix: str,
|
content: str,
|
*,
|
gps_time_s: float | None = None,
|
source_rank: int = 0,
|
):
|
arrival = time.time()
|
priority = 0 if gps_time_s is not None else 1
|
key_time = float(gps_time_s if gps_time_s is not None else arrival)
|
self._queue.put(
|
(
|
(priority, key_time, source_rank),
|
next(self._counter),
|
arrival,
|
prefix,
|
content,
|
)
|
)
|
|
def close(self):
|
self._stop.set()
|
self._worker.join()
|
self._flush_buffer(force=True)
|
if not self._file.closed:
|
self._file.close()
|
|
def _loop(self):
|
while not self._stop.is_set() or not self._queue.empty() or self._buffer:
|
try:
|
item = self._queue.get(timeout=0.05)
|
heapq.heappush(self._buffer, item)
|
except queue.Empty:
|
pass
|
self._flush_buffer()
|
|
def _flush_buffer(self, force: bool = False):
|
now = time.time()
|
while self._buffer:
|
key, _seq, arrival, prefix, content = self._buffer[0]
|
if not force and (now - arrival) < self._flush_delay and not self._stop.is_set():
|
break
|
heapq.heappop(self._buffer)
|
self._write_line(key, prefix, content)
|
|
def _write_line(
|
self, key: tuple[int, float, int], prefix: str, content: str
|
):
|
priority, timestamp_s, source_rank = key
|
if priority == 0:
|
ts_label = f"GPS={timestamp_s:09.3f}s"
|
else:
|
ts_label = f"HOST={timestamp_s:012.3f}s"
|
line = f"[{ts_label} r={source_rank}] {prefix}: {content}\n"
|
self._file.write(line)
|
self._file.flush()
|
|
|
|
@dataclasses.dataclass
|
class HitlConfig:
|
"""HITL 仿真运行必需的配置项。"""
|
|
uart2_port: str # STM32 UART2 (GPS/IMU 输入 + 控制输出)
|
uart5_port: str | None # STM32 UART5 日志输出
|
origin_gga: str
|
initial_enu: tuple[float, float, float] = (0.0, 0.0, 0.0)
|
initial_heading_deg: float = 0.0
|
gps_baudrate: int = 115200
|
log_baudrate: int = 921600
|
track_width: float = 0.78
|
baseline_distance: float = 0.9
|
max_linear_speed: float = 2.0
|
max_angular_speed_deg: float = 140.0
|
position_quality: int = 4
|
|
|
class SerialEndpoint:
|
"""线程安全的串口包装。"""
|
|
port: str | None
|
baudrate: int
|
timeout: float
|
_serial: serial.Serial | None
|
_lock: threading.Lock
|
|
def __init__(self, port: str | None, baudrate: int, timeout: float = 0.0):
|
self.port = port
|
self.baudrate = baudrate
|
self.timeout = timeout
|
self._serial = None
|
self._lock = threading.Lock()
|
|
def open(self):
|
if not self.port:
|
return
|
if self._serial and self._serial.is_open:
|
return
|
self._serial = serial.Serial(
|
self.port,
|
self.baudrate,
|
timeout=self.timeout,
|
write_timeout=0,
|
)
|
self._serial.reset_input_buffer()
|
self._serial.reset_output_buffer()
|
|
def close(self):
|
if self._serial:
|
try:
|
self._serial.close()
|
finally:
|
self._serial = None
|
|
def write(self, data: bytes):
|
if not data or not self._serial:
|
return
|
with self._lock:
|
_ = self._serial.write(data)
|
|
def read(self, size: int = 1) -> bytes:
|
if not self._serial:
|
return b""
|
try:
|
return self._serial.read(size)
|
except serial.SerialException:
|
return b""
|
|
def readline(self) -> bytes:
|
if not self._serial:
|
return b""
|
try:
|
return self._serial.readline()
|
except serial.SerialException:
|
return b""
|
|
|
class HitlSimulator:
|
"""电脑仿真环境与 STM32H7 控制板之间的桥梁。"""
|
|
config: HitlConfig
|
origin: geo.Origin
|
model: DifferentialDriveModel
|
_state_lock: threading.Lock
|
_latest_state: DifferentialDriveState
|
_target_linear: float
|
_target_angular: float
|
_sim_time: datetime
|
uart2: SerialEndpoint
|
log_uart: SerialEndpoint
|
_decoder: PythonLinkDecoder
|
_running: threading.Event
|
_threads: list[threading.Thread]
|
on_control: Callable[[float, float], None] | None
|
on_log: Callable[[str], None] | None
|
on_control_status: Callable[[ControlStatus], None] | None
|
on_pose_status: Callable[[PoseStatus], None] | None
|
on_state_status: Callable[[StateStatus], None] | None
|
on_stack_status: Callable[[StackStatus], None] | None
|
|
def __init__(self, config: HitlConfig, run_logger: RunLogger | None = None):
|
self.config = config
|
self.origin = geo.parse_origin(config.origin_gga)
|
self.model = DifferentialDriveModel(
|
track_width=config.track_width,
|
max_linear_speed=config.max_linear_speed,
|
max_angular_speed=math.radians(config.max_angular_speed_deg),
|
)
|
self.model.reset(
|
east=config.initial_enu[0],
|
north=config.initial_enu[1],
|
up=config.initial_enu[2],
|
heading_deg=config.initial_heading_deg,
|
)
|
self._state_lock = threading.Lock()
|
self._latest_state = self.model.state.copy()
|
self._target_linear = 0.0
|
self._target_angular = 0.0
|
self._sim_time = _initial_timestamp_from_gga(config.origin_gga)
|
|
self.uart2 = SerialEndpoint(config.uart2_port, config.gps_baudrate, timeout=0.0)
|
self.log_uart = SerialEndpoint(config.uart5_port, config.log_baudrate, timeout=0.1)
|
|
self._decoder = PythonLinkDecoder(self._handle_control_frame)
|
self._running = threading.Event()
|
self._threads = []
|
self.run_logger = run_logger
|
|
self.on_control = None
|
self.on_log = None
|
self.on_control_status = None
|
self.on_pose_status = None
|
self.on_state_status = None
|
self.on_stack_status = None
|
self.on_ascii = None
|
|
# ------------------------------------------------------------------ #
|
# 生命周期
|
# ------------------------------------------------------------------ #
|
def start(self):
|
if self._running.is_set():
|
return
|
self.uart2.open()
|
self.log_uart.open()
|
self._running.set()
|
self._threads = [
|
threading.Thread(target=self._loop_physics, name="hitl-phys", daemon=True),
|
threading.Thread(target=self._loop_nav, name="hitl-nav", daemon=True),
|
threading.Thread(target=self._loop_imu, name="hitl-imu", daemon=True),
|
threading.Thread(target=self._loop_control, name="hitl-ctrl", daemon=True),
|
threading.Thread(target=self._loop_log, name="hitl-log", daemon=True),
|
]
|
for t in self._threads:
|
t.start()
|
|
def stop(self):
|
if not self._running.is_set():
|
return
|
self._running.clear()
|
for t in self._threads:
|
t.join(timeout=1.0)
|
self._threads.clear()
|
self.uart2.close()
|
self.log_uart.close()
|
|
# ------------------------------------------------------------------ #
|
# 线程
|
# ------------------------------------------------------------------ #
|
def _loop_physics(self):
|
last = time.perf_counter()
|
while self._running.is_set():
|
now = time.perf_counter()
|
dt = max(now - last, 1e-4)
|
last = now
|
with self._state_lock:
|
# 注意:这里的 self._target_linear/angular 是从 STM32 发回来的 PWM 转换后的估算值
|
# 实际上是 STM32 认为的命令,或者如果我们只收 PWM 的话,它就是实际控制量
|
# STM32 发送的是 target_mps 和 target_turn,或者是 PWM
|
# 在 protocols.py 中,如果我们收到了 PWM,就会转换为速度
|
# 如果收到的是 target_mps,那就是 target_mps
|
# 现在的逻辑是:Simulator 收到 PythonLinkFrame,里面可能是 PWM 反算的 velocity
|
# 这样就构成了闭环:STM32 计算 PWM -> Python 接收 PWM -> 反算速度 -> 物理模型积分 -> 传感器数据 -> STM32
|
state = self.model.step(self._target_linear, self._target_angular, dt).copy()
|
self._latest_state = state
|
self._sim_time += timedelta(seconds=dt)
|
time.sleep(0.005)
|
|
def _loop_nav(self):
|
period = 0.1 # 10 Hz
|
while self._running.is_set():
|
start = time.perf_counter()
|
frame, gps_time_s = self._build_nav_frame()
|
if frame:
|
self.uart2.write(frame)
|
self._sleep_remaining(start, period)
|
|
def _loop_imu(self):
|
period = 0.01 # 100 Hz
|
while self._running.is_set():
|
start = time.perf_counter()
|
frame, gps_time_s = self._build_imu_frame()
|
if frame:
|
self.uart2.write(frame)
|
self._sleep_remaining(start, period)
|
|
def _loop_control(self):
|
while self._running.is_set():
|
data = self.uart2.read(128)
|
if data:
|
self._log_binary(
|
"STM32->PY UART2 CTRL_RAW", data, source_rank=1
|
)
|
self._decoder.feed(data)
|
else:
|
time.sleep(0.002)
|
|
def _loop_log(self):
|
if not self.config.uart5_port:
|
while self._running.is_set():
|
time.sleep(0.5)
|
return
|
while self._running.is_set():
|
line = self.log_uart.readline()
|
if not line:
|
time.sleep(0.01)
|
continue
|
text = line.decode("utf-8", errors="replace").strip()
|
if not text:
|
continue
|
handled = False
|
msg = parse_ascii_message(text)
|
if msg:
|
ctrl = decode_control_status(msg)
|
if ctrl and self.on_control_status:
|
self.on_control_status(ctrl)
|
self._log_text(
|
"STM32 UART5 CTRL",
|
text,
|
gps_time_s=_ascii_timestamp_to_seconds(ctrl.timestamp_ms),
|
source_rank=1,
|
)
|
self._apply_ascii_control(ctrl)
|
handled = True
|
pose = decode_pose_status(msg)
|
if pose and self.on_pose_status:
|
self.on_pose_status(pose)
|
self._log_text(
|
"STM32 UART5 POSE",
|
text,
|
gps_time_s=_ascii_timestamp_to_seconds(pose.timestamp_ms),
|
source_rank=1,
|
)
|
handled = True
|
state = decode_state_status(msg)
|
if state and self.on_state_status:
|
self.on_state_status(state)
|
self._log_text(
|
"STM32 UART5 STATE",
|
text,
|
gps_time_s=_ascii_timestamp_to_seconds(state.timestamp_ms),
|
source_rank=1,
|
)
|
handled = True
|
stack = decode_stack_status(msg)
|
if stack and self.on_stack_status:
|
self.on_stack_status(stack)
|
handled = True
|
if not handled and self.on_log:
|
self.on_log(text)
|
if not handled:
|
self._log_text("STM32 UART5", text, source_rank=1)
|
|
# ------------------------------------------------------------------ #
|
# 构造帧
|
# ------------------------------------------------------------------ #
|
def _build_nav_frame(self) -> tuple[bytes | None, float]:
|
state, timestamp = self._snapshot()
|
lat, lon, alt = geo.enu_to_lla(state.east, state.north, state.up, self.origin)
|
heading_nav_deg = geo.heading_math_to_nav(state.heading)
|
# 将导航坐标系角度(0-360度)转换为弧度(-π到π),0度正北对应0弧度
|
# 0°(正北)→0rad, 90°(正东)→π/2rad, 180°(正南)→πrad, 270°(正西)→-π/2rad
|
heading_nav_rad = math.radians(heading_nav_deg) if heading_nav_deg <= 180.0 else math.radians(heading_nav_deg - 360.0)
|
gps_time_s = _seconds_of_day(timestamp)
|
status_flags = int(self.config.position_quality) & 0xFF
|
frame = build_im23a_nav_frame(
|
timestamp=timestamp,
|
lat_deg=lat,
|
lon_deg=lon,
|
alt_m=alt,
|
east_vel=state.east_velocity,
|
north_vel=state.north_velocity,
|
up_vel=state.up_velocity,
|
heading_rad=heading_nav_rad,
|
pitch_deg=state.pitch_deg,
|
roll_deg=state.roll_deg,
|
accel_bias=(0.0, 0.0, 0.0),
|
gyro_bias=(0.0, 0.0, 0.0),
|
temperature_c=state.temperature_c,
|
status_flags=status_flags,
|
)
|
return frame, gps_time_s
|
|
def _build_imu_frame(self) -> tuple[bytes | None, float]:
|
state, timestamp = self._snapshot()
|
gps_time_s = _seconds_of_day(timestamp)
|
frame = build_im23a_imu_frame(
|
timestamp=timestamp,
|
accel_g=state.body_accel_g,
|
gyro_deg_s=state.gyro_deg_s,
|
)
|
return frame, gps_time_s
|
|
# ------------------------------------------------------------------ #
|
# 工具
|
# ------------------------------------------------------------------ #
|
def _snapshot(self) -> tuple[DifferentialDriveState, datetime]:
|
with self._state_lock:
|
return self._latest_state.copy(), self._sim_time
|
|
def _handle_control_frame(self, frame: PythonLinkFrame):
|
with self._state_lock:
|
self._target_linear = frame.forward
|
self._target_angular = frame.turn
|
if self.on_control:
|
self.on_control(frame.forward, frame.turn)
|
if self.run_logger:
|
self.run_logger.log(
|
"STM32->PY CTRL_FRAME",
|
f"forward={frame.forward:.3f} turn={frame.turn:.3f} pwm={frame.steering_pwm}/{frame.throttle_pwm}",
|
)
|
|
@staticmethod
|
def _sleep_remaining(start: float, period: float):
|
elapsed = time.perf_counter() - start
|
remaining = period - elapsed
|
if remaining > 0:
|
time.sleep(remaining)
|
|
# ------------------------------------------------------------------ #
|
# 外部控制
|
# ------------------------------------------------------------------ #
|
def update_origin(self, origin_gga: str):
|
if not origin_gga:
|
return
|
try:
|
new_origin = geo.parse_origin(origin_gga)
|
except ValueError:
|
return
|
self.config.origin_gga = origin_gga
|
self.origin = new_origin
|
self._sim_time = _initial_timestamp_from_gga(origin_gga)
|
|
def reset_state(self, east: float, north: float, up: float, heading_deg: float):
|
with self._state_lock:
|
self.model.reset(east=east, north=north, up=up, heading_deg=heading_deg)
|
self._latest_state = self.model.state.copy()
|
self._target_linear = 0.0
|
self._target_angular = 0.0
|
|
# ------------------------------------------------------------------ #
|
# 日志工具
|
# ------------------------------------------------------------------ #
|
def _apply_ascii_control(self, ctrl: ControlStatus):
|
with self._state_lock:
|
self._target_linear = ctrl.forward_mps
|
self._target_angular = ctrl.turn_rate
|
if self.on_control:
|
self.on_control(ctrl.forward_mps, ctrl.turn_rate)
|
|
def _log_ascii(
|
self,
|
prefix: str,
|
payload: bytes,
|
*,
|
gps_time_s: float | None = None,
|
source_rank: int = 0,
|
):
|
if not self.run_logger or not payload:
|
return
|
text = payload.decode("utf-8", errors="replace").strip()
|
self.run_logger.log(prefix, text, gps_time_s=gps_time_s, source_rank=source_rank)
|
|
def _log_binary(
|
self,
|
prefix: str,
|
payload: bytes,
|
*,
|
gps_time_s: float | None = None,
|
source_rank: int = 0,
|
):
|
if not self.run_logger or not payload:
|
return
|
self.run_logger.log(
|
prefix, payload.hex(), gps_time_s=gps_time_s, source_rank=source_rank
|
)
|
|
def _log_text(
|
self,
|
prefix: str,
|
text: str,
|
*,
|
gps_time_s: float | None = None,
|
source_rank: int = 0,
|
):
|
if not self.run_logger:
|
return
|
self.run_logger.log(prefix, text, gps_time_s=gps_time_s, source_rank=source_rank)
|
|
def _ensure_utc_datetime(dt: datetime) -> datetime:
|
if dt.tzinfo is None:
|
return dt.replace(tzinfo=timezone.utc)
|
return dt.astimezone(timezone.utc)
|
|
|
def _seconds_of_day(dt: datetime) -> float:
|
utc = _ensure_utc_datetime(dt)
|
return (
|
utc.hour * 3600
|
+ utc.minute * 60
|
+ utc.second
|
+ utc.microsecond / 1_000_000.0
|
)
|
|
|
def _ascii_timestamp_to_seconds(raw: float | int | None) -> float | None:
|
if raw is None:
|
return None
|
try:
|
value = float(raw)
|
except (TypeError, ValueError):
|
return None
|
text = f"{int(abs(value)):09d}"
|
hh = int(text[0:2])
|
mm = int(text[2:4])
|
ss = int(text[4:6])
|
ms = int(text[6:9])
|
if 0 <= hh < 24 and 0 <= mm < 60 and 0 <= ss < 60:
|
return hh * 3600 + mm * 60 + ss + ms / 1000.0
|
return value / 1000.0
|
|
|
def _initial_timestamp_from_gga(gga: str) -> datetime:
|
parts = (gga or "").split(",")
|
if len(parts) > 1 and parts[1]:
|
time_str = parts[1]
|
try:
|
hh = int(time_str[0:2])
|
mm = int(time_str[2:4])
|
ss = int(time_str[4:6])
|
frac = float("0." + time_str.split(".")[1]) if "." in time_str else 0.0
|
today = datetime.now(timezone.utc).date()
|
base = datetime(today.year, today.month, today.day, tzinfo=timezone.utc)
|
return base.replace(hour=hh, minute=mm, second=ss, microsecond=int(frac * 1_000_000))
|
except (ValueError, IndexError):
|
pass
|
return datetime.now(timezone.utc)
|