"""
|
硬件在环 (HITL) 仿真器:生成 $GPRMI/$GPIMU 传感器数据,并与 STM32H7 通过 PythonLink 闭环。
|
"""
|
|
from __future__ import annotations
|
|
import dataclasses
|
import math
|
import threading
|
import time
|
from datetime import datetime, timedelta, timezone
|
from typing import Callable
|
|
import serial
|
|
from . import geo
|
from .dynamics import DifferentialDriveModel, DifferentialDriveState
|
from .protocols import (
|
PythonLinkDecoder,
|
PythonLinkFrame,
|
build_gpimu_sentence,
|
build_gprmi_sentence,
|
)
|
|
|
@dataclasses.dataclass
|
class HitlConfig:
|
"""HITL 仿真运行必需的配置项。"""
|
|
uart2_port: str # STM32 UART2 (GPS/IMU 输入 + 控制输出)
|
uart5_port: str | None # STM32 UART5 日志输出
|
origin_gga: str
|
initial_enu: tuple[float, float, float] = (0.0, 0.0, 0.0)
|
initial_heading_deg: float = 0.0
|
gps_baudrate: int = 460800
|
log_baudrate: int = 921600
|
track_width: float = 0.78
|
baseline_distance: float = 0.9
|
max_linear_speed: float = 2.0
|
max_angular_speed_deg: float = 140.0
|
|
|
class SerialEndpoint:
|
"""线程安全的串口包装。"""
|
|
port: str | None
|
baudrate: int
|
timeout: float
|
_serial: serial.Serial | None
|
_lock: threading.Lock
|
|
def __init__(self, port: str | None, baudrate: int, timeout: float = 0.0):
|
self.port = port
|
self.baudrate = baudrate
|
self.timeout = timeout
|
self._serial = None
|
self._lock = threading.Lock()
|
|
def open(self):
|
if not self.port:
|
return
|
if self._serial and self._serial.is_open:
|
return
|
self._serial = serial.Serial(
|
self.port,
|
self.baudrate,
|
timeout=self.timeout,
|
write_timeout=0,
|
)
|
self._serial.reset_input_buffer()
|
self._serial.reset_output_buffer()
|
|
def close(self):
|
if self._serial:
|
try:
|
self._serial.close()
|
finally:
|
self._serial = None
|
|
def write(self, data: bytes):
|
if not data or not self._serial:
|
return
|
with self._lock:
|
_ = self._serial.write(data)
|
|
def read(self, size: int = 1) -> bytes:
|
if not self._serial:
|
return b""
|
try:
|
return self._serial.read(size)
|
except serial.SerialException:
|
return b""
|
|
def readline(self) -> bytes:
|
if not self._serial:
|
return b""
|
try:
|
return self._serial.readline()
|
except serial.SerialException:
|
return b""
|
|
|
class HitlSimulator:
|
"""电脑仿真环境与 STM32H7 控制板之间的桥梁。"""
|
|
config: HitlConfig
|
origin: geo.Origin
|
model: DifferentialDriveModel
|
_state_lock: threading.Lock
|
_latest_state: DifferentialDriveState
|
_target_linear: float
|
_target_angular: float
|
_sim_time: datetime
|
uart2: SerialEndpoint
|
log_uart: SerialEndpoint
|
_decoder: PythonLinkDecoder
|
_running: threading.Event
|
_threads: list[threading.Thread]
|
on_control: Callable[[float, float], None] | None
|
on_log: Callable[[str], None] | None
|
|
def __init__(self, config: HitlConfig):
|
self.config = config
|
self.origin = geo.parse_origin(config.origin_gga)
|
self.model = DifferentialDriveModel(
|
track_width=config.track_width,
|
max_linear_speed=config.max_linear_speed,
|
max_angular_speed=math.radians(config.max_angular_speed_deg),
|
)
|
self.model.reset(
|
east=config.initial_enu[0],
|
north=config.initial_enu[1],
|
up=config.initial_enu[2],
|
heading_deg=config.initial_heading_deg,
|
)
|
self._state_lock = threading.Lock()
|
self._latest_state = self.model.state.copy()
|
self._target_linear = 0.0
|
self._target_angular = 0.0
|
self._sim_time = _initial_timestamp_from_gga(config.origin_gga)
|
|
self.uart2 = SerialEndpoint(config.uart2_port, config.gps_baudrate, timeout=0.0)
|
self.log_uart = SerialEndpoint(config.uart5_port, config.log_baudrate, timeout=0.1)
|
|
self._decoder = PythonLinkDecoder(self._handle_control_frame)
|
self._running = threading.Event()
|
self._threads = []
|
|
self.on_control = None
|
self.on_log = None
|
|
# ------------------------------------------------------------------ #
|
# 生命周期
|
# ------------------------------------------------------------------ #
|
def start(self):
|
if self._running.is_set():
|
return
|
self.uart2.open()
|
self.log_uart.open()
|
self._running.set()
|
self._threads = [
|
threading.Thread(target=self._loop_physics, name="hitl-phys", daemon=True),
|
threading.Thread(target=self._loop_gprmi, name="hitl-gprmi", daemon=True),
|
threading.Thread(target=self._loop_gpimu, name="hitl-gpimu", daemon=True),
|
threading.Thread(target=self._loop_control, name="hitl-ctrl", daemon=True),
|
threading.Thread(target=self._loop_log, name="hitl-log", daemon=True),
|
]
|
for t in self._threads:
|
t.start()
|
|
def stop(self):
|
if not self._running.is_set():
|
return
|
self._running.clear()
|
for t in self._threads:
|
t.join(timeout=1.0)
|
self._threads.clear()
|
self.uart2.close()
|
self.log_uart.close()
|
|
# ------------------------------------------------------------------ #
|
# 线程
|
# ------------------------------------------------------------------ #
|
def _loop_physics(self):
|
last = time.perf_counter()
|
while self._running.is_set():
|
now = time.perf_counter()
|
dt = max(now - last, 1e-4)
|
last = now
|
with self._state_lock:
|
state = self.model.step(self._target_linear, self._target_angular, dt).copy()
|
self._latest_state = state
|
self._sim_time += timedelta(seconds=dt)
|
time.sleep(0.005)
|
|
def _loop_gprmi(self):
|
period = 0.1 # 10 Hz
|
while self._running.is_set():
|
start = time.perf_counter()
|
sentence = self._build_gprmi_sentence()
|
if sentence:
|
self.uart2.write(sentence)
|
self._sleep_remaining(start, period)
|
|
def _loop_gpimu(self):
|
period = 0.01 # 100 Hz
|
while self._running.is_set():
|
start = time.perf_counter()
|
sentence = self._build_gpimu_sentence()
|
if sentence:
|
self.uart2.write(sentence)
|
self._sleep_remaining(start, period)
|
|
def _loop_control(self):
|
while self._running.is_set():
|
data = self.uart2.read(128)
|
if data:
|
self._decoder.feed(data)
|
else:
|
time.sleep(0.002)
|
|
def _loop_log(self):
|
if not self.config.uart5_port:
|
while self._running.is_set():
|
time.sleep(0.5)
|
return
|
while self._running.is_set():
|
line = self.log_uart.readline()
|
if not line:
|
time.sleep(0.01)
|
continue
|
text = line.decode("utf-8", errors="replace").strip()
|
if text and self.on_log:
|
self.on_log(text)
|
|
# ------------------------------------------------------------------ #
|
# 构造帧
|
# ------------------------------------------------------------------ #
|
def _build_gprmi_sentence(self) -> bytes | None:
|
state, timestamp = self._snapshot()
|
lat, lon, alt = geo.enu_to_lla(state.east, state.north, state.up, self.origin)
|
heading_nav = geo.heading_math_to_nav(state.heading)
|
return build_gprmi_sentence(
|
timestamp=timestamp,
|
lat_deg=lat,
|
lon_deg=lon,
|
alt_m=alt,
|
east_vel=state.east_velocity,
|
north_vel=state.north_velocity,
|
up_vel=state.up_velocity,
|
heading_deg=heading_nav,
|
pitch_deg=state.pitch_deg,
|
roll_deg=state.roll_deg,
|
baseline_m=self.config.baseline_distance,
|
)
|
|
def _build_gpimu_sentence(self) -> bytes | None:
|
state, timestamp = self._snapshot()
|
return build_gpimu_sentence(
|
timestamp=timestamp,
|
accel_g=state.body_accel_g,
|
gyro_deg_s=state.gyro_deg_s,
|
temperature_c=state.temperature_c,
|
)
|
|
# ------------------------------------------------------------------ #
|
# 工具
|
# ------------------------------------------------------------------ #
|
def _snapshot(self) -> tuple[DifferentialDriveState, datetime]:
|
with self._state_lock:
|
return self._latest_state.copy(), self._sim_time
|
|
def _handle_control_frame(self, frame: PythonLinkFrame):
|
with self._state_lock:
|
self._target_linear = frame.forward
|
self._target_angular = frame.turn
|
if self.on_control:
|
self.on_control(frame.forward, frame.turn)
|
|
@staticmethod
|
def _sleep_remaining(start: float, period: float):
|
elapsed = time.perf_counter() - start
|
remaining = period - elapsed
|
if remaining > 0:
|
time.sleep(remaining)
|
|
|
def _initial_timestamp_from_gga(gga: str) -> datetime:
|
parts = (gga or "").split(",")
|
if len(parts) > 1 and parts[1]:
|
time_str = parts[1]
|
try:
|
hh = int(time_str[0:2])
|
mm = int(time_str[2:4])
|
ss = int(time_str[4:6])
|
frac = float("0." + time_str.split(".")[1]) if "." in time_str else 0.0
|
today = datetime.now(timezone.utc).date()
|
base = datetime(today.year, today.month, today.day, tzinfo=timezone.utc)
|
return base.replace(hour=hh, minute=mm, second=ss, microsecond=int(frac * 1_000_000))
|
except (ValueError, IndexError):
|
pass
|
return datetime.now(timezone.utc)
|