commit e3be8861b87b27abeae9dfa4bec961335698ebb7 Author: YikaiFu-cart Date: Wed May 6 10:47:14 2026 +0800 initial diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a94acc6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +# Python 缓存文件 +__pycache__/ +*.py[cod] +*.class +*.pyc +*.cache + +# IDE 配置文件(可选,若不需要提交 IDE 配置) +.idea/ + +*.csv + +ui_settings.json + +# VSCode 配置 +.vscode/ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..48547b2 --- /dev/null +++ b/README.md @@ -0,0 +1,205 @@ +# TomatoPick 番茄采摘系统 + +TomatoPick 是一个面向番茄自动采摘实验的 Python 项目,集成了 Tkinter 参数界面、Intel RealSense RGB-D 相机、YOLO pose 目标检测、AUBO 机械臂控制和 AGV 移动平台控制。 + +当前仓库是运行端工程,核心代码集中在 `main.py`、`control.py` 和 `control_core.py`,辅助脚本位于 `tools/`。 + +## 项目结构 + +```text +TomatoPick/ +├── main.py # Tkinter 图形界面入口 +├── control.py # UI 与核心逻辑之间的配置转发/兼容入口 +├── control_core.py # 相机、YOLO、机械臂、AGV 与采摘状态机 +├── ui_settings.json # UI 参数缓存,运行后自动保存/更新 +├── tools/ +│ ├── realsense_record_video.py # RealSense 录制辅助脚本 +│ ├── video_to_rgb_frames.py # 视频抽帧辅助脚本 +│ ├── aubo_joint_position.py # AUBO 关节位置辅助脚本 +│ └── 1.png # UI 相机区域背景图 +└── README.md +``` + +## 功能概览 + +- 通过图形界面配置机械臂 IP、AGV IP、运行时长、YOLO 模型路径、放置位置等参数。 +- 支持完整采摘流程:AGV 巡检、视觉检测、停车、精确定位、机械臂抓取、放置、回 Home、AGV 继续前进。 +- 支持“视觉测试”模式,只启动 RealSense 和 YOLO 推理,不启动机械臂和 AGV,适合先调试模型与相机画面。 +- 将运行日志、终端输出和相机检测画面实时显示在 UI 中。 +- 通过 `ui_settings.json` 保存上一次 UI 参数,方便下次启动恢复。 + +## 核心文件说明 + +### `main.py` + +图形界面入口,直接运行即可打开参数配置界面。 + +主要职责: + +- 构建 Tkinter 主界面,包括参数区、运行日志、相机画面和终端输出。 +- 读取和保存 `ui_settings.json`。 +- 选择 YOLO `.pt` 模型文件和相机背景图。 +- 启动完整采摘流程,内部调用 `control.main()`。 +- 启动视觉测试流程,只初始化 RealSense 与 YOLO。 +- 接收 `control_core.py` 推送的相机帧并叠加检测信息。 + +### `control.py` + +兼容入口和配置桥接层。UI 或外部脚本会先修改这里的模块级参数,然后由 `_sync_config()` 同步到 `control_core.configure()`。 + +常用参数包括: + +| 参数 | 说明 | +| --- | --- | +| `ROBOT_IP` / `ROBOT_PORT` | AUBO 机械臂连接地址 | +| `AGV_IP` / `AGV_PORT` | AGV 连接地址 | +| `AGV_SPEED_FORWARD` | AGV 前进速度 | +| `TOTAL_DURATION` | 系统总运行时长,单位秒 | +| `AGV_STOP_TIMEOUT` | AGV 停车后等待采摘的超时时间 | +| `YOLO_MODEL_PATH` | YOLO pose 模型路径,默认 `best.pt` | +| `YOLO_DETECT_CONF` | YOLO 推理置信度阈值 | +| `PICK_CONFIDENCE_THRESHOLD` | 进入采摘逻辑的目标置信度阈值 | +| `HOME_JOINTS` | 机械臂 Home 关节位 | +| `place_positions` | 番茄放置位 | +| `SCISSORS_ENABLED` | 是否启用末端剪刀/夹爪动作 | + +### `control_core.py` + +核心业务逻辑,负责设备连接、视觉检测、坐标转换、机械臂动作、AGV 控制和资源释放。 + +主要模块: + +- `PlacementManager`:管理放置位。 +- `DetectedTomato` / `TomatoCandidate`:保存检测目标、关键点、深度和坐标信息。 +- `RobotArmController`:封装机械臂回 Home、移动到预抓取点、抓取点和放置点等动作。 +- `VisionController`:封装 RealSense 采集、YOLO pose 检测、关键点筛选、深度读取、坐标转换和抓取触发。 +- `AgvController`:封装 AGV 获取控制权、前进、停止和清理。 + +核心状态由 3 个 `threading.Event` 协调: + +| 状态 | 说明 | +| --- | --- | +| `running` | 系统总运行开关 | +| `has_tomato` | 视觉线程发现候选番茄后置位,通知 AGV 停车 | +| `picking_done` | 本轮抓取放置完成后置位,允许 AGV 继续前进 | + +简化流程: + +```text +连接机械臂 + -> 连接 AGV 并获取控制权 + -> 初始化 RealSense + -> 加载 YOLO 模型 + -> 启动视觉线程和 AGV 线程 + -> AGV 前进巡检 + -> 视觉发现成熟番茄 + -> AGV 停止 + -> 视觉计算采摘点和抓取角度 + -> 机械臂抓取并放置 + -> 机械臂返回 Home + -> AGV 继续前进 +``` + +## 环境依赖 + +建议在 Windows 环境下运行,并提前安装设备 SDK。 + +主要 Python 依赖: + +```text +opencv-python +numpy +Pillow +ultralytics +pyrealsense2 +pyaubo_sdk +pyaubo_agvc_sdk +``` + +`tkinter`、`threading`、`ctypes`、`signal` 等为 Python 标准库或随 Python 发行版提供。 + +可参考安装方式: + +```bash +pip install opencv-python numpy Pillow ultralytics pyrealsense2 +``` + +`pyaubo_sdk` 和 `pyaubo_agvc_sdk` 通常需要按设备厂商提供的 SDK 包安装。 + +## 启动方式 + +### 1. 启动图形界面 + +推荐使用 UI 启动: + +```bash +python main.py +``` + +界面中可执行: + +- `启动程序`:运行完整采摘流程,会连接机械臂、AGV、RealSense 并加载 YOLO。 +- `视觉测试`:只测试相机采集和 YOLO 检测,不启动机械臂与 AGV。 +- `停止程序`:停止当前任务并释放资源。 +- `保存参数`:将当前 UI 参数写入 `ui_settings.json`。 + +### 2. 直接启动核心流程 + +如果不需要 UI,也可以直接运行: + +```bash +python control.py +``` + +这种方式会使用 `control.py` 顶部定义的默认参数,不读取 UI 输入。 + +## 模型文件 + +默认模型路径为: + +```text +best.pt +``` + +可以将模型文件放在项目根目录,也可以在 UI 中选择任意 `.pt` 文件。使用绝对路径时,请确认路径在当前机器上存在。 + +当前 `ui_settings.json` 中保存的模型路径和背景图路径可能来自其他电脑或旧目录,首次运行前建议在 UI 中重新选择。 + +## 参数保存与传递 + +UI 参数传递流程: + +```text +main.py + -> 保存到 ui_settings.json + -> 写入 control.py 模块变量 + -> control.py 调用 _sync_config() + -> control_core.configure() + -> control_core.main() +``` + +因此: + +- 从 UI 启动时,以界面中的参数为准。 +- 直接运行 `control.py` 时,以 `control.py` 中写死的默认参数为准。 +- 如果修改了 `control.py` 的默认值,需要重新启动程序才会生效。 + +## 运行前检查 + +完整流程会控制真实硬件,启动前请确认: + +- AUBO 机械臂、AGV、RealSense 相机均已连接,并能被对应 SDK 访问。 +- `ROBOT_IP`、`AGV_IP`、端口号与现场设备一致。 +- `HOME_JOINTS`、手眼标定矩阵 `R_tc` / `T_tc`、放置位 `place_positions` 已按现场设备校准。 +- YOLO 模型路径正确,且模型类别与关键点定义符合当前采摘逻辑。 +- 调试视觉效果时优先使用 `视觉测试`,确认画面和检测结果稳定后再运行完整采摘流程。 + +## 辅助脚本 + +`tools/` 目录提供了一些调试脚本: + +- `realsense_record_video.py`:录制 RealSense 视频。 +- `video_to_rgb_frames.py`:从视频中导出 RGB 图片帧。 +- `aubo_joint_position.py`:辅助读取/查看 AUBO 机械臂关节位置。 + +这些脚本主要用于采集数据、调试相机和校准机械臂位姿。 diff --git a/control.py b/control.py new file mode 100644 index 0000000..2b3ff7f --- /dev/null +++ b/control.py @@ -0,0 +1,243 @@ +import signal +from typing import Any, Sequence + +import control_core as _core + +""" +对外兼容入口。 +这个文件主要做两件事: +1. 暴露和旧版本一致的函数名,方便 UI 或外部脚本继续调用; +2. 在真正进入核心逻辑前,把这里可修改的配置同步到 `control_core.py`。 +""" + + +# -------------------- 设备连接参数 -------------------- +ROBOT_IP = "192.168.192.100" +ROBOT_PORT = 30004 +M_PI = _core.M_PI + +AGV_IP = "192.168.192.100" +AGV_PORT = 30104 +AGV_SPEED_FORWARD = 0 +AGV_SPEED_STOP = 0.0 +AGV_RUN_DISTANCE = 5.0 + + +# -------------------- 运行时长与检测参数 -------------------- +TOTAL_DURATION = 300 +AGV_STOP_TIMEOUT = 10 +AGV_PICK_SETTLE_DELAY = 1.0 +YOLO_MODEL_PATH = "best.pt" +YOLO_DETECT_CONF = 0.5 +PICK_CONFIDENCE_THRESHOLD = 0.7 + + +# -------------------- 位姿与轨迹参数 -------------------- +# Home 是机械臂的初始/复位关节位。 +HOME_JOINTS = [-1.5247, 1.0899, 2.4671, 1.2761, 1.5113, 0.0001] + +# 视觉算出的目标点会再叠加这两个偏移: +# APPROACH_Y_OFFSET 用于先到预抓取点; +# LIFT_Z_OFFSET 用于修正真正执行抓取时的高度。 +APPROACH_Y_OFFSET = 0.15 +LIFT_Z_OFFSET = 0.035 +PICKUP_X_OFFSET = 0.0 +PICKUP_Y_OFFSET = -0.015 +PICK_ANGLE_ORIENTATION_INDEX = 4 +PICK_ANGLE_SIGN = 1.0 +MAX_PICK_ANGLE_DEG = 30.0 + +# 当前项目只使用一个放置位,仍保留列表结构以兼容 control_core。 +place_positions = [[-0.6283, 0.0354, -0.0889, 2.8798, 0.02, -1.4346]] + + +# -------------------- 夹爪与机械臂运动参数 -------------------- +GRIPPER_CLOSE_IO = 1 +GRIPPER_OPEN_IO = 0 +GRIPPER_ACTION_DELAY = 2.0 +SCISSORS_ENABLED = True + +# 这里沿用底层 SDK 的弧度制。 +ARM_SPEED = 100 * (M_PI / 180) +ARM_ACCEL = 100 * (M_PI / 180) + + +# -------------------- 参数同步清单 -------------------- +_CONFIG_NAMES = ( + "ROBOT_IP", + "ROBOT_PORT", + "AGV_IP", + "AGV_PORT", + "AGV_SPEED_FORWARD", + "AGV_SPEED_STOP", + "AGV_RUN_DISTANCE", + "TOTAL_DURATION", + "AGV_STOP_TIMEOUT", + "AGV_PICK_SETTLE_DELAY", + "YOLO_MODEL_PATH", + "YOLO_DETECT_CONF", + "PICK_CONFIDENCE_THRESHOLD", + "HOME_JOINTS", + "APPROACH_Y_OFFSET", + "LIFT_Z_OFFSET", + "PICKUP_X_OFFSET", + "PICKUP_Y_OFFSET", + "PICK_ANGLE_ORIENTATION_INDEX", + "PICK_ANGLE_SIGN", + "MAX_PICK_ANGLE_DEG", + "GRIPPER_CLOSE_IO", + "GRIPPER_OPEN_IO", + "GRIPPER_ACTION_DELAY", + "SCISSORS_ENABLED", + "ARM_SPEED", + "ARM_ACCEL", + "place_positions", +) + + +# -------------------- 运行状态别名 -------------------- +running = _core.running +has_tomato = _core.has_tomato +picking_done = _core.picking_done +robot_rpc_client = _core.robot_rpc_client +ui_callback = _core.ui_callback + + +def _sync_config() -> None: + # 如果 UI 或外部脚本只改了 control.py,这里会在进入核心逻辑前同步到 control_core。 + config = {name: globals()[name] for name in _CONFIG_NAMES} + if place_positions and isinstance(place_positions[0], (list, tuple)): + config["place_positions"] = [list(position) for position in place_positions] + else: + config["place_positions"] = [list(place_positions)] + config["HOME_JOINTS"] = list(HOME_JOINTS) + _core.configure(config) + + +def set_ui_callback(callback): + global ui_callback + ui_callback = callback + _core.set_ui_callback(callback) + + +def exampleState(robot_name: str): + return _core.exampleState(robot_name) + + +def get_robot_end_effector_pose(robot_name: str): + return _core.get_robot_end_effector_pose(robot_name) + + +def exampleInverseK(robot_name: str, pose: Sequence[float], reference_q: Sequence[float]): + return _core.exampleInverseK(robot_name, pose, reference_q) + + +def exampleStartup(): + _sync_config() + return _core.exampleStartup() + + +def quaternion_to_rotation_matrix(q: Sequence[float]): + return _core.quaternion_to_rotation_matrix(q) + + +def euler_to_rotation_matrix(roll: float, pitch: float, yaw: float): + return _core.euler_to_rotation_matrix(roll, pitch, yaw) + + +def get_robot_end_matrix(robot_name: str): + return _core.get_robot_end_matrix(robot_name) + + +def camera_to_base(point_cam: Sequence[float], robot_name: str): + return _core.camera_to_base(point_cam, robot_name) + + +def check_joint_position(robot_name: str, target_joints: Sequence[float]): + return _core.check_joint_position(robot_name, target_joints) + + +def check_tcp_pose(robot_name: str, target_pose: Sequence[float]): + return _core.check_tcp_pose(robot_name, target_pose) + + +def waitArrival(impl: Any, target_joints=None, target_pose=None): + return _core.waitArrival(impl, target_joints=target_joints, target_pose=target_pose) + + +def control_tool_io(robot_name: str, io_index: int, state: bool): + return _core.control_tool_io(robot_name, io_index, state) + + +def get_next_placement(): + _sync_config() + return _core.get_next_placement() + + +def get_placement_index(): + return _core.get_placement_index() + + +def create_placement_manager(): + _sync_config() + return _core.create_placement_manager() + + +def return_to_home(robot_name: str): + _sync_config() + return _core.return_to_home(robot_name) + + +def control_robot(robot_name: str, pose: Sequence[float], angle_rad: float): + _sync_config() + return _core.control_robot(robot_name, pose, angle_rad) + + +def init_camera(): + _sync_config() + return _core.init_camera() + + +def init_tomato_detector(model_path=None): + _sync_config() + return _core.init_tomato_detector(model_path) + + +def filter_ripe_tomatoes(results, confidence_threshold=None): + _sync_config() + return _core.filter_ripe_tomatoes(results, confidence_threshold) + + +def draw_tomato_tilt_line(color_image, x1, y1, x2, y2, pixel_x): + return _core.draw_tomato_tilt_line(color_image, x1, y1, x2, y2, pixel_x) + + +def fit_red_line(crop_image, x1, y1): + return _core.fit_red_line(crop_image, x1, y1) + + +def vision_detection_thread(pipeline, align, depth_intrinsics, model, robot_name: str): + _sync_config() + return _core.vision_detection_thread(pipeline, align, depth_intrinsics, model, robot_name) + + +def agv_control_thread(agv_client): + _sync_config() + return _core.agv_control_thread(agv_client) + + +def main(): + # control.py 本身只做兼容和配置转发,真正业务流程在 control_core.main()。 + _sync_config() + return _core.main() + + +def signal_handler(sig: int, frame: Any): + _sync_config() + return _core.signal_handler(sig, frame) + + +if __name__ == "__main__": + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + main() diff --git a/control_core.py b/control_core.py new file mode 100644 index 0000000..8110cb7 --- /dev/null +++ b/control_core.py @@ -0,0 +1,1775 @@ +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +import math +import os +import sys +import threading +import time + +import cv2 +import numpy as np +import pyrealsense2 as rs +import pyaubo_sdk +from pyaubo_agvc_sdk import RpcClient +from ultralytics import YOLO + +""" +番茄采摘控制核心实现。 + +这个模块只负责内部逻辑拆分,`control.py` 作为兼容层保留给 UI 调用。 +""" + + +M_PI = math.pi + +# 可修改的设备、运行、位姿、夹爪参数统一放在 control.py 顶部。 +# control_core.py 只保留核心内部常量,避免出现第二套可调默认值。 + + +# -------------------- 外部配置缓存 -------------------- +_runtime_config: Dict[str, Any] = {} + + +PICK_ZONE_LEFT_RATIO = 1.0 / 3.0 +PICK_ZONE_RIGHT_RATIO = 2.0 / 3.0 +PICK_KEYPOINT_CONF_THRESHOLD = 0.8 # 关键点置信度低于该值时不参与采摘。 +PICK_CUTPOINT_KEYPOINT_INDEX = 0 # YOLO pose 第 0 个点为采摘点 cutpoint。 +PICK_ENDPOINT_KEYPOINT_INDEX = 1 # YOLO pose 第 1 个点为方向端点 endpoint。 +PICK_KEYPOINT_BOX_MARGIN_RATIO = 0.15 # 关键点允许略超出 bbox 的比例。 +PICK_CUTPOINT_MAX_REL_Y = 0.70 # 采摘点在 bbox 内的最大相对高度,防止误取下方点。 +PICK_KEYPOINT_MIN_DISTANCE_PX = 8.0 # cutpoint 与 endpoint 的最小像素距离。 +PICK_ENDPOINT_Y_TOLERANCE_PX = 4 # endpoint 允许比 cutpoint 略高的像素容差。 +PICK_DEPTH_WINDOW_RADIUS = 5 # 首次稳健取深度的窗口半径,5 表示 11x11。 +PICK_DEPTH_FALLBACK_WINDOW_RADIUS = 10 # 首次深度不足时的扩大窗口半径,10 表示 21x21。 +PICK_DEPTH_MIN_VALID_PIXELS = 5 # 深度窗口内至少需要的有效深度点数量。 +PICK_DEPTH_MIN_M = 0.25 # 有效深度下限,单位 m。 +PICK_DEPTH_MAX_M = 0.6 # 有效深度上限,单位 m。 +PICK_DEPTH_MEDIAN_TOLERANCE_M = 0.035 # 深度中位数滤波容差,单位 m。 + +# R_tc = np.array( +# [ +# [-0.9996, -0.0270, -0.0103], +# [0.0267, -0.7993, 0.0245], +# [-0.0109, 0.0242, 0.7996], +# ], +# dtype=float, +# ) +# T_tc = np.array([0.0, 0.0, 0.0], dtype=float) + +R_tc = np.array( + [ + [-0.9998, -0.0104, 0.0152], + [0.0099, -0.9995, -0.0316], + [0.0155, -0.0315, 0.9994], + ], + dtype=float, +) + +T_tc = np.array([-0.0026, 0.0388, -0.0632], dtype=float) + +POSITION_TOLERANCE = 0.01 +ORIENTATION_TOLERANCE = 0.05 +JOINT_TOLERANCE = 0.02 + + +# -------------------- 运行时常量 -------------------- +FRAME_SHAPE = (480, 640, 3) +AGV_SUCCESS_CODE = 10100000 +EXEC_ID_MAX_RETRIES = 20 +MOTION_COMMAND_RETRIES = 2 +MOTION_START_DELAY = 0.2 +HOME_RETRY_COUNT = 1 +HOME_RETRY_DELAY = 1.0 +STARTUP_HOME_DELAY = 2.0 +DEFAULT_PLACE_POSITION = [0.1, 0.7, 0.5, 1.5708, 1.5708, 0] + + +# -------------------- 全局运行状态 -------------------- +# `running`: 总运行开关,两个工作线程都以它作为退出条件。 +# `has_tomato`: 视觉线程在 AGV 行进时发现候选成熟番茄后置位,通知 AGV 先停下。 +# `picking_done`: 一次抓取放置完成后置位,视觉线程会在下一轮把本轮状态清掉,让 AGV 继续走。 +running = threading.Event() +has_tomato = threading.Event() +picking_done = threading.Event() + +# 这三个事件变量共同组成跨线程状态机: +# running 控制系统总开关; +# has_tomato 表示视觉线程发现了候选番茄,需要 AGV 暂停; +# picking_done 表示本轮抓取放置完成,系统可以回到继续巡检。 + +start_time = 0.0 +agv_stop_time = 0.0 + +ui_callback: Callable[[np.ndarray], None] = lambda frame: None +robot_rpc_client = pyaubo_sdk.RpcClient() + + +def configure(config: Dict[str, Any]) -> None: + """接收 control.py 同步过来的可调参数。""" + _runtime_config.update(config) + + +def _cfg(name: str) -> Any: + control_module = sys.modules.get("control") + if control_module is not None and hasattr(control_module, name): + value = getattr(control_module, name) + elif name in _runtime_config: + value = _runtime_config[name] + else: + raise RuntimeError(f"缺少配置参数 {name},请先通过 control.py 启动或同步配置") + if isinstance(value, list): + return [list(item) if isinstance(item, list) else item for item in value] + return value + + +def __getattr__(name: str) -> Any: + control_module = sys.modules.get("control") + if name in _runtime_config or (control_module is not None and hasattr(control_module, name)): + return _cfg(name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +@dataclass +class PlacementManager: + """按顺序循环取放置位,始终读取 control.py 中的 `place_positions`。""" + + lock: threading.Lock = field(default_factory=threading.Lock) + next_index: int = 0 + + def get_next_position(self) -> Tuple[List[float], int]: + with self.lock: + positions = _cfg("place_positions") or [DEFAULT_PLACE_POSITION] + used_index = self.next_index % len(positions) + position = list(positions[used_index]) + self.next_index = (used_index + 1) % len(positions) + return position, used_index + + def get_current_index(self) -> int: + with self.lock: + return self.next_index + + def reset(self) -> None: + with self.lock: + self.next_index = 0 + + +@dataclass +class DetectedTomato: + x1: int + y1: int + x2: int + y2: int + pixel_x: int + pixel_y: int + depth: float + point_cam: List[float] + point_base: List[float] + robot_pose: List[float] + angle_rad: float + cutpoint: List[int] = field(default_factory=list) + end_point: List[int] = field(default_factory=list) + + +@dataclass +class TomatoCandidate: + box: Any + box_index: int + cutpoint: Optional[Tuple[int, int]] + end_point: Optional[Tuple[int, int]] + + +placement_manager = PlacementManager() +_current_arm_controller: Optional["RobotArmController"] = None +_current_agv_controller: Optional["AgvController"] = None + + +def blank_frame() -> np.ndarray: + return np.zeros(FRAME_SHAPE, dtype=np.uint8) + + +def set_ui_callback(callback: Callable[[np.ndarray], None]) -> None: + global ui_callback + if callable(callback): + ui_callback = callback + print("UI 回调函数已设置") + else: + print("警告:传入的 UI 回调不可调用,继续使用默认空回调") + + +def safe_ui_update(frame: Any, overlay: Optional[Dict[str, Any]] = None) -> None: + try: + payload = frame + if overlay is not None: + payload = {"frame": frame, **overlay} + ui_callback(payload) + except Exception as exc: + print(f"UI 回调失败: {exc}") + + +def reset_runtime_state() -> None: + global start_time, agv_stop_time + running.clear() + has_tomato.clear() + picking_done.clear() + start_time = 0.0 + agv_stop_time = 0.0 + + +def start_runtime_state() -> None: + global start_time, agv_stop_time + has_tomato.clear() + picking_done.clear() + agv_stop_time = 0.0 + start_time = time.time() + running.set() + + +def runtime_expired() -> bool: + if start_time <= 0: + return False + return (time.time() - start_time) > _cfg("TOTAL_DURATION") + + +def remaining_seconds() -> int: + total_duration = _cfg("TOTAL_DURATION") + if start_time <= 0: + return total_duration + return max(0, int(total_duration - (time.time() - start_time))) + + +def start_agv_stop_window() -> bool: + global agv_stop_time + if agv_stop_time == 0: + agv_stop_time = time.time() + return True + return False + + +def clear_agv_stop_window() -> None: + global agv_stop_time + agv_stop_time = 0.0 + + +def agv_stop_elapsed() -> float: + if agv_stop_time == 0: + return 0.0 + return time.time() - agv_stop_time + + +def clear_pick_cycle() -> None: + # 清空这一轮“发现目标 -> 停车 -> 抓取”的中间状态。 + has_tomato.clear() + picking_done.clear() + clear_agv_stop_window() + + +def mark_pick_complete() -> None: + # 抓取完成后先打完成标记,再由搜索阶段统一恢复为继续前进。 + picking_done.set() + has_tomato.clear() + clear_agv_stop_window() + + +def get_next_placement() -> Tuple[List[float], int]: + return placement_manager.get_next_position() + + +def get_placement_index() -> int: + return placement_manager.get_current_index() + + +def create_placement_manager(): + """兼容旧接口,返回当前放置位访问器。""" + return get_next_placement, get_placement_index + + +def exampleState(robot_name: str) -> Optional[List[float]]: + """读取并打印当前机械臂关节角。""" + try: + joints = robot_rpc_client.getRobotInterface(robot_name).getRobotState().getJointPositions() + joints = list(joints) + print("机械臂关节角:", joints) + return joints + except Exception as exc: + print(f"读取关节角失败: {exc}") + return None + + +def get_robot_end_effector_pose(robot_name: str) -> Optional[List[float]]: + """读取并打印当前 TCP 位姿。""" + try: + pose = robot_rpc_client.getRobotInterface(robot_name).getRobotState().getTcpPose() + pose = list(pose) + print("机械臂末端位姿:", pose) + return pose + except Exception as exc: + print(f"读取末端位姿失败: {exc}") + return None + + +def exampleInverseK( + robot_name: str, pose: Sequence[float], reference_q: Sequence[float] +) -> Tuple[Optional[List[float]], Optional[int]]: + """执行逆解计算。""" + if not isinstance(pose, (list, tuple)) or len(pose) != 6: + print(f"逆解目标位姿格式错误,需要 6 维列表: {pose}") + return None, None + + try: + result_joints, status = robot_rpc_client.getRobotInterface(robot_name).getRobotAlgorithm().inverseKinematics( + list(reference_q), list(pose) + ) + result_joints = list(result_joints) + print("逆解输入位姿:", list(pose)) + print("逆解输出关节角:", result_joints) + if status != 0: + print(f"逆解失败,错误码: {status}") + return None, status + return result_joints, status + except Exception as exc: + print(f"逆解计算失败: {exc}") + return None, None + + +def exampleStartup() -> Optional[str]: + """机械臂上电、启动,并等待进入 Running。""" + try: + robot_names = robot_rpc_client.getRobotNames() + if not robot_names: + print("机械臂启动失败:未找到机器人名称") + return None + + robot_name = robot_names[0] + robot_manage = robot_rpc_client.getRobotInterface(robot_name).getRobotManage() + if robot_manage.poweron() != 0: + print("机械臂上电失败") + return None + print("机械臂上电请求已发送") + + if robot_manage.startup() != 0: + print("机械臂启动失败") + return None + print("机械臂启动请求已发送") + + while True: + robot_mode = robot_rpc_client.getRobotInterface(robot_name).getRobotState().getRobotModeType() + print(f"机械臂当前模式: {robot_mode.name}") + if robot_mode == pyaubo_sdk.RobotModeType.Running: + return robot_name + time.sleep(2) + except Exception as exc: + print(f"机械臂启动异常: {exc}") + return None + + +def quaternion_to_rotation_matrix(q: Sequence[float]) -> np.ndarray: + """四元数转旋转矩阵。""" + w, x, y, z = q + return np.array( + [ + [1 - 2 * y**2 - 2 * z**2, 2 * x * y - 2 * w * z, 2 * x * z + 2 * w * y], + [2 * x * y + 2 * w * z, 1 - 2 * x**2 - 2 * z**2, 2 * y * z - 2 * w * x], + [2 * x * z - 2 * w * y, 2 * y * z + 2 * w * x, 1 - 2 * x**2 - 2 * y**2], + ], + dtype=float, + ) + + +def euler_to_rotation_matrix(roll: float, pitch: float, yaw: float) -> np.ndarray: + """欧拉角转旋转矩阵,旋转顺序为 ZYX。""" + r_x = np.array( + [ + [1, 0, 0], + [0, math.cos(roll), -math.sin(roll)], + [0, math.sin(roll), math.cos(roll)], + ], + dtype=float, + ) + r_y = np.array( + [ + [math.cos(pitch), 0, math.sin(pitch)], + [0, 1, 0], + [-math.sin(pitch), 0, math.cos(pitch)], + ], + dtype=float, + ) + r_z = np.array( + [ + [math.cos(yaw), -math.sin(yaw), 0], + [math.sin(yaw), math.cos(yaw), 0], + [0, 0, 1], + ], + dtype=float, + ) + return r_z @ r_y @ r_x + + +def get_robot_end_matrix(robot_name: str) -> np.ndarray: + """获取末端相对基座的 4x4 齐次变换矩阵。""" + pose = get_robot_end_effector_pose(robot_name) + if pose is None: + return np.eye(4, dtype=float) + + try: + position = pose[:3] + if len(pose) >= 7: + rotation = quaternion_to_rotation_matrix(pose[3:7]) + elif len(pose) == 6: + rotation = euler_to_rotation_matrix(*pose[3:]) + else: + print(f"末端位姿格式错误: {pose}") + return np.eye(4, dtype=float) + + matrix = np.eye(4, dtype=float) + matrix[:3, :3] = rotation + matrix[:3, 3] = position + return matrix + except Exception as exc: + print(f"计算末端矩阵失败: {exc}") + return np.eye(4, dtype=float) + + +def camera_to_base(point_cam: Sequence[float], robot_name: str) -> Optional[np.ndarray]: + """将相机坐标系中的三维点换算到机械臂基坐标系。""" + try: + base_t_tool = get_robot_end_matrix(robot_name) + tool_t_camera = np.eye(4, dtype=float) + tool_t_camera[:3, :3] = R_tc + tool_t_camera[:3, 3] = T_tc + point_h = np.append(np.array(point_cam, dtype=float), 1.0) + point_base = base_t_tool @ tool_t_camera @ point_h + return point_base[:3] + except Exception as exc: + print(f"坐标转换失败: {exc}") + return None + + +def check_joint_position(robot_name: str, target_joints: Sequence[float]) -> bool: + """校验当前关节位置是否到达目标位置。""" + current_joints = exampleState(robot_name) + if current_joints is None or len(current_joints) != len(target_joints): + print("无法读取关节位置,或关节数量不匹配") + return False + + errors = [abs(current - target) for current, target in zip(current_joints, target_joints)] + max_error = max(errors) if errors else float("inf") + print(f"关节最大误差: {max_error:.4f} rad,阈值: {JOINT_TOLERANCE:.4f} rad") + return max_error < JOINT_TOLERANCE + + +def check_tcp_pose(robot_name: str, target_pose: Sequence[float]) -> bool: + """校验当前 TCP 位姿是否到达目标位姿。""" + current_pose = get_robot_end_effector_pose(robot_name) + if current_pose is None or len(current_pose) < 6 or len(target_pose) < 6: + print("无法读取 TCP 位姿,或目标位姿格式错误") + return False + + pos_error = np.linalg.norm(np.array(current_pose[:3], dtype=float) - np.array(target_pose[:3], dtype=float)) + ori_error = np.linalg.norm(np.array(current_pose[3:6], dtype=float) - np.array(target_pose[3:6], dtype=float)) + print( + f"TCP 位姿误差: 位置={pos_error:.4f} m,姿态={ori_error:.4f} rad | " + f"阈值: 位置={POSITION_TOLERANCE:.4f} m,姿态={ORIENTATION_TOLERANCE:.4f} rad" + ) + return pos_error < POSITION_TOLERANCE and ori_error < ORIENTATION_TOLERANCE + + +def waitArrival(impl: Any, target_joints: Optional[Sequence[float]] = None, target_pose: Optional[Sequence[float]] = None) -> int: + """ + 等待运动完成并在结束后进行位置校验。 + + 返回: + 0 = 成功 + -1 = 获取执行 ID 失败 + -2 = 运动等待超时 + """ + retry_count = 0 + motion_control = impl.getMotionControl() + while motion_control.getExecId() == -1: + retry_count += 1 + if retry_count > EXEC_ID_MAX_RETRIES: + robot_names = robot_rpc_client.getRobotNames() + robot_name = robot_names[0] if robot_names else None + if robot_name: + joints_ok = target_joints is not None and check_joint_position(robot_name, target_joints) + pose_ok = target_pose is not None and check_tcp_pose(robot_name, target_pose) + if joints_ok or pose_ok: + print("未获取到执行 ID,但检测到机械臂已到达目标位,按成功处理") + return 0 + print("运动失败:获取执行 ID 超时") + return -1 + time.sleep(0.5) + + exec_id = motion_control.getExecId() + timeout = 0.0 + max_timeout = 30.0 + + while True: + current_exec_id = motion_control.getExecId() + if exec_id != current_exec_id: + break + + timeout += 0.5 + if timeout > max_timeout: + print(f"运动等待超时:{max_timeout} 秒") + return -2 + time.sleep(0.5) + + robot_names = robot_rpc_client.getRobotNames() + robot_name = robot_names[0] if robot_names else None + if robot_name: + if target_joints is not None and not check_joint_position(robot_name, target_joints): + print("警告:关节位置未达到目标值") + if target_pose is not None and not check_tcp_pose(robot_name, target_pose): + print("警告:TCP 位姿未达到目标值") + + print("机械臂运动完成") + return 0 + + +def control_tool_io(robot_name: str, io_index: int, state: bool) -> None: + """设置工具端数字 IO。""" + io_control = robot_rpc_client.getRobotInterface(robot_name).getIoControl() + io_control.setToolDigitalOutput(io_index, state) + output = io_control.getToolDigitalOutput(io_index) + print(f"TOOL_IO[{io_index}] -> {'ON' if output else 'OFF'}") + + +@dataclass +class RobotArmController: + """机械臂子系统。""" + + robot_name: str + rpc_client: Any + placement_manager: PlacementManager + + @property + def interface(self) -> Any: + return self.rpc_client.getRobotInterface(self.robot_name) + + @property + def runtime_machine(self) -> Any: + return self.rpc_client.getRuntimeMachine() + + def get_joint_positions(self) -> Optional[List[float]]: + return exampleState(self.robot_name) + + def get_tcp_pose(self) -> Optional[List[float]]: + return get_robot_end_effector_pose(self.robot_name) + + def is_at_home(self, joints: Optional[Sequence[float]] = None) -> bool: + home_joints = _cfg("HOME_JOINTS") + current_joints = list(joints) if joints is not None else self.get_joint_positions() + if current_joints is None or len(current_joints) != len(home_joints): + return False + max_error = max(abs(current - target) for current, target in zip(current_joints, home_joints)) + return max_error < JOINT_TOLERANCE + + def keep_scissors_open(self) -> None: + try: + control_tool_io(self.robot_name, _cfg("GRIPPER_CLOSE_IO"), False) + control_tool_io(self.robot_name, _cfg("GRIPPER_OPEN_IO"), True) + print("Scissors disabled: keeping end effector open") + except Exception as exc: + print(f"Failed to keep scissors open: {exc}") + + def reset_scissors_outputs(self) -> None: + try: + control_tool_io(self.robot_name, _cfg("GRIPPER_CLOSE_IO"), False) + control_tool_io(self.robot_name, _cfg("GRIPPER_OPEN_IO"), False) + print("Scissors enabled: tool IO reset to neutral") + except Exception as exc: + print(f"Failed to reset scissors IO: {exc}") + + def close_scissors_for_pick(self) -> None: + if not _cfg("SCISSORS_ENABLED"): + self.keep_scissors_open() + return + control_tool_io(self.robot_name, _cfg("GRIPPER_OPEN_IO"), False) + control_tool_io(self.robot_name, _cfg("GRIPPER_CLOSE_IO"), True) + time.sleep(_cfg("GRIPPER_ACTION_DELAY")) + control_tool_io(self.robot_name, _cfg("GRIPPER_CLOSE_IO"), False) + + def open_scissors_after_place(self) -> None: + if not _cfg("SCISSORS_ENABLED"): + self.keep_scissors_open() + return + control_tool_io(self.robot_name, _cfg("GRIPPER_CLOSE_IO"), False) + control_tool_io(self.robot_name, _cfg("GRIPPER_OPEN_IO"), True) + time.sleep(_cfg("GRIPPER_ACTION_DELAY")) + control_tool_io(self.robot_name, _cfg("GRIPPER_OPEN_IO"), False) + + def _restart_runtime_machine(self) -> None: + """重置运行态,尽量把控制器拉回可接收运动指令的状态。""" + try: + self.runtime_machine.stop() + except Exception: + pass + time.sleep(MOTION_START_DELAY) + self.runtime_machine.start() + time.sleep(MOTION_START_DELAY) + + def move_joint( + self, + target_joints: Sequence[float], + *, + target_pose: Optional[Sequence[float]] = None, + description: str = "机械臂运动", + ) -> bool: + """按关节目标执行运动。""" + joints = list(target_joints) + last_error: Optional[Exception] = None + + for attempt in range(1, MOTION_COMMAND_RETRIES + 1): + try: + if attempt > 1: + print(f"{description}:重新下发运动指令(第 {attempt}/{MOTION_COMMAND_RETRIES} 次)") + + self._restart_runtime_machine() + self.interface.getMotionControl().moveJoint(joints, _cfg("ARM_SPEED"), _cfg("ARM_ACCEL"), 0, 0) + result = waitArrival(self.interface, target_joints=joints, target_pose=target_pose) + if result == 0: + print(f"{description}完成") + return True + + print(f"{description}失败,错误码: {result}") + except Exception as exc: + last_error = exc + print(f"{description}失败: {exc}") + finally: + try: + self.runtime_machine.stop() + except Exception: + pass + + time.sleep(MOTION_START_DELAY) + + if last_error is not None: + print(f"{description}最终失败: {last_error}") + return False + + def move_pose( + self, + target_pose: Sequence[float], + reference_joints: Sequence[float], + *, + description: str, + ) -> Optional[List[float]]: + """按位姿目标执行逆解和关节运动。""" + target_joints, _ = exampleInverseK(self.robot_name, target_pose, reference_joints) + if target_joints is None: + print(f"{description}失败:逆解无结果") + self.return_home(retries=1) + return None + + if not self.move_joint(target_joints, target_pose=target_pose, description=description): + return None + return target_joints + + def return_home(self, retries: int = HOME_RETRY_COUNT) -> bool: + """返回 Home 关节位。""" + current_joints = self.get_joint_positions() + if self.is_at_home(current_joints): + print("Robot arm is already at Home, skipping Home motion") + return True + + attempts = max(1, retries) + for attempt in range(1, attempts + 1): + if attempts == 1: + print("机械臂:返回 Home") + else: + print(f"机械臂:返回 Home(第 {attempt}/{attempts} 次)") + if self.move_joint(_cfg("HOME_JOINTS"), description="返回 Home"): + return True + if attempt < attempts: + time.sleep(HOME_RETRY_DELAY) + print("机械臂:多次尝试后仍未能返回 Home") + return False + + def execute_pick_and_place(self, pose: Sequence[float], angle_rad: float) -> bool: + """ + 执行抓取和放置动作。 + + 保持原有外部接口和基本动作顺序: + 1. 去抓取预备点 + 2. 去抓取点 + 3. 夹爪动作 + 4. 回抓取预备点 + 5. 回 Home + 6. 去放置点 + 7. 释放 + 8. 再回 Home + """ + current_joints = self.get_joint_positions() + if current_joints is None: + print("获取关节位置失败,终止抓取并尝试返回 Home") + self.return_home() + return False + + target_pose = list(pose) + print(f"机械臂:开始处理目标位姿 {target_pose}") + print(f"机械臂:关键点连线角度 {math.degrees(angle_rad):.2f}°") + + approach_pose = target_pose.copy() + # 先走到抓取预备点,避免直接扎向果实。 + approach_pose[1] += _cfg("APPROACH_Y_OFFSET") + + pickup_pose = target_pose.copy() + # 再把末端抬到可抓取高度,和原代码保持一致使用 Z 偏移。 + pickup_pose[0] += _cfg("PICKUP_X_OFFSET") + pickup_pose[1] += _cfg("PICKUP_Y_OFFSET") + pickup_pose[2] += _cfg("LIFT_Z_OFFSET") + + angle_index = int(_cfg("PICK_ANGLE_ORIENTATION_INDEX")) + if 3 <= angle_index < len(pickup_pose): + signed_angle = float(_cfg("PICK_ANGLE_SIGN")) * angle_rad + pickup_pose[angle_index] += signed_angle + print( + "机械臂:末端姿态旋转 " + f"{math.degrees(signed_angle):.2f}° | 姿态索引: {angle_index}" + ) + else: + print(f"警告:PICK_ANGLE_ORIENTATION_INDEX={_cfg('PICK_ANGLE_ORIENTATION_INDEX')} 无效,跳过姿态旋转") + + approach_joints = self.move_pose(approach_pose, current_joints, description="移动到抓取预备点") + if approach_joints is None: + self.return_home() + return False + + pickup_joints = self.move_pose(pickup_pose, approach_joints, description="旋转末端并移动到抓取点") + if pickup_joints is None: + self.return_home() + return False + + self.close_scissors_for_pick() + + retreat_joints = self.move_pose(approach_pose, pickup_joints, description="采摘后返回预抓取点") + if retreat_joints is None: + self.return_home() + return False + + if not self.return_home(): + return False + + place_pose, place_index = self.placement_manager.get_next_position() + print(f"机械臂:移动到放置位 {place_index} -> {place_pose}") + + reference_joints = self.get_joint_positions() or _cfg("HOME_JOINTS") + place_joints = self.move_pose(place_pose, reference_joints, description=f"移动到放置位 {place_index}") + if place_joints is None: + self.return_home() + return False + + self.open_scissors_after_place() + + if not self.return_home(): + return False + + print("机械臂:本次抓取放置流程完成") + return True + + +def init_camera() -> Tuple[Any, Any, Any]: + """初始化 RealSense 相机。""" + try: + pipeline = rs.pipeline() + config = rs.config() + config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30) + config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30) + profile = pipeline.start(config) + + depth_sensor = profile.get_device().first_depth_sensor() + depth_scale = depth_sensor.get_depth_scale() + depth_profile = rs.video_stream_profile(profile.get_stream(rs.stream.depth)) + depth_intrinsics = depth_profile.get_intrinsics() + align = rs.align(rs.stream.color) + + print(f"相机初始化完成 | 深度比例尺: {depth_scale}") + return pipeline, align, depth_intrinsics + except Exception as exc: + print(f"相机初始化失败: {exc}") + raise + + +def init_tomato_detector(model_path: Optional[str] = None) -> Any: + """初始化 YOLO 模型。""" + use_path = model_path or _cfg("YOLO_MODEL_PATH") or "best.pt" + print(f"加载 YOLO 模型 | 路径: {use_path}") + + if not os.path.exists(use_path): + raise FileNotFoundError(f"YOLO 模型文件不存在: {use_path}") + if not use_path.endswith(".pt"): + raise TypeError(f"模型文件格式错误: {use_path},仅支持 .pt") + + try: + model = YOLO(use_path) + print("YOLO 模型加载成功") + return model + except Exception as exc: + print(f"YOLO 模型加载失败: {exc}") + raise + + +def filter_ripe_tomatoes(results: Any, confidence_threshold: Optional[float] = None) -> List[Any]: + """筛选成熟番茄目标。""" + threshold = _cfg("PICK_CONFIDENCE_THRESHOLD") if confidence_threshold is None else confidence_threshold + ripe_tomatoes: List[Any] = [] + for box in results.boxes: + confidence = float(box.conf) + if int(box.cls) == 0 and confidence >= threshold: + ripe_tomatoes.append(box) + return ripe_tomatoes + + +def draw_tomato_tilt_line(color_image: np.ndarray, x1: int, y1: int, x2: int, y2: int, pixel_x: int) -> float: + """在目标区域内估计番茄串倾角。""" + _ = pixel_x + if x2 <= x1 or y2 <= y1: + return 0.0 + + roi = color_image[y1:y2, x1:x2] + if roi.size == 0: + return 0.0 + + gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY) + edges = cv2.Canny(gray, 50, 150) + lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=50, minLineLength=50, maxLineGap=10) + + angle = 0.0 + if lines is None: + return angle + + longest_line = max( + lines, + key=lambda line: np.linalg.norm(np.array(line[0][:2], dtype=float) - np.array(line[0][2:], dtype=float)), + ) + x1_line, y1_line, x2_line, y2_line = [int(v) for v in longest_line[0]] + x1_line += x1 + x2_line += x1 + y1_line += y1 + y2_line += y1 + + vector_tilt = np.array([x2_line - x1_line, y2_line - y1_line], dtype=float) + vector_vertical = np.array([0.0, y2 - y1], dtype=float) + norm_tilt = np.linalg.norm(vector_tilt) + norm_vertical = np.linalg.norm(vector_vertical) + + if norm_tilt > 0 and norm_vertical > 0: + cos_angle = np.dot(vector_tilt, vector_vertical) / (norm_tilt * norm_vertical) + cos_angle = max(-1.0, min(1.0, float(cos_angle))) + angle = math.degrees(math.acos(cos_angle)) + angle = angle if x2_line > x1_line else -angle + angle = max(-30.0, min(30.0, angle)) + + cv2.putText( + color_image, + f"Angle: {angle:.2f}deg", + (x1, y2 + 20), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 0), + 2, + ) + print(f"检测到番茄串倾角: {angle:.2f}°") + return angle + + +def fit_red_line(crop_image: np.ndarray, x1: int, y1: int) -> Optional[Tuple[Tuple[int, int], Tuple[int, int]]]: + """在目标区域中拟合红色主轴。""" + if crop_image.size == 0: + return None + + lower_red1 = np.array([0, 120, 70]) + upper_red1 = np.array([10, 255, 255]) + lower_red2 = np.array([160, 120, 70]) + upper_red2 = np.array([180, 255, 255]) + + hsv = cv2.cvtColor(crop_image, cv2.COLOR_BGR2HSV) + mask1 = cv2.inRange(hsv, lower_red1, upper_red1) + mask2 = cv2.inRange(hsv, lower_red2, upper_red2) + mask = cv2.bitwise_or(mask1, mask2) + + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if not contours: + return None + + all_points = np.vstack(contours).squeeze() + if len(np.atleast_1d(all_points)) < 2: + return None + + vx, vy, x0, y0 = [float(value) for value in cv2.fitLine(all_points, cv2.DIST_L2, 0, 0.01, 0.01)] + if abs(vx) < 1e-6: + p1 = (x1 + int(x0), y1) + p2 = (x1 + int(x0), y1 + crop_image.shape[0]) + return p1, p2 + + left_y = int((-x0 * vy / vx) + y0) + right_y = int((((crop_image.shape[1] - x0) * vy) / vx) + y0) + p1 = (x1, y1 + left_y) + p2 = (x1 + crop_image.shape[1], y1 + right_y) + return p1, p2 + + +def calculate_keypoint_line_angle_rad( + cutpoint: Optional[Tuple[int, int]], + end_point: Optional[Tuple[int, int]], +) -> float: + """计算上/下关键点连线相对图像竖直方向的角度。""" + if cutpoint is None or end_point is None: + return 0.0 + + dx = float(end_point[0] - cutpoint[0]) + dy = float(end_point[1] - cutpoint[1]) + if abs(dx) < 1e-6 and abs(dy) < 1e-6: + return 0.0 + + angle_rad = math.atan2(dx, dy) + max_angle_rad = math.radians(float(_cfg("MAX_PICK_ANGLE_DEG"))) + return max(-max_angle_rad, min(max_angle_rad, angle_rad)) + + +@dataclass +class VisionController: + """视觉子系统。""" + + pipeline: Any + align: Any + depth_intrinsics: Any + model: Any + robot_name: str + arm_controller: RobotArmController + + def capture_frame(self) -> Tuple[Optional[np.ndarray], Optional[Any]]: + """采集并对齐相机画面。""" + try: + frames = self.pipeline.wait_for_frames(timeout_ms=5000) + if not frames: + print("未获取到相机帧,重试...") + return None, None + + aligned_frames = self.align.process(frames) + depth_frame = aligned_frames.get_depth_frame() + color_frame = aligned_frames.get_color_frame() + if not depth_frame or not color_frame: + print("深度帧或彩色帧为空,重试...") + return None, None + + color_image = np.asanyarray(color_frame.get_data()) + return color_image, depth_frame + except Exception as exc: + print(f"获取相机帧失败: {exc}") + return None, None + + @staticmethod + def get_pick_zone_bounds(frame_width: int) -> Tuple[int, int]: + left = int(frame_width * PICK_ZONE_LEFT_RATIO) + right = int(frame_width * PICK_ZONE_RIGHT_RATIO) + return max(0, left), max(left + 1, min(frame_width - 1, right)) + + def annotate_pick_zone(self, frame: np.ndarray) -> None: + frame_h, frame_w = frame.shape[:2] + left, right = self.get_pick_zone_bounds(frame_w) + overlay = frame.copy() + cv2.rectangle(overlay, (left, 0), (right, frame_h - 1), (0, 180, 0), -1) + cv2.addWeighted(overlay, 0.12, frame, 0.88, 0, frame) + cv2.rectangle(frame, (left, 0), (right, frame_h - 1), (0, 220, 0), 2) + cv2.putText( + frame, + "Pick Zone", + (left + 8, max(24, int(frame_h * 0.08))), + cv2.FONT_HERSHEY_SIMPLEX, + 0.65, + (0, 220, 0), + 2, + ) + + @staticmethod + def box_xyxy(box: Any) -> Tuple[int, int, int, int]: + return tuple(int(value) for value in box.xyxy[0].cpu().numpy()) + + @staticmethod + def box_confidence(box: Any) -> float: + return float(box.conf[0]) if hasattr(box.conf, "__len__") else float(box.conf) + + def box_center_in_pick_zone(self, box: Any, frame_width: int) -> bool: + x1, _y1, x2, _y2 = self.box_xyxy(box) + center_x = int((x1 + x2) / 2) + left, right = self.get_pick_zone_bounds(frame_width) + return left <= center_x <= right + + @staticmethod + def as_numpy_array(value: Any) -> Optional[np.ndarray]: + if value is None: + return None + if hasattr(value, "cpu"): + value = value.cpu() + if hasattr(value, "numpy"): + value = value.numpy() + return np.asarray(value) + + @staticmethod + def clamp_pixel(px: float, py: float, frame_shape: Tuple[int, int, int]) -> Tuple[int, int]: + frame_h, frame_w = frame_shape[:2] + pixel_x = max(0, min(frame_w - 1, int(round(px)))) + pixel_y = max(0, min(frame_h - 1, int(round(py)))) + return pixel_x, pixel_y + + def extract_keypoint_by_index( + self, + current_points: np.ndarray, + current_conf: Optional[np.ndarray], + point_index: int, + frame_shape: Tuple[int, int, int], + ) -> Optional[Tuple[int, int]]: + if point_index >= len(current_points): + return None + + point = current_points[point_index] + px, py = float(point[0]), float(point[1]) + if px <= 0 and py <= 0: + return None + if current_conf is not None and point_index < len(current_conf): + if float(current_conf[point_index]) < PICK_KEYPOINT_CONF_THRESHOLD: + return None + return self.clamp_pixel(px, py, frame_shape) + + def validate_candidate_keypoints( + self, + box: Any, + cutpoint: Tuple[int, int], + end_point: Optional[Tuple[int, int]], + ) -> bool: + x1, y1, x2, y2 = self.box_xyxy(box) + box_w = max(1, x2 - x1) + box_h = max(1, y2 - y1) + margin = int(max(box_w, box_h) * PICK_KEYPOINT_BOX_MARGIN_RATIO) + + if not (x1 - margin <= cutpoint[0] <= x2 + margin and y1 - margin <= cutpoint[1] <= y2 + margin): + print(f"Pick candidate skipped: semantic cutpoint outside bbox @ {cutpoint}") + return False + + cut_rel_y = (cutpoint[1] - y1) / float(box_h) + if cut_rel_y > PICK_CUTPOINT_MAX_REL_Y: + print( + "Pick candidate skipped: semantic cutpoint too low " + f"(rel_y={cut_rel_y:.2f}) @ {cutpoint}" + ) + return False + + if end_point is None: + return True + + if not (x1 - margin <= end_point[0] <= x2 + margin and y1 - margin <= end_point[1] <= y2 + margin): + print(f"Pick candidate skipped: endpoint outside bbox @ {end_point}") + return False + + distance = float(np.linalg.norm(np.array(end_point, dtype=float) - np.array(cutpoint, dtype=float))) + if distance < PICK_KEYPOINT_MIN_DISTANCE_PX: + print(f"Pick candidate skipped: keypoints too close ({distance:.1f}px)") + return False + + if end_point[1] + PICK_ENDPOINT_Y_TOLERANCE_PX < cutpoint[1]: + print(f"Pick candidate skipped: endpoint appears above cutpoint @ cut={cutpoint}, end={end_point}") + return False + + return True + + def extract_candidate_keypoints( + self, + result: Any, + box_index: int, + frame_shape: Tuple[int, int, int], + ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: + keypoints = getattr(result, "keypoints", None) + keypoint_xy = self.as_numpy_array(getattr(keypoints, "xy", None)) + if keypoint_xy is None or box_index >= len(keypoint_xy): + return None, None + + current_points = keypoint_xy[box_index] + keypoint_conf = self.as_numpy_array(getattr(keypoints, "conf", None)) + current_conf = keypoint_conf[box_index] if keypoint_conf is not None and box_index < len(keypoint_conf) else None + + cutpoint = self.extract_keypoint_by_index( + current_points, + current_conf, + PICK_CUTPOINT_KEYPOINT_INDEX, + frame_shape, + ) + end_point = self.extract_keypoint_by_index( + current_points, + current_conf, + PICK_ENDPOINT_KEYPOINT_INDEX, + frame_shape, + ) + return cutpoint, end_point + + def build_pick_candidates(self, result: Any, frame_shape: Tuple[int, int, int]) -> List[TomatoCandidate]: + frame_w = frame_shape[1] + candidates: List[TomatoCandidate] = [] + boxes = getattr(result, "boxes", None) + if boxes is None: + return candidates + + for box_index, box in enumerate(boxes): + confidence = self.box_confidence(box) + if int(box.cls) != 0 or confidence < _cfg("PICK_CONFIDENCE_THRESHOLD"): + continue + if not self.box_center_in_pick_zone(box, frame_w): + continue + + cutpoint, end_point = self.extract_candidate_keypoints(result, box_index, frame_shape) + if cutpoint is None: + print("Pick candidate skipped: missing semantic cutpoint keypoint") + continue + if not self.validate_candidate_keypoints(box, cutpoint, end_point): + continue + candidates.append(TomatoCandidate(box=box, box_index=box_index, cutpoint=cutpoint, end_point=end_point)) + + return candidates + + def run_detection(self, image: np.ndarray) -> Tuple[np.ndarray, List[TomatoCandidate]]: + """执行 YOLO 检测并返回带标注画面。""" + results = self.model(image, classes=0, conf=_cfg("YOLO_DETECT_CONF"), verbose=False) + result = results[0] + annotated_image = result.plot() + self.annotate_pick_zone(annotated_image) + ripe_tomatoes = self.build_pick_candidates(result, image.shape) + return annotated_image, ripe_tomatoes + + def serialize_detection_boxes(self, boxes: Sequence[Any]) -> List[Dict[str, Any]]: + """把 YOLO 检测框整理成可供 UI 叠加显示的结构化信息。""" + serialized: List[Dict[str, Any]] = [] + for index, item in enumerate(boxes, start=1): + try: + box = item.box if isinstance(item, TomatoCandidate) else item + x1, y1, x2, y2 = self.box_xyxy(box) + center_x = int((x1 + x2) / 2) + center_y = int((y1 + y2) / 2) + cutpoint = list(item.cutpoint) if isinstance(item, TomatoCandidate) and item.cutpoint else [] + end_point = list(item.end_point) if isinstance(item, TomatoCandidate) and item.end_point else [] + serialized.append( + { + "label": f"Tomato-{index}", + "confidence": self.box_confidence(box), + "bbox": [x1, y1, x2, y2], + "center": [center_x, center_y], + "cutpoint": cutpoint, + "end_point": end_point, + } + ) + except Exception as exc: + print(f"检测框序列化失败: {exc}") + return serialized + + def annotate_target(self, frame: np.ndarray, target: DetectedTomato) -> None: + """在画面上绘制目标核心信息。""" + cutpoint = (target.pixel_x, target.pixel_y) + cv2.drawMarker(frame, cutpoint, (0, 255, 255), cv2.MARKER_CROSS, 18, 2) + cv2.circle(frame, cutpoint, 6, (0, 255, 255), -1) + cv2.circle(frame, cutpoint, 11, (0, 80, 255), 2) + if len(target.end_point) == 2: + end_point = tuple(int(v) for v in target.end_point) + cv2.circle(frame, end_point, 5, (255, 215, 0), -1) + cv2.line(frame, cutpoint, end_point, (255, 255, 0), 2) + angle_text = f"Angle: {math.degrees(target.angle_rad):.2f}deg" + (angle_w, _angle_h), angle_baseline = cv2.getTextSize( + angle_text, + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + 2, + ) + angle_x = max(6, min(end_point[0] + 8, frame.shape[1] - angle_w - 6)) + angle_y = max(18, min(end_point[1] + 18, frame.shape[0] - angle_baseline - 6)) + cv2.putText( + frame, + angle_text, + (angle_x, angle_y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (0, 255, 0), + 2, + ) + + @staticmethod + def collect_depth_samples( + depth_frame: Any, + pixel_x: int, + pixel_y: int, + frame_shape: Tuple[int, int, int], + radius: int, + ) -> List[float]: + frame_h, frame_w = frame_shape[:2] + x_start = max(0, pixel_x - radius) + x_end = min(frame_w - 1, pixel_x + radius) + y_start = max(0, pixel_y - radius) + y_end = min(frame_h - 1, pixel_y + radius) + + samples: List[float] = [] + for sample_y in range(y_start, y_end + 1): + for sample_x in range(x_start, x_end + 1): + depth = float(depth_frame.get_distance(sample_x, sample_y)) + if PICK_DEPTH_MIN_M <= depth <= PICK_DEPTH_MAX_M: + samples.append(depth) + return samples + + def get_robust_depth( + self, + depth_frame: Any, + pixel_x: int, + pixel_y: int, + frame_shape: Tuple[int, int, int], + ) -> Optional[float]: + center_depth = float(depth_frame.get_distance(pixel_x, pixel_y)) + for radius in (PICK_DEPTH_WINDOW_RADIUS, PICK_DEPTH_FALLBACK_WINDOW_RADIUS): + samples = self.collect_depth_samples(depth_frame, pixel_x, pixel_y, frame_shape, radius) + if len(samples) < PICK_DEPTH_MIN_VALID_PIXELS: + continue + + sample_array = np.array(samples, dtype=float) + median_depth = float(np.median(sample_array)) + stable_samples = sample_array[np.abs(sample_array - median_depth) <= PICK_DEPTH_MEDIAN_TOLERANCE_M] + if len(stable_samples) >= PICK_DEPTH_MIN_VALID_PIXELS: + depth = float(np.median(stable_samples)) + else: + depth = median_depth + + if not (PICK_DEPTH_MIN_M <= center_depth <= PICK_DEPTH_MAX_M): + print( + "采摘点中心深度无效,使用邻域稳健深度 " + f"{depth:.3f}m @ ({pixel_x}, {pixel_y}), radius={radius}" + ) + elif abs(center_depth - depth) > PICK_DEPTH_MEDIAN_TOLERANCE_M: + print( + "采摘点中心深度跳变,使用邻域稳健深度 " + f"{depth:.3f}m 替代 {center_depth:.3f}m @ ({pixel_x}, {pixel_y})" + ) + return depth + + print( + "采摘点邻域深度不足 " + f"center={center_depth:.3f}m @ ({pixel_x}, {pixel_y})" + ) + return None + + def build_detected_tomato( + self, + candidate: TomatoCandidate, + depth_frame: Any, + frame: np.ndarray, + tcp_pose: Sequence[float], + ) -> Optional[DetectedTomato]: + """把检测框转换成可供机械臂执行的抓取目标。""" + box = candidate.box + x1, y1, x2, y2 = self.box_xyxy(box) + if candidate.cutpoint is None: + print("Pick target skipped: missing upper YOLO keypoint") + return None + pixel_x, pixel_y = candidate.cutpoint + + depth = self.get_robust_depth(depth_frame, pixel_x, pixel_y, frame.shape) + if depth is None: + print(f"无效深度值 @ ({pixel_x}, {pixel_y})") + return None + + point_cam = rs.rs2_deproject_pixel_to_point(self.depth_intrinsics, [pixel_x, pixel_y], depth) + point_base = camera_to_base(point_cam, self.robot_name) + if point_base is None: + print(f"坐标转换失败 @ ({pixel_x}, {pixel_y}),本次目标跳过") + self.arm_controller.return_home() + return None + + angle_rad = calculate_keypoint_line_angle_rad(candidate.cutpoint, candidate.end_point) + print(f"关键点连线角度: {math.degrees(angle_rad):.2f}°") + robot_pose = list(point_base) + list(tcp_pose[3:6]) + + target = DetectedTomato( + x1=x1, + y1=y1, + x2=x2, + y2=y2, + pixel_x=pixel_x, + pixel_y=pixel_y, + depth=depth, + point_cam=list(point_cam), + point_base=list(point_base), + robot_pose=robot_pose, + angle_rad=angle_rad, + cutpoint=[int(pixel_x), int(pixel_y)], + end_point=list(candidate.end_point) if candidate.end_point else [], + ) + self.annotate_target(frame, target) + return target + + def log_target(self, index: int, target: DetectedTomato) -> None: + """打印目标信息。""" + print(f"检测到目标 #{index}:") + print(f" 像素坐标: ({target.pixel_x}, {target.pixel_y})") + print( + " 相机坐标: " + f"({target.point_cam[0]:.3f}, {target.point_cam[1]:.3f}, {target.point_cam[2]:.3f}) m" + ) + print( + " 机械臂基坐标: " + f"({target.point_base[0]:.3f}, {target.point_base[1]:.3f}, {target.point_base[2]:.3f}) m" + ) + + def handle_search_phase(self, raw_image: np.ndarray) -> np.ndarray: + """AGV 行进时的快速检测阶段。""" + annotated_image, ripe_tomatoes = self.run_detection(raw_image) + + if ripe_tomatoes and not picking_done.is_set(): + # 这里只负责“发现目标并请求停车”,真正抓取发生在 AGV 停稳后的下一阶段。 + has_tomato.set() + print(f"视觉检测:识别到 {len(ripe_tomatoes)} 个成熟番茄 | 触发 AGV 暂停") + elif picking_done.is_set(): + # 一次采摘结束后,把状态机恢复到“边走边看”的初始态。 + clear_pick_cycle() + print("视觉检测:采摘完成 | 允许 AGV 继续前进") + + return annotated_image + + def handle_pick_phase(self, raw_image: np.ndarray, depth_frame: Any) -> np.ndarray: + """AGV 停止后的精确检测与抓取阶段。""" + if start_agv_stop_window(): + # AGV 刚停下时画面可能还在抖,这里故意留一个稳定窗口再做精检和抓取。 + pick_settle_delay = _cfg("AGV_PICK_SETTLE_DELAY") + print(f"AGV 已停止,等待 {pick_settle_delay:.1f} 秒后开始精确检测...") + frozen_image, frozen_targets = self.run_detection(raw_image) + safe_ui_update( + frozen_image, + { + "freeze": True, + "stage": "pick_wait", + "detections": self.serialize_detection_boxes(frozen_targets), + }, + ) + time.sleep(pick_settle_delay) + return frozen_image + + agv_stop_timeout = _cfg("AGV_STOP_TIMEOUT") + if agv_stop_elapsed() > agv_stop_timeout: + print(f"AGV 停止超时 {agv_stop_timeout} 秒,允许继续前进") + clear_pick_cycle() + timeout_image = raw_image.copy() + self.annotate_pick_zone(timeout_image) + return timeout_image + + annotated_image, ripe_tomatoes = self.run_detection(raw_image) + if not ripe_tomatoes: + print("视觉检测:停止窗口内未找到成熟番茄,允许 AGV 继续前进") + clear_pick_cycle() + return annotated_image + + tcp_pose = self.arm_controller.get_tcp_pose() + if tcp_pose is None: + print("获取末端位姿失败,跳过本次抓取并返回 Home") + self.arm_controller.return_home() + clear_pick_cycle() + return annotated_image + + for index, candidate in enumerate(ripe_tomatoes, start=1): + try: + # 逐个候选目标尝试构造成机械臂可执行的抓取目标,成功一个就结束本轮。 + target = self.build_detected_tomato(candidate, depth_frame, annotated_image, tcp_pose) + if target is None: + continue + + detections = self.serialize_detection_boxes([candidate]) + if detections: + detections[0]["pick_xyz"] = [float(value) for value in target.point_cam[:3]] + detections[0]["pick_xyz_frame"] = "d405_camera" + detections[0]["angle_deg"] = float(math.degrees(target.angle_rad)) + + safe_ui_update( + annotated_image, + { + "freeze": True, + "stage": "pick_target", + "detections": detections, + }, + ) + self.log_target(index, target) + if self.arm_controller.execute_pick_and_place(target.robot_pose, target.angle_rad): + mark_pick_complete() + else: + clear_pick_cycle() + break + except Exception as exc: + print(f"处理目标时出错: {exc}") + self.arm_controller.return_home() + clear_pick_cycle() + continue + else: + print("视觉检测:候选目标均不可用,允许 AGV 继续前进") + clear_pick_cycle() + + return annotated_image + + def annotate_runtime_status(self, frame: np.ndarray) -> None: + """绘制运行时状态文本。""" + cv2.putText( + frame, + f"Time Left: {remaining_seconds()}s", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.8, + (0, 0, 255), + 2, + ) + cv2.putText( + frame, + f"Tomato: {'Yes' if has_tomato.is_set() else 'No'}", + (10, 70), + cv2.FONT_HERSHEY_SIMPLEX, + 0.8, + (0, 255, 0), + 2, + ) + cv2.putText( + frame, + f"Picking: {'Yes' if picking_done.is_set() else 'No'}", + (10, 110), + cv2.FONT_HERSHEY_SIMPLEX, + 0.8, + (255, 0, 0), + 2, + ) + + def shutdown(self) -> None: + """释放视觉资源。""" + try: + self.pipeline.stop() + except Exception as exc: + print(f"释放相机失败: {exc}") + safe_ui_update(blank_frame()) + print("视觉线程停止 | 相机已释放") + + def run(self) -> None: + """视觉主线程。""" + print("视觉线程启动 | 开始检测番茄...") + while running.is_set(): + if runtime_expired(): + print("视觉检测:总时长已到,准备停止...") + running.clear() + break + + color_image, depth_frame = self.capture_frame() + if color_image is None or depth_frame is None: + safe_ui_update(blank_frame()) + time.sleep(0.1) + continue + + try: + # 视觉线程本质上是个两阶段状态机: + # 1. 搜索态:AGV 行进,找候选番茄; + # 2. 抓取态:AGV 停住,在停止窗口内做精检并触发抓取。 + if has_tomato.is_set() and not picking_done.is_set(): + annotated_image = self.handle_pick_phase(color_image, depth_frame) + else: + annotated_image = self.handle_search_phase(color_image) + except Exception as exc: + print(f"视觉检测逻辑出错: {exc}") + self.arm_controller.return_home() + annotated_image = color_image.copy() + + self.annotate_runtime_status(annotated_image) + safe_ui_update(annotated_image) + time.sleep(0.05) + + self.shutdown() + + +@dataclass +class AgvController: + """AGV 子系统。""" + + agv_client: Any + + def reacquire_priority(self) -> None: + """重新获取 AGV 控制权。""" + self.agv_client.release_priority() + time.sleep(0.5) + self.agv_client.set_priority("admin", _cfg("AGV_IP")) + + def send_speed(self, linear_speed: float, description: str) -> None: + """发送 AGV 速度指令。""" + ret = self.agv_client.set_control_speed(linear_speed, _cfg("AGV_SPEED_STOP")) + if ret != AGV_SUCCESS_CODE: + print(f"AGV 控制:{description}指令失败(码 {ret}),重新获取控制权...") + self.reacquire_priority() + + def stop_motion(self) -> None: + """停止 AGV。""" + self.agv_client.set_control_speed(_cfg("AGV_SPEED_STOP"), _cfg("AGV_SPEED_STOP")) + + def shutdown(self) -> None: + """线程结束时确保 AGV 停止。""" + try: + self.stop_motion() + except Exception as exc: + print(f"AGV 停止失败: {exc}") + print("AGV 控制线程停止 | AGV 已停止") + + def run(self) -> None: + """AGV 控制主线程。""" + print("AGV 控制线程启动 | 等待视觉检测信号...") + while running.is_set(): + try: + if runtime_expired(): + print("AGV 控制:总时长已到,停止 AGV...") + self.stop_motion() + running.clear() + break + + if has_tomato.is_set() and not picking_done.is_set(): + # AGV 线程本身不做视觉/抓取,只消费状态位决定“停”还是“走”。 + self.send_speed(_cfg("AGV_SPEED_STOP"), "停止") + else: + self.send_speed(_cfg("AGV_SPEED_FORWARD"), "前进") + except Exception as exc: + print(f"AGV 控制线程出错: {exc}") + time.sleep(0.5) + continue + + time.sleep(0.1) + + self.shutdown() + + +def get_arm_controller(robot_name: str) -> RobotArmController: + global _current_arm_controller + if _current_arm_controller is None or _current_arm_controller.robot_name != robot_name: + _current_arm_controller = RobotArmController(robot_name, robot_rpc_client, placement_manager) + return _current_arm_controller + + +def return_to_home(robot_name: str) -> None: + """兼容旧接口:返回 Home。""" + controller = get_arm_controller(robot_name) + controller.return_home() + + +def control_robot(robot_name: str, pose: Sequence[float], angle_rad: float) -> bool: + """兼容旧接口:执行抓取与放置。""" + controller = get_arm_controller(robot_name) + return controller.execute_pick_and_place(pose, angle_rad) + + +def vision_detection_thread(pipeline: Any, align: Any, depth_intrinsics: Any, model: Any, robot_name: str) -> None: + """兼容旧接口:视觉线程入口。""" + controller = VisionController( + pipeline=pipeline, + align=align, + depth_intrinsics=depth_intrinsics, + model=model, + robot_name=robot_name, + arm_controller=get_arm_controller(robot_name), + ) + controller.run() + + +def agv_control_thread(agv_client: Any) -> None: + """兼容旧接口:AGV 线程入口。""" + global _current_agv_controller + _current_agv_controller = AgvController(agv_client) + _current_agv_controller.run() + + +def connect_robot_arm() -> RobotArmController: + """连接并启动机械臂。""" + robot_rpc_client.connect(_cfg("ROBOT_IP"), _cfg("ROBOT_PORT")) + robot_rpc_client.setRequestTimeout(1000) + if not robot_rpc_client.hasConnected(): + raise RuntimeError("机械臂连接失败") + print("机械臂连接成功") + + robot_rpc_client.login("rob1", "123456") + if not robot_rpc_client.hasLogined(): + raise RuntimeError("机械臂登录失败") + print("机械臂登录成功") + + robot_name = exampleStartup() + if not robot_name: + raise RuntimeError("机械臂启动失败") + print(f"机械臂 {robot_name} 启动成功") + + return get_arm_controller(robot_name) + + +def connect_agv() -> Any: + """连接 AGV 并获取控制权。""" + agv_client = RpcClient() + agv_client.setRequestTimeout(1000) + agv_ip = _cfg("AGV_IP") + agv_port = _cfg("AGV_PORT") + print(f"正在连接 AGV | IP: {agv_ip}:{agv_port}") + + connect_ret = agv_client.connect(agv_ip, agv_port) + if connect_ret != 0: + raise RuntimeError(f"AGV 连接失败(返回码 {connect_ret})") + + login_ret = agv_client.login("admin", "admin") + if login_ret != 0: + raise RuntimeError(f"AGV 登录失败(返回码 {login_ret})") + print("AGV 连接与登录成功") + + agv_client.release_priority() + time.sleep(0.5) + control_ret = agv_client.set_priority("admin", agv_ip) + if control_ret != AGV_SUCCESS_CODE: + raise RuntimeError(f"AGV 获取控制权失败(返回码 {control_ret})") + print("AGV 控制权获取成功") + + return agv_client + + +def release_agv(agv_client: Optional[Any]) -> None: + """释放 AGV 资源。""" + if agv_client is None: + return + try: + agv_client.release_priority() + except Exception as exc: + print(f"释放 AGV 控制权失败: {exc}") + try: + agv_client.disconnect() + except Exception as exc: + print(f"断开 AGV 失败: {exc}") + + +def release_robot_arm() -> None: + """释放机械臂 RPC 连接。""" + try: + if robot_rpc_client.hasConnected(): + robot_rpc_client.disconnect() + except Exception as exc: + print(f"断开机械臂连接失败: {exc}") + + +def main() -> None: + """主程序入口。""" + global _current_arm_controller, _current_agv_controller + + # 每次主流程启动前都把跨线程状态和放置位游标清零,避免继承上一次运行残留。 + reset_runtime_state() + placement_manager.reset() + + print("=" * 50) + print("开始初始化番茄采摘系统") + print("当前参数:") + print(f" 机械臂 IP: {_cfg('ROBOT_IP')}:{_cfg('ROBOT_PORT')}") + print(f" AGV IP: {_cfg('AGV_IP')}:{_cfg('AGV_PORT')}") + print(f" YOLO 模型: {_cfg('YOLO_MODEL_PATH')}") + print(f" 总时长: {_cfg('TOTAL_DURATION')}s") + print("=" * 50) + + arm_controller: Optional[RobotArmController] = None + agv_client: Optional[Any] = None + pipeline: Optional[Any] = None + vision_thread: Optional[threading.Thread] = None + agv_thread: Optional[threading.Thread] = None + + try: + print("\n1. 初始化机械臂...") + arm_controller = connect_robot_arm() + if not _cfg("SCISSORS_ENABLED"): + arm_controller.keep_scissors_open() + else: + arm_controller.reset_scissors_outputs() + + print("\n1.1 机械臂回到初始位置...") + # 机械臂刚进入 Running 后,控制器有时还没完全准备好接收第一条运动指令。 + # 这里把“回 Home”作为尽力而为的预处理,失败时只告警,不阻塞整机启动。 + if arm_controller.is_at_home(): + print("Robot arm already at Home during startup, continuing initialization") + else: + time.sleep(STARTUP_HOME_DELAY) + if not arm_controller.return_home(): + print("警告:机械臂启动后返回 Home 失败,继续启动系统,后续流程中再尝试回 Home") + + print("\n2. 初始化 AGV...") + agv_client = connect_agv() + + print("\n3. 初始化视觉...") + pipeline, align, depth_intrinsics = init_camera() + model = init_tomato_detector() + + print("\n4. 启动核心线程...") + start_runtime_state() + _current_arm_controller = arm_controller + _current_agv_controller = AgvController(agv_client) + + # 主线程只负责保活和兜底清理; + # 真正并发执行的是视觉线程和 AGV 控制线程。 + vision_thread = threading.Thread( + target=vision_detection_thread, + args=(pipeline, align, depth_intrinsics, model, arm_controller.robot_name), + daemon=True, + name="VisionThread", + ) + agv_thread = threading.Thread( + target=agv_control_thread, + args=(agv_client,), + daemon=True, + name="AGVThread", + ) + + vision_thread.start() + agv_thread.start() + print(f"所有线程启动完成 | 总运行时长:{_cfg('TOTAL_DURATION')} 秒") + + # 主线程不参与业务决策,只等待 `running` 被任一线程或异常路径清掉。 + while running.is_set(): + time.sleep(1) + + except KeyboardInterrupt: + print("\n用户手动中断程序...") + running.clear() + if arm_controller is not None: + arm_controller.return_home() + except Exception as exc: + print(f"系统初始化或运行失败: {exc}") + running.clear() + if arm_controller is not None: + arm_controller.return_home() + finally: + print("\n释放系统资源...") + running.clear() + + if vision_thread is not None and vision_thread.is_alive(): + vision_thread.join(timeout=2.0) + elif pipeline is not None: + try: + pipeline.stop() + except Exception: + pass + + if agv_thread is not None and agv_thread.is_alive(): + agv_thread.join(timeout=2.0) + + release_agv(agv_client) + release_robot_arm() + reset_runtime_state() + _current_arm_controller = None + _current_agv_controller = None + print("程序结束 | 所有设备已断开连接") + + +def signal_handler(sig: int, frame: Any) -> None: + """信号处理:停止线程并尽量回 Home。""" + _ = frame + print(f"\n接收到信号 {sig},正在终止程序并返回 Home...") + running.clear() + + try: + if _current_arm_controller is not None: + _current_arm_controller.return_home() + else: + robot_names = robot_rpc_client.getRobotNames() + if robot_names: + RobotArmController(robot_names[0], robot_rpc_client, placement_manager).return_home() + except Exception: + pass + + time.sleep(1) + sys.exit(0) diff --git a/main.py b/main.py new file mode 100644 index 0000000..1f86352 --- /dev/null +++ b/main.py @@ -0,0 +1,1672 @@ +import tkinter as tk +from tkinter import ttk, messagebox, scrolledtext, filedialog +import tkinter.font as tkfont +import queue +import threading +import time +import os +import sys +import json +import io +import ctypes +from ctypes import wintypes +from pathlib import Path +from PIL import Image, ImageTk, ImageDraw +import cv2 +import numpy as np +import signal +import contextlib + +""" +图形界面入口文件。 + +本文件主要负责: +1. 构建参数配置界面。 +2. 启动/停止后台采摘主程序。 +3. 捕获控制台输出并实时展示到日志面板。 +4. 接收 control.py 回传的相机画面并显示在界面上。 + +界面层本身不直接实现采摘算法,而是负责把用户输入转成运行参数, +并把后台线程的运行状态可视化出来。 +""" + +# Windows 控制台缓冲区结构体。 +# 当程序在 Windows 上运行时,会用它读取后台线程输出的控制台内容, +# 再把这些内容转发到 Tkinter 的日志面板中。 +class CONSOLE_SCREEN_BUFFER_INFO(ctypes.Structure): + _fields_ = [ + ("dwSize", wintypes._COORD), + ("dwCursorPosition", wintypes._COORD), + ("wAttributes", wintypes.WORD), + ("srWindow", wintypes.SMALL_RECT), + ("dwMaximumWindowSize", wintypes._COORD) + ] + +STD_OUTPUT_HANDLE = -11 +kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) +MAX_LOG_LINES = 1000 +MAX_TERMINAL_LINES = 1500 + +class TomatoHarvestingUI: + """番茄采摘系统的主界面类。""" + + def __init__(self, root): + self.root = root + self.root.title("番茄采摘系统参数配置") + screen_w = self.root.winfo_screenwidth() + screen_h = self.root.winfo_screenheight() + window_w = int(screen_w * 0.8) + window_h = int(screen_h * 0.7) + pos_x = max(0, (screen_w - window_w) // 2) + pos_y = max(0, (screen_h - window_h) // 2) + self.root.geometry(f"{window_w}x{window_h}+{pos_x}+{pos_y}") + self.root.minsize(max(980, int(screen_w * 0.5)), max(640, int(screen_h * 0.55))) + self.base_window_width = window_w + self.base_window_height = window_h + self.base_dir = Path(__file__).resolve().parent + self.settings_file = self.base_dir / "ui_settings.json" + self._font_resize_job = None + self._is_applying_font_scale = False + self.colors = { + "bg": "#F3F7FB", + "panel": "#FFFFFF", + "panel_soft": "#F8FBFD", + "border": "#D8E3EC", + "text": "#16324A", + "muted": "#6B7C93", + "accent": "#0F766E", + "accent_soft": "#D7F3EE", + "danger": "#B45309", + "danger_soft": "#FDE7D7", + "info": "#0F4C81", + "info_soft": "#DCEBFA", + "terminal_bg": "#0F172A", + "terminal_fg": "#DCFCE7", + "terminal_muted": "#93C5FD", + "terminal_err": "#FCA5A5", + "metric_1": "#E6FFFB", + "metric_2": "#EEF2FF", + "metric_3": "#FFF7ED", + "metric_4": "#F0FDF4", + } + + # 设置中文字体 + self.init_fonts() + + # 程序运行状态 + self.running = False + self.thread = None + self.runtime_mode = None + self.vision_test_stop_event = threading.Event() + self.vision_test_pipeline = None + self.show_background_only = True + self.scissors_enabled = tk.BooleanVar(value=True) + + # 日志队列 + self.log_queue = queue.Queue(maxsize=500) + self.log_line_count = 0 + self.terminal_queue = queue.Queue(maxsize=1000) + self.terminal_line_count = 0 + self.log_records = [] + self.terminal_message_count = 0 + self.collapsible_sections = {} + + # 相机画面变量 + self.camera_frame = None + self.camera_image = None + self.camera_bg_image = None # 保存背景图引用(关键:防止回收) + self.raw_bg_image = None # 新增:保存原始1.png,用于后续尺寸调整 + + # 创建界面 + self.create_widgets() + + # 初始化参数 + self.init_parameters() + + # 加载相机背景图(延迟100ms,确保窗口渲染完成后再加载) + self.root.after(100, self.load_camera_background) + + # 启动日志处理线程 + self.root.after(200, self.process_log_queue) + + # 信号处理 + signal.signal(signal.SIGINT, self.signal_handler) + signal.signal(signal.SIGTERM, self.signal_handler) + + def create_widgets(self): + """构建整个界面布局。""" + # 设置主窗口背景 + self.set_background_image() + + main_frame = ttk.Frame(self.root, padding=(16, 12, 16, 16), style="App.TFrame") + main_frame.pack(fill=tk.BOTH, expand=True) + + header_card = tk.Frame( + main_frame, + bg=self.colors["panel"], + bd=0, + highlightthickness=1, + highlightbackground=self.colors["border"] + ) + header_card.pack(fill=tk.X, pady=(0, 10)) + header_top = tk.Frame(header_card, bg=self.colors["panel"]) + header_top.pack(fill=tk.X, padx=18, pady=10) + self.status_badge = tk.Label( + header_top, + text="系统就绪", + font=self.fonts["button"], + bg=self.colors["accent_soft"], + fg=self.colors["accent"], + padx=14, + pady=6 + ) + self.status_badge.pack(side=tk.RIGHT) + title_wrap = tk.Frame(header_top, bg=self.colors["panel"]) + title_wrap.pack(side=tk.LEFT, fill=tk.X, expand=True) + self.title_label = tk.Label( + title_wrap, + text="番茄采摘系统", + font=self.fonts["title"], + bg=self.colors["panel"], + fg=self.colors["text"] + ) + self.title_label.pack(side=tk.LEFT) + self.subtitle_label = tk.Label( + title_wrap, + text="参数配置 · 日志 · 相机 · 终端", + font=self.fonts["subtitle"], + bg=self.colors["panel"], + fg=self.colors["muted"] + ) + self.subtitle_label.pack(side=tk.LEFT, padx=(14, 0), pady=(4, 0)) + + body_frame = ttk.Frame(main_frame, style="App.TFrame") + body_frame.pack(fill=tk.BOTH, expand=True) + + # 三栏布局(侧边配置 + 中间工作区 + 终端) + paned_window = ttk.PanedWindow(body_frame, orient=tk.HORIZONTAL) + paned_window.pack(fill=tk.BOTH, expand=True) + self.paned_window = paned_window + + # 左侧参数区 + params_frame = ttk.Frame(paned_window, width=470, style="App.TFrame") + paned_window.add(params_frame, weight=2) + + # 中间日志和相机区 + center_frame = ttk.Frame(paned_window, style="App.TFrame") + paned_window.add(center_frame, weight=3) + + # 最右侧终端区 + terminal_frame = ttk.Frame(paned_window, width=280, style="App.TFrame") + paned_window.add(terminal_frame, weight=1) + + # 参数区布局:使用可滚动容器,避免窗口高度不足时底部控件被裁切 + params_canvas = tk.Canvas( + params_frame, + highlightthickness=0, + bd=0, + relief="flat", + bg=self.colors["bg"] + ) + params_scrollbar = ttk.Scrollbar(params_frame, orient=tk.VERTICAL, command=params_canvas.yview) + params_canvas.configure(yscrollcommand=params_scrollbar.set) + params_scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + params_canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + self.params_canvas = params_canvas + self.params_scrollbar = params_scrollbar + + params_inner = ttk.Frame(params_canvas, style="App.TFrame") + self.params_canvas_window = params_canvas.create_window((0, 0), window=params_inner, anchor="nw") + params_inner.bind( + "", + lambda event: params_canvas.configure(scrollregion=params_canvas.bbox("all")) + ) + params_canvas.bind( + "", + lambda event: params_canvas.itemconfigure(self.params_canvas_window, width=event.width) + ) + + param_card = ttk.Frame(params_inner, style="Transparent.TFrame") + param_card.pack(fill=tk.BOTH, expand=True, pady=0, padx=(4, 8)) + param_grid = ttk.Frame(param_card, style="Transparent.TFrame") + param_grid.pack(fill=tk.BOTH, expand=True) + param_grid.columnconfigure(0, weight=1) + param_grid.columnconfigure(1, weight=1) + + # 1. 机械臂配置 + robot_group = ttk.LabelFrame(param_grid, text="机械臂配置", padding="6", style="SidebarCard.TLabelframe") + robot_group.grid(row=0, column=0, columnspan=2, sticky=tk.W+tk.E, pady=4, padx=4) + robot_group.columnconfigure(1, weight=1) + robot_group.columnconfigure(3, weight=1) + ttk.Label(robot_group, text="机械臂IP:").grid(row=0, column=0, sticky=tk.W, pady=2) + self.robot_ip = ttk.Entry(robot_group, width=16) + self.robot_ip.grid(row=0, column=1, sticky=tk.EW, pady=2, padx=4) + ttk.Label(robot_group, text="端口:").grid(row=0, column=2, sticky=tk.W, pady=2) + self.robot_port = ttk.Entry(robot_group, width=8) + self.robot_port.grid(row=0, column=3, sticky=tk.EW, pady=2, padx=4) + + # 2. AGV配置 + agv_group = ttk.LabelFrame(param_grid, text="AGV配置", padding="6", style="SidebarCard.TLabelframe") + agv_group.grid(row=1, column=0, columnspan=2, sticky=tk.W+tk.E, pady=4, padx=4) + agv_group.columnconfigure(1, weight=1) + agv_group.columnconfigure(3, weight=1) + ttk.Label(agv_group, text="AGV IP:").grid(row=0, column=0, sticky=tk.W, pady=2) + self.agv_ip = ttk.Entry(agv_group, width=16) + self.agv_ip.grid(row=0, column=1, sticky=tk.EW, pady=2, padx=4) + ttk.Label(agv_group, text="端口:").grid(row=0, column=2, sticky=tk.W, pady=2) + self.agv_port = ttk.Entry(agv_group, width=8) + self.agv_port.grid(row=0, column=3, sticky=tk.EW, pady=2, padx=4) + ttk.Label(agv_group, text="前进(m/s):").grid(row=1, column=0, sticky=tk.W, pady=2) + self.agv_speed_forward = ttk.Entry(agv_group, width=8) + self.agv_speed_forward.grid(row=1, column=1, sticky=tk.EW, pady=2, padx=4) + ttk.Label(agv_group, text="停止(m/s):").grid(row=1, column=2, sticky=tk.W, pady=2) + self.agv_speed_stop = ttk.Entry(agv_group, width=8) + self.agv_speed_stop.grid(row=1, column=3, sticky=tk.EW, pady=2, padx=4) + + # 3. 运行参数 + run_group = ttk.LabelFrame(param_grid, text="运行参数", padding="6", style="SidebarCard.TLabelframe") + run_group.grid(row=2, column=0, columnspan=2, sticky=tk.W+tk.E, pady=4, padx=4) + run_group.columnconfigure(1, weight=1) + run_group.columnconfigure(3, weight=1) + ttk.Label(run_group, text="总时长(s):").grid(row=0, column=0, sticky=tk.W, pady=2) + self.total_duration = ttk.Entry(run_group, width=8) + self.total_duration.grid(row=0, column=1, sticky=tk.EW, pady=2, padx=4) + ttk.Label(run_group, text="停止超时(s):").grid(row=0, column=2, sticky=tk.W, pady=2) + self.agv_stop_timeout = ttk.Entry(run_group, width=8) + self.agv_stop_timeout.grid(row=0, column=3, sticky=tk.EW, pady=2, padx=4) + # YOLO模型路径 + model_frame = ttk.Frame(run_group) + model_frame.grid(row=1, column=0, columnspan=4, sticky=tk.W+tk.E, pady=2) + ttk.Label(model_frame, text="YOLO模型:").pack(side=tk.LEFT, pady=2) + self.yolo_model_path = ttk.Entry(model_frame) + self.yolo_model_path.pack(side=tk.LEFT, pady=2, padx=4, fill=tk.X, expand=True) + self.browse_btn = ttk.Button(model_frame, text="浏览", command=self.browse_model, width=5) + self.browse_btn.pack(side=tk.LEFT, padx=(3, 0), pady=2) + bg_frame = ttk.Frame(run_group) + bg_frame.grid(row=2, column=0, columnspan=4, sticky=tk.W+tk.E, pady=2) + ttk.Label(bg_frame, text="背景图:").pack(side=tk.LEFT, pady=2) + self.camera_bg_path = ttk.Entry(bg_frame) + self.camera_bg_path.pack(side=tk.LEFT, pady=2, padx=4, fill=tk.X, expand=True) + self.browse_bg_btn = ttk.Button(bg_frame, text="浏览", command=self.browse_camera_background, width=5) + self.browse_bg_btn.pack(side=tk.LEFT, padx=(3, 0), pady=2) + + # 4. 放置位置 + place_group = ttk.LabelFrame(param_grid, text="放置位置", padding="6", style="SidebarCard.TLabelframe") + place_group.grid(row=3, column=0, columnspan=2, sticky=tk.W+tk.E, pady=4, padx=4) + # 位置参数输入 + place_params_frame = ttk.Frame(place_group) + place_params_frame.pack(fill=tk.BOTH, expand=True, pady=(0, 2)) + place_params_frame.columnconfigure(1, weight=1) + place_params_frame.columnconfigure(3, weight=1) + self.place_entries = [] + left_labels = ["X(m):", "Y(m):", "Z(m):"] + right_labels = ["Roll:", "Pitch:", "Yaw:"] + for i, label in enumerate(left_labels): + ttk.Label(place_params_frame, text=label).grid(row=i, column=0, sticky=tk.W, pady=2, padx=(2, 4)) + entry = ttk.Entry(place_params_frame, width=9) + entry.grid(row=i, column=1, sticky=tk.EW, pady=2, padx=2) + self.place_entries.append(entry) + for i, label in enumerate(right_labels): + ttk.Label(place_params_frame, text=label).grid(row=i, column=2, sticky=tk.W, pady=2, padx=(8, 4)) + entry = ttk.Entry(place_params_frame, width=9) + entry.grid(row=i, column=3, sticky=tk.EW, pady=2, padx=2) + self.place_entries.append(entry) + self.save_place_btn = ttk.Button(place_group, text="保存位置", command=self.save_place_position) + self.save_place_btn.pack(pady=4) + + # 5. 控制按钮 + control_frame = ttk.LabelFrame(param_grid, text="系统控制", padding="6", style="SidebarCard.TLabelframe") + control_frame.grid(row=4, column=0, columnspan=2, sticky=tk.W+tk.E, pady=4, padx=4) + btn_frame = ttk.Frame(control_frame) + btn_frame.pack(fill=tk.X, pady=(3, 1)) + btn_frame.columnconfigure(0, weight=1) + btn_frame.columnconfigure(1, weight=1) + btn_frame.columnconfigure(2, weight=1) + btn_frame.columnconfigure(3, weight=1) + self.start_btn = ttk.Button(btn_frame, text="启动程序", command=self.start_program, style="Primary.TButton") + self.start_btn.grid(row=0, column=0, padx=3, pady=3, sticky=tk.EW) + self.vision_test_btn = ttk.Button(btn_frame, text="视觉测试", command=self.start_vision_test, style="Secondary.TButton") + self.vision_test_btn.grid(row=0, column=1, padx=3, pady=3, sticky=tk.EW) + self.stop_btn = ttk.Button(btn_frame, text="停止程序", command=self.stop_program, state=tk.DISABLED, style="Danger.TButton") + self.stop_btn.grid(row=0, column=2, padx=3, pady=3, sticky=tk.EW) + self.save_btn = ttk.Button(btn_frame, text="保存参数", command=self.save_parameters, style="Secondary.TButton") + self.save_btn.grid(row=0, column=3, padx=3, pady=3, sticky=tk.EW) + self.scissors_enabled_check = ttk.Checkbutton( + control_frame, + text="启用末端剪刀", + variable=self.scissors_enabled, + ) + self.scissors_enabled_check.pack(anchor=tk.W, padx=5, pady=(3, 1)) + + # 中间区域布局 + # 日志区域(上方) + log_frame = ttk.LabelFrame(center_frame, text="运行日志", padding="8", style="Card.TLabelframe") + log_frame.pack(fill=tk.X, expand=False, pady=(4, 8), padx=5) + # 日志控制栏 + log_control = ttk.Frame(log_frame) + log_control.pack(fill=tk.X, pady=(2, 4)) + ttk.Label(log_control, text="日志显示:").pack(side=tk.LEFT, padx=5) + self.log_level = tk.StringVar(value="全部") + log_level_combo = ttk.Combobox(log_control, textvariable=self.log_level, width=10, state="readonly") + log_level_combo['values'] = ("全部", "信息", "警告", "错误") + log_level_combo.pack(side=tk.LEFT, padx=5) + log_level_combo.bind("<>", self.on_log_level_changed) + ttk.Button(log_control, text="清空日志", command=self.clear_log).pack(side=tk.RIGHT, padx=5) + # 日志文本框 + self.log_text = scrolledtext.ScrolledText(log_frame, wrap=tk.WORD, height=4, bd=0, relief="flat", highlightthickness=0) + self.log_text.pack(fill=tk.BOTH, expand=True, pady=(0, 2)) + self.log_text.config(state=tk.DISABLED, bg="#F0F0F0", fg="#333333", insertbackground="black") + + # 相机画面区域(下方) + camera_card = ttk.LabelFrame(center_frame, text="相机画面", padding="8", style="Card.TLabelframe") + camera_card.pack(fill=tk.BOTH, expand=True, pady=(0, 5), padx=5) + # 关键修改1:相机显示框架取消默认背景色 + self.camera_display_frame = ttk.Frame(camera_card) + self.camera_display_frame.pack(fill=tk.BOTH, expand=True) + # 绑定窗口大小变化事件(窗口拉伸时,图片自动适配) + self.camera_display_frame.bind("", self.on_camera_frame_resize) + + # 关键修改2:相机标签不继承全局背景色,且填充整个框架 + self.camera_label = ttk.Label(self.camera_display_frame) + self.camera_label.pack(fill=tk.BOTH, expand=True) + + # 终端输出区域(最右侧) + terminal_card = ttk.LabelFrame(terminal_frame, text="终端输出", padding="8", style="TerminalCard.TLabelframe") + terminal_card.pack(fill=tk.BOTH, expand=True, pady=5, padx=5) + terminal_control = ttk.Frame(terminal_card) + terminal_control.pack(fill=tk.X, pady=5) + ttk.Label(terminal_control, text="输出类型: 标准输出 / 标准错误").pack(side=tk.LEFT, padx=5) + ttk.Button(terminal_control, text="清空终端", command=self.clear_terminal).pack(side=tk.RIGHT, padx=5) + self.terminal_text = scrolledtext.ScrolledText( + terminal_card, + wrap=tk.WORD, + font=self.fonts["mono"], + bd=0, + relief="flat", + highlightthickness=0, + bg="#111111", + fg="#D8F3DC", + insertbackground="#D8F3DC" + ) + self.terminal_text.pack(fill=tk.BOTH, expand=True, pady=5) + self.terminal_text.config(state=tk.DISABLED) + + # 设置样式(核心修改:取消全局标签背景色) + self.setup_styles() + self.bind_mousewheel_to_params() + self.root.after(100, self.set_initial_layout) + self.root.after(120, self.apply_font_scale) + + + def init_fonts(self): + """创建一组可随窗口缩放的字体对象。""" + self.font_specs = { + "default": {"family": "微软雅黑", "size": 9}, + "title": {"family": "微软雅黑", "size": 18, "weight": "bold"}, + "subtitle": {"family": "微软雅黑", "size": 10}, + "section": {"family": "微软雅黑", "size": 9, "weight": "bold"}, + "button": {"family": "微软雅黑", "size": 9, "weight": "bold"}, + "button_normal": {"family": "微软雅黑", "size": 9}, + "mono": {"family": "微软雅黑", "size": 9}, + } + self.font_min_sizes = { + "default": 7, + "title": 14, + "subtitle": 8, + "section": 8, + "button": 8, + "button_normal": 8, + "mono": 8, + } + self.fonts = {} + for name, spec in self.font_specs.items(): + self.fonts[name] = tkfont.Font( + self.root, + family=spec["family"], + size=spec["size"], + weight=spec.get("weight", "normal"), + ) + self.root.option_add("*Font", self.fonts["default"]) + self.root.bind("", self.on_root_resize) + + def setup_styles(self): + """集中配置 ttk 控件样式,避免样式定义散落在各个方法中。""" + style = ttk.Style() + available_themes = style.theme_names() + if "clam" in available_themes: + style.theme_use("clam") + elif "vista" in available_themes: + style.theme_use("vista") + + style.configure("TFrame", background=self.colors["panel"]) + style.configure("App.TFrame", background=self.colors["bg"]) + style.configure("Transparent.TFrame", background=self.colors["bg"]) + style.configure("TPanedwindow", background=self.colors["bg"]) + style.configure("TLabel", background=self.colors["panel"], foreground=self.colors["text"], font=self.fonts["default"]) + style.configure( + "TEntry", + fieldbackground=self.colors["panel"], + background=self.colors["panel"], + foreground=self.colors["text"], + padding=5, + font=self.fonts["default"], + relief="flat", + borderwidth=1 + ) + style.configure( + "TCombobox", + fieldbackground=self.colors["panel"], + background=self.colors["panel"], + foreground=self.colors["text"], + arrowsize=14, + padding=3, + font=self.fonts["default"] + ) + style.configure( + "TButton", + font=self.fonts["button_normal"], + padding=(10, 6), + relief="flat", + borderwidth=0 + ) + style.configure( + "SidebarShell.TLabelframe", + background=self.colors["panel_soft"], + borderwidth=1, + relief="solid", + padding=6 + ) + style.configure( + "SidebarShell.TLabelframe.Label", + background=self.colors["panel_soft"], + foreground=self.colors["text"], + font=self.fonts["section"] + ) + style.configure( + "SidebarCard.TLabelframe", + background=self.colors["panel"], + borderwidth=1, + relief="solid", + padding=6 + ) + style.configure( + "SidebarCard.TLabelframe.Label", + background=self.colors["panel"], + foreground=self.colors["text"], + font=self.fonts["section"] + ) + style.configure( + "Card.TLabelframe", + background=self.colors["panel"], + borderwidth=1, + relief="solid", + padding=6 + ) + style.configure( + "Card.TLabelframe.Label", + background=self.colors["panel"], + foreground=self.colors["text"], + font=self.fonts["section"] + ) + style.configure( + "TerminalCard.TLabelframe", + background=self.colors["panel"], + borderwidth=1, + relief="solid", + padding=6 + ) + style.configure( + "TerminalCard.TLabelframe.Label", + background=self.colors["panel"], + foreground=self.colors["text"], + font=self.fonts["section"] + ) + style.configure("Primary.TButton", background=self.colors["accent"], foreground="#FFFFFF", font=self.fonts["button"]) + style.map("Primary.TButton", background=[("active", "#0B5E58"), ("disabled", "#A8D7D2")], foreground=[("disabled", "#F7FAFC")]) + style.configure("Danger.TButton", background="#C2410C", foreground="#FFFFFF", font=self.fonts["button"]) + style.map("Danger.TButton", background=[("active", "#9A3412"), ("disabled", "#F3B99A")], foreground=[("disabled", "#FEF2F2")]) + style.configure("Secondary.TButton", background=self.colors["info_soft"], foreground=self.colors["info"], font=self.fonts["button_normal"]) + style.map("Secondary.TButton", background=[("active", "#C8DFF7"), ("disabled", "#E7EEF5")]) + + def on_root_resize(self, event): + """窗口尺寸变化时,延迟刷新字体缩放。""" + if event.widget is not self.root or self._is_applying_font_scale: + return + if self._font_resize_job is not None: + self.root.after_cancel(self._font_resize_job) + self._font_resize_job = self.root.after(80, self.apply_font_scale) + + def apply_font_scale(self): + """根据窗口当前尺寸动态缩放字体。""" + self._font_resize_job = None + current_w = max(self.root.winfo_width(), 1) + current_h = max(self.root.winfo_height(), 1) + scale = min(current_w / self.base_window_width, current_h / self.base_window_height) + scale = max(0.72, min(1.12, scale)) + self._is_applying_font_scale = True + try: + for name, font in self.fonts.items(): + base_size = self.font_specs[name]["size"] + min_size = self.font_min_sizes[name] + font.configure(size=max(min_size, int(round(base_size * scale)))) + if hasattr(self, "log_text"): + self.log_text.configure(font=self.fonts["default"]) + if hasattr(self, "terminal_text"): + self.terminal_text.configure(font=self.fonts["mono"]) + finally: + self._is_applying_font_scale = False + + def set_status_badge(self, text, tone="idle"): + """更新顶部状态徽标。""" + tone_map = { + "idle": (self.colors["info_soft"], self.colors["info"]), + "running": (self.colors["accent_soft"], self.colors["accent"]), + "warning": (self.colors["danger_soft"], self.colors["danger"]), + "error": ("#FEE2E2", "#B91C1C"), + } + bg_color, fg_color = tone_map.get(tone, tone_map["idle"]) + if hasattr(self, "status_badge"): + self.status_badge.configure(text=text, bg=bg_color, fg=fg_color) + + def bind_mousewheel_to_params(self): + """让左侧参数区支持鼠标滚轮滚动。""" + def _on_mousewheel(event): + if hasattr(self, "params_canvas"): + self.params_canvas.yview_scroll(int(-event.delta / 120), "units") + + def _bind(_event): + self.params_canvas.bind_all("", _on_mousewheel) + + def _unbind(_event): + self.params_canvas.unbind_all("") + + self.params_canvas.bind("", _bind) + self.params_canvas.bind("", _unbind) + + def set_initial_layout(self): + """根据当前窗口尺寸设置更稳定的初始三栏宽度。""" + try: + total_width = self.paned_window.winfo_width() + if total_width > 600: + left_width = max(480, int(total_width * 0.38)) + terminal_width = max(250, int(total_width * 0.20)) + self.paned_window.sashpos(0, left_width) + self.paned_window.sashpos(1, total_width - terminal_width) + except Exception: + pass + + def get_resource_candidates(self, filename): + """返回项目内可能存放资源文件的候选路径。""" + return [ + self.base_dir / filename, + self.base_dir / "tools" / filename, + Path.cwd() / filename, + Path.cwd() / "tools" / filename, + ] + + def resolve_path(self, path_value, fallback_name=None): + """解析绝对路径或相对项目根目录的路径。""" + if path_value: + candidate = Path(path_value).expanduser() + if not candidate.is_absolute(): + candidate = self.base_dir / candidate + candidate = candidate.resolve() + if candidate.exists(): + return candidate + if fallback_name: + for candidate in self.get_resource_candidates(fallback_name): + if candidate.exists(): + return candidate.resolve() + return None + + def load_ui_settings(self): + """读取本地保存的 UI 配置。""" + if not self.settings_file.exists(): + return {} + try: + with self.settings_file.open("r", encoding="utf-8") as f: + data = json.load(f) + return data if isinstance(data, dict) else {} + except Exception as e: + self.log(f"读取 UI 配置失败:{str(e)}", level="警告") + return {} + + def save_ui_settings(self): + """保存 UI 配置,方便下次启动直接复用。""" + data = { + "ROBOT_IP": self.robot_ip.get(), + "ROBOT_PORT": self.robot_port.get(), + "AGV_IP": self.agv_ip.get(), + "AGV_PORT": self.agv_port.get(), + "AGV_SPEED_FORWARD": self.agv_speed_forward.get(), + "AGV_SPEED_STOP": self.agv_speed_stop.get(), + "TOTAL_DURATION": self.total_duration.get(), + "AGV_STOP_TIMEOUT": self.agv_stop_timeout.get(), + "YOLO_MODEL_PATH": self.yolo_model_path.get(), + "CAMERA_BG_PATH": self.camera_bg_path.get(), + "SCISSORS_ENABLED": bool(self.scissors_enabled.get()), + } + with self.settings_file.open("w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + def get_camera_background_path(self): + """返回当前相机背景图的有效路径。""" + return self.resolve_path(self.camera_bg_path.get().strip(), fallback_name="1.png") + + + def load_camera_background(self): + """加载1.png并确保全屏显示,保存原始图片用于后续缩放""" + try: + bg_path = self.get_camera_background_path() + if bg_path is None: + raise FileNotFoundError("1.png") + # 这里只保存原始图片,不直接反复对已缩放图片继续缩放, + # 这样可以减少多次窗口拉伸后的画质损失。 + self.raw_bg_image = Image.open(bg_path) + self.camera_bg_path.delete(0, tk.END) + self.camera_bg_path.insert(0, str(bg_path)) + self.adjust_bg_image_size() + except FileNotFoundError: + searched_paths = "\n".join(str(path) for path in self.get_resource_candidates("1.png")) + error_msg = f"未找到背景图 1.png。\n可在界面中手动选择图片,或将文件放到以下位置之一:\n{searched_paths}" + self.log(error_msg, level="警告") + messagebox.showwarning("图片缺失", error_msg) + self.raw_bg_image = None + self.camera_bg_image = None + self.camera_label.config(background="#E8E8E8") # 仅异常时显示灰色 + except Exception as e: + self.log(f"加载1.png失败:{str(e)}", level="错误") + self.raw_bg_image = None + self.camera_bg_image = None + self.camera_label.config(background="#E8E8E8") + + def adjust_bg_image_size(self): + """根据相机区域尺寸,调整1.png大小并显示""" + if self.raw_bg_image is None: + return + # 获取当前相机显示区域的实际尺寸(窗口渲染完成后的真实尺寸) + display_w = self.camera_display_frame.winfo_width() + display_h = self.camera_display_frame.winfo_height() + # 确保尺寸有效(避免窗口未渲染完成的情况) + if display_w <= 10 or display_h <= 10: + return + # 缩放图片:填充整个相机区域(保持图片比例可改用 Image.ANTIALIAS + 计算比例) + resized_bg = self.raw_bg_image.resize((display_w, display_h), Image.LANCZOS) + # 转换为Tkinter格式并更新标签 + self.camera_bg_image = ImageTk.PhotoImage(resized_bg) + self.camera_label.config(image=self.camera_bg_image) + + def restore_initial_camera_view(self): + """将相机区域恢复为初始背景图,并忽略后续残留画面刷新。""" + self.show_background_only = True + if self.raw_bg_image is None: + self.load_camera_background() + return + self.adjust_bg_image_size() + self.camera_image = None + if self.camera_bg_image is not None: + self.camera_label.config(image=self.camera_bg_image) + + def on_camera_frame_resize(self, event): + """相机区域窗口拉伸时,自动调整1.png尺寸""" + # 使用 after 做轻微延迟,避免用户拖拽窗口时触发过于频繁的 resize 计算。 + self.root.after(50, self.adjust_bg_image_size) + + def set_background_image(self): + """设置主窗口背景色。""" + self.root.configure(bg=self.colors["bg"]) + + def init_parameters(self): + """填充界面默认参数。 + + 这些值会作为用户第一次打开界面时的初始配置, + 同时也让界面与 control.py 的默认参数保持基本一致。 + """ + self.robot_ip.insert(0, "192.168.192.100") + self.robot_port.insert(0, "30004") + self.agv_ip.insert(0, "192.168.192.100") + self.agv_port.insert(0, "30104") + self.agv_speed_forward.insert(0, "-0.2") + self.agv_speed_stop.insert(0, "0.0") + self.total_duration.insert(0, "300") + self.agv_stop_timeout.insert(0, "10") + self.yolo_model_path.insert(0, "best.pt") + default_bg_path = self.resolve_path("", fallback_name="1.png") + if default_bg_path: + self.camera_bg_path.insert(0, str(default_bg_path)) + import control + + self.scissors_enabled.set(self.parse_bool_setting(getattr(control, "SCISSORS_ENABLED", True), default=True)) + self.place_position = self.normalize_place_position(control.place_positions) + settings = self.load_ui_settings() + if settings: + self.robot_ip.delete(0, tk.END) + self.robot_ip.insert(0, str(settings.get("ROBOT_IP", "192.168.192.100"))) + self.robot_port.delete(0, tk.END) + self.robot_port.insert(0, str(settings.get("ROBOT_PORT", "30004"))) + self.agv_ip.delete(0, tk.END) + self.agv_ip.insert(0, str(settings.get("AGV_IP", "192.168.192.100"))) + self.agv_port.delete(0, tk.END) + self.agv_port.insert(0, str(settings.get("AGV_PORT", "30104"))) + self.agv_speed_forward.delete(0, tk.END) + self.agv_speed_forward.insert(0, str(settings.get("AGV_SPEED_FORWARD", "-0.2"))) + self.agv_speed_stop.delete(0, tk.END) + self.agv_speed_stop.insert(0, str(settings.get("AGV_SPEED_STOP", "0.0"))) + self.total_duration.delete(0, tk.END) + self.total_duration.insert(0, str(settings.get("TOTAL_DURATION", "300"))) + self.agv_stop_timeout.delete(0, tk.END) + self.agv_stop_timeout.insert(0, str(settings.get("AGV_STOP_TIMEOUT", "10"))) + self.yolo_model_path.delete(0, tk.END) + self.yolo_model_path.insert(0, str(settings.get("YOLO_MODEL_PATH", "best.pt"))) + self.camera_bg_path.delete(0, tk.END) + saved_bg_path = settings.get("CAMERA_BG_PATH", "") + resolved_bg_path = self.resolve_path(saved_bg_path, fallback_name="1.png") + if resolved_bg_path: + self.camera_bg_path.insert(0, str(resolved_bg_path)) + elif saved_bg_path: + self.camera_bg_path.insert(0, str(saved_bg_path)) + self.scissors_enabled.set( + self.parse_bool_setting(settings.get("SCISSORS_ENABLED", self.scissors_enabled.get()), default=True) + ) + self.update_place_fields() + + def parse_bool_setting(self, value, default=True): + if isinstance(value, bool): + return value + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in ("1", "true", "yes", "on"): + return True + if normalized in ("0", "false", "no", "off"): + return False + return default + + def normalize_place_position(self, position): + """把外部读取到的放置位统一规整为 [x,y,z,roll,pitch,yaw]。""" + if isinstance(position, (list, tuple)): + if len(position) >= 6 and not isinstance(position[0], (list, tuple)): + try: + return [float(value) for value in position[:6]] + except (TypeError, ValueError): + pass + if position and isinstance(position[0], (list, tuple)): + return self.normalize_place_position(position[0]) + return [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + + def update_place_fields(self, event=None): + """刷新单一放置位的 6 个位姿输入框。""" + def update(): + for i, entry in enumerate(self.place_entries): + entry.delete(0, tk.END) + entry.insert(0, str(self.place_position[i])) + self.root.after(0, update) + + def save_place_position(self): + """把当前输入框中的单一放置位姿保存回内存。""" + try: + position = [float(entry.get()) for entry in self.place_entries] + if len(position) != 6: + messagebox.showerror("错误", "请填写所有位置参数") + return + self.place_position = position + self.save_ui_settings() + messagebox.showinfo("成功", "放置位已保存") + self.log("放置位已保存") + except ValueError: + messagebox.showerror("错误", "请输入有效的数值") + self.log("保存放置位失败:无效的数值", level="错误") + + def browse_model(self): + """弹出文件选择框,选择 YOLO 模型文件。""" + filename = filedialog.askopenfilename( + title="选择YOLO模型文件", + filetypes=[("PyTorch模型", "*.pt"), ("所有文件", "*.*")] + ) + if filename: + self.yolo_model_path.delete(0, tk.END) + self.yolo_model_path.insert(0, filename) + + def browse_camera_background(self): + """弹出文件选择框,选择相机背景图。""" + filename = filedialog.askopenfilename( + title="选择相机背景图", + filetypes=[("图片文件", "*.png;*.jpg;*.jpeg;*.bmp"), ("所有文件", "*.*")] + ) + if filename: + self.camera_bg_path.delete(0, tk.END) + self.camera_bg_path.insert(0, filename) + self.load_camera_background() + + def save_parameters(self): + """把界面中的参数写回 control 模块。 + + 这里采用“直接修改模块级变量”的方式,而不是写配置文件。 + 好处是后台启动时可以立刻读取最新参数,代价是参数只在本次进程生命周期内生效。 + """ + try: + import control + params = { + "ROBOT_IP": self.robot_ip.get(), + "ROBOT_PORT": int(self.robot_port.get()), + "AGV_IP": self.agv_ip.get(), + "AGV_PORT": int(self.agv_port.get()), + "AGV_SPEED_FORWARD": float(self.agv_speed_forward.get()), + "AGV_SPEED_STOP": float(self.agv_speed_stop.get()), + "TOTAL_DURATION": int(self.total_duration.get()), + "AGV_STOP_TIMEOUT": int(self.agv_stop_timeout.get()), + "YOLO_MODEL_PATH": self.yolo_model_path.get(), + "SCISSORS_ENABLED": bool(self.scissors_enabled.get()), + "place_positions": [self.place_position] + } + for param_name, value in params.items(): + if value in (None, ""): + raise ValueError(f"参数 {param_name} 不能为空") + for param_name, value in params.items(): + setattr(control, param_name, value) + self.save_ui_settings() + self.log("参数保存成功") + messagebox.showinfo("成功", "参数已保存") + return True + except Exception as e: + self.log(f"参数保存失败: {str(e)}", level="错误") + messagebox.showerror("错误", f"参数保存失败: {str(e)}") + return False + + @contextlib.contextmanager + def capture_console_output(self): + """捕获后台程序的标准输出/错误输出,并转发到日志队列。""" + def push_console_output(message, channel="stdout"): + if not message: + return + normalized = message.strip() + if not normalized: + return + level = "错误" if "error" in normalized.lower() or "失败" in normalized else \ + "警告" if "warning" in normalized.lower() or "警告" in normalized else "信息" + self.push_terminal_message(normalized, channel) + if channel == "stderr" or level in ("警告", "错误"): + self.log_queue.put((normalized, "错误" if channel == "stderr" else level)) + + old_stdout = sys.stdout + old_stderr = sys.stderr + + class TeeConsole(io.TextIOBase): + def __init__(self, original_stream, push_fn, channel): + self.original_stream = original_stream + self.push_fn = push_fn + self.channel = channel + self._buffer = "" + + def write(self, message): + if not message: + return 0 + if self.original_stream: + self.original_stream.write(message) + self.original_stream.flush() + self._buffer += message + while "\n" in self._buffer: + line, self._buffer = self._buffer.split("\n", 1) + line = line.rstrip("\r") + if line.strip(): + self.push_fn(line, self.channel) + return len(message) + + def flush(self): + if self.original_stream: + self.original_stream.flush() + if self._buffer.strip(): + self.push_fn(self._buffer.strip(), self.channel) + self._buffer = "" + + sys.stdout = TeeConsole(old_stdout, push_console_output, "stdout") + sys.stderr = TeeConsole(old_stderr, push_console_output, "stderr") + try: + yield + finally: + try: + sys.stdout.flush() + sys.stderr.flush() + finally: + sys.stdout = old_stdout + sys.stderr = old_stderr + + def start_program(self): + """启动后台采摘程序线程。""" + try: + if self.running: + messagebox.showwarning("提示", "已有任务正在运行,请先停止当前任务") + return + if not self.save_parameters(): + messagebox.showwarning("警告", "参数保存失败,无法启动程序") + return + self.running = True + self.runtime_mode = "full" + self.show_background_only = False + self.set_status_badge("运行中", "running") + self.start_btn.config(state=tk.DISABLED) + self.vision_test_btn.config(state=tk.DISABLED) + self.stop_btn.config(state=tk.NORMAL) + self.push_terminal_message("终端监听已连接,等待后台标准输出...", "system") + def main_thread(): + exited_with_error = False + try: + import control + # 将 UI 的画面更新函数注入到控制模块, + # 这样 control.py 中的视觉线程就能直接把最新画面推送回来。 + control.set_ui_callback(self.update_camera_frame) + with self.capture_console_output(): + control.main() + except Exception as e: + exited_with_error = True + self.log(f"后台线程异常退出: {str(e)}", level="错误") + self.push_terminal_message(f"后台线程异常退出: {str(e)}", "stderr") + self.root.after(0, lambda: self.set_status_badge("异常退出", "error")) + finally: + self.running = False + self.runtime_mode = None + self.root.after(0, lambda: self.stop_btn.config(state=tk.DISABLED)) + self.root.after(0, lambda: self.start_btn.config(state=tk.NORMAL)) + self.root.after(0, lambda: self.vision_test_btn.config(state=tk.NORMAL)) + if not exited_with_error: + self.root.after(0, lambda: self.set_status_badge("已停止", "idle")) + self.thread = threading.Thread(target=main_thread) + self.thread.daemon = True + self.thread.start() + self.log("程序启动成功") + except Exception as e: + self.set_status_badge("启动失败", "error") + self.log(f"程序启动失败: {str(e)}", level="错误") + messagebox.showerror("错误", f"程序启动失败: {str(e)}") + self.start_btn.config(state=tk.NORMAL) + self.vision_test_btn.config(state=tk.NORMAL) + self.stop_btn.config(state=tk.DISABLED) + + def start_vision_test(self): + """启动仅做相机采集和 YOLO 推理的视觉测试。""" + try: + if self.running: + messagebox.showwarning("提示", "已有任务正在运行,请先停止当前任务") + return + if not self.save_parameters(): + messagebox.showwarning("警告", "参数保存失败,无法启动视觉测试") + return + + self.running = True + self.runtime_mode = "vision_test" + self.vision_test_stop_event.clear() + self.show_background_only = False + self.set_status_badge("视觉测试中", "running") + self.start_btn.config(state=tk.DISABLED) + self.vision_test_btn.config(state=tk.DISABLED) + self.stop_btn.config(state=tk.NORMAL) + self.push_terminal_message("视觉测试启动:仅加载相机和 YOLO,不启动机械臂。", "system") + + def vision_test_thread(): + pipeline = None + exited_with_error = False + try: + import control_core + with self.capture_console_output(): + pipeline, align, _ = control_core.init_camera() + self.vision_test_pipeline = pipeline + model = control_core.init_tomato_detector(self.yolo_model_path.get().strip()) + self.log("视觉测试已启动,可直接查看相机检测结果") + + while not self.vision_test_stop_event.is_set(): + try: + frames = pipeline.wait_for_frames(timeout_ms=3000) + if not frames: + continue + aligned_frames = align.process(frames) + color_frame = aligned_frames.get_color_frame() + if not color_frame: + continue + + color_image = np.asanyarray(color_frame.get_data()) + results = model(color_image, classes=0, conf=control_core.YOLO_DETECT_CONF, verbose=False) + annotated_image = self.build_vision_test_frame( + color_image, + results[0], + control_core.PICK_CONFIDENCE_THRESHOLD, + ) + self.update_camera_frame(annotated_image) + except Exception as frame_exc: + if self.vision_test_stop_event.is_set(): + break + self.log(f"视觉测试帧处理失败: {frame_exc}", level="警告") + time.sleep(0.2) + self.push_terminal_message("视觉测试已停止。", "system") + except Exception as exc: + exited_with_error = True + self.log(f"视觉测试启动失败: {exc}", level="错误") + self.push_terminal_message(f"视觉测试异常退出: {exc}", "stderr") + self.root.after(0, lambda: self.set_status_badge("视觉测试失败", "error")) + finally: + if pipeline is not None: + try: + pipeline.stop() + except Exception: + pass + self.vision_test_pipeline = None + self.running = False + self.runtime_mode = None + self.vision_test_stop_event.set() + self.root.after(0, lambda: self.stop_btn.config(state=tk.DISABLED)) + self.root.after(0, lambda: self.start_btn.config(state=tk.NORMAL)) + self.root.after(0, lambda: self.vision_test_btn.config(state=tk.NORMAL)) + if not exited_with_error: + self.root.after(0, lambda: self.set_status_badge("已停止", "idle")) + + self.thread = threading.Thread(target=vision_test_thread, daemon=True) + self.thread.start() + except Exception as e: + self.running = False + self.runtime_mode = None + self.vision_test_stop_event.set() + self.set_status_badge("启动失败", "error") + self.log(f"视觉测试启动失败: {str(e)}", level="错误") + messagebox.showerror("错误", f"视觉测试启动失败: {str(e)}") + self.start_btn.config(state=tk.NORMAL) + self.vision_test_btn.config(state=tk.NORMAL) + self.stop_btn.config(state=tk.DISABLED) + + def stop_program(self): + """停止后台程序并恢复按钮状态。""" + if not self.running and not (self.thread and self.thread.is_alive()): + self.restore_initial_camera_view() + self.start_btn.config(state=tk.NORMAL) + self.vision_test_btn.config(state=tk.NORMAL) + self.stop_btn.config(state=tk.DISABLED) + self.runtime_mode = None + self.set_status_badge("已停止", "idle") + return + + self.running = False + self.set_status_badge("停止中", "warning") + self.log("正在停止程序...") + self.restore_initial_camera_view() + self.start_btn.config(state=tk.NORMAL) + self.vision_test_btn.config(state=tk.NORMAL) + self.stop_btn.config(state=tk.DISABLED) + + thread_to_wait = self.thread + self.thread = None + + try: + import control + control.set_ui_callback(lambda *_args, **_kwargs: None) + except Exception: + pass + + try: + import control_core + control_core.set_ui_callback(lambda *_args, **_kwargs: None) + control_core.reset_runtime_state() + except Exception as e: + self.log(f"发送停止信号失败: {e}", level="警告") + + self.vision_test_stop_event.set() + if self.vision_test_pipeline is not None: + try: + self.vision_test_pipeline.stop() + except Exception: + pass + finally: + self.vision_test_pipeline = None + + self.runtime_mode = None + + if thread_to_wait and thread_to_wait.is_alive(): + def wait_for_shutdown(): + thread_to_wait.join(timeout=3) + if thread_to_wait.is_alive(): + self.root.after( + 0, + lambda: self.log("后台线程仍未完全退出,请检查相机或控制器是否阻塞", level="警告"), + ) + else: + self.root.after(0, lambda: self.log("程序已停止")) + self.root.after(0, lambda: self.set_status_badge("已停止", "idle")) + + threading.Thread(target=wait_for_shutdown, daemon=True).start() + else: + self.set_status_badge("已停止", "idle") + self.log("程序已停止") + + def push_terminal_message(self, message, kind="stdout"): + """向终端输出区域追加一条消息。""" + self.terminal_queue.put((kind, message)) + + def log(self, message, level="信息", sync_terminal=False): + """向日志队列追加一条消息,由 UI 主线程异步消费。""" + self.log_queue.put((message, level)) + if sync_terminal: + self.push_terminal_message(f"[UI-{level}] {message}", "system") + + def append_text_line(self, text_widget, text, line_count_attr, max_lines, tag=None): + """向文本框追加一行并限制最大行数。""" + current_count = getattr(self, line_count_attr) + 1 + setattr(self, line_count_attr, current_count) + if current_count > max_lines: + text_widget.config(state=tk.NORMAL) + text_widget.delete(1.0, 2.0) + text_widget.config(state=tk.DISABLED) + setattr(self, line_count_attr, current_count - 1) + text_widget.config(state=tk.NORMAL) + if tag: + text_widget.insert(tk.END, text, tag) + else: + text_widget.insert(tk.END, text) + text_widget.see(tk.END) + text_widget.config(state=tk.DISABLED) + + def format_log_entry(self, timestamp, message, level): + """格式化运行日志文本和标签。""" + tag = "info" + prefix = "信息" + if level == "警告": + tag = "warning" + prefix = "警告" + elif level == "错误": + tag = "error" + prefix = "错误" + return tag, f"[{timestamp}] {prefix}: {message}\n" + + def should_display_log(self, level): + """根据下拉框判断日志是否显示。""" + selected = self.log_level.get() + return selected == "全部" or selected == level + + def on_log_level_changed(self, event=None): + """切换日志级别时刷新日志视图。""" + self.refresh_log_view() + + def refresh_log_view(self): + """根据当前筛选条件重绘运行日志。""" + self.log_text.config(state=tk.NORMAL) + self.log_text.delete(1.0, tk.END) + self.log_text.config(state=tk.DISABLED) + self.log_line_count = 0 + self.log_text.tag_config("info", foreground=self.colors["text"]) + self.log_text.tag_config("warning", foreground="#D97706") + self.log_text.tag_config("error", foreground="#DC2626") + for timestamp, message, level in self.log_records: + if self.should_display_log(level): + tag, display_text = self.format_log_entry(timestamp, message, level) + self.append_text_line(self.log_text, display_text, "log_line_count", MAX_LOG_LINES, tag=tag) + + def process_log_queue(self): + """定时消费日志队列和终端队列并刷新文本框。""" + while not self.log_queue.empty(): + message, level = self.log_queue.get() + timestamp = time.strftime('%H:%M:%S') + self.log_records.append((timestamp, message, level)) + if len(self.log_records) > MAX_LOG_LINES * 2: + self.log_records.pop(0) + self.log_text.tag_config("info", foreground=self.colors["text"]) + self.log_text.tag_config("warning", foreground="#D97706") + self.log_text.tag_config("error", foreground="#DC2626") + if self.should_display_log(level): + tag, display_text = self.format_log_entry(timestamp, message, level) + self.append_text_line(self.log_text, display_text, "log_line_count", MAX_LOG_LINES, tag=tag) + self.log_queue.task_done() + while not self.terminal_queue.empty(): + kind, message = self.terminal_queue.get() + prefix_map = { + "stdout": "STDOUT", + "stderr": "STDERR", + "system": "SYSTEM", + } + tag_map = { + "stdout": "terminal_stdout", + "stderr": "terminal_stderr", + "system": "terminal_system", + } + self.terminal_text.tag_config("terminal_stdout", foreground=self.colors["terminal_fg"]) + self.terminal_text.tag_config("terminal_stderr", foreground=self.colors["terminal_err"]) + self.terminal_text.tag_config("terminal_system", foreground=self.colors["terminal_muted"]) + self.append_text_line( + self.terminal_text, + f"[{time.strftime('%H:%M:%S')}] {prefix_map.get(kind, 'OUT')}: {message}\n", + "terminal_line_count", + MAX_TERMINAL_LINES, + tag=tag_map.get(kind, "terminal_stdout") + ) + self.terminal_queue.task_done() + self.root.after(200, self.process_log_queue) + + def clear_log(self): + """清空界面日志显示区域。""" + self.log_text.config(state=tk.NORMAL) + self.log_text.delete(1.0, tk.END) + self.log_text.config(state=tk.DISABLED) + self.log_line_count = 0 + self.log_records.clear() + self.log("日志已清空") + + def clear_terminal(self): + """清空实时终端输出区域。""" + self.terminal_text.config(state=tk.NORMAL) + self.terminal_text.delete(1.0, tk.END) + self.terminal_text.config(state=tk.DISABLED) + self.terminal_line_count = 0 + + def build_vision_test_frame(self, image, result, confidence_threshold): + """生成视觉测试画面,仅显示检测框、置信度、中心点和 YOLO-pose 关键点。""" + annotated = image.copy() + frame_h, frame_w = annotated.shape[:2] + header_y = max(30, int(frame_h * 0.05)) + cv2.putText( + annotated, + "Vision Test", + (12, header_y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.8, + (38, 166, 154), + 2, + ) + + boxes = getattr(result, "boxes", None) + keypoints = getattr(result, "keypoints", None) + keypoint_xy = getattr(keypoints, "xy", None) + keypoint_conf = getattr(keypoints, "conf", None) + tomato_index = 0 + iterable_boxes = boxes if boxes is not None else [] + + for box_index, box in enumerate(iterable_boxes): + try: + confidence = float(box.conf[0]) if hasattr(box.conf, "__len__") else float(box.conf) + if int(box.cls) != 0 or confidence < confidence_threshold: + continue + + tomato_index += 1 + x1, y1, x2, y2 = [int(value) for value in box.xyxy[0].cpu().numpy()] + x1 = max(0, min(frame_w - 1, x1)) + x2 = max(0, min(frame_w - 1, x2)) + y1 = max(0, min(frame_h - 1, y1)) + y2 = max(0, min(frame_h - 1, y2)) + if x2 <= x1 or y2 <= y1: + continue + + pixel_x = int((x1 + x2) / 2) + pixel_y = int((y1 + y2) / 2) + label = f"Tomato-{tomato_index} {confidence:.2f}" + + cv2.rectangle(annotated, (x1, y1), (x2, y2), (0, 165, 255), 2) + cv2.circle(annotated, (pixel_x, pixel_y), 5, (0, 0, 255), -1) + + text_scale = 0.55 + text_thickness = 2 + (text_w, text_h), baseline = cv2.getTextSize( + label, + cv2.FONT_HERSHEY_SIMPLEX, + text_scale, + text_thickness, + ) + text_x = max(6, min(x1, frame_w - text_w - 10)) + text_y = y1 - 10 + if text_y - text_h < header_y + 8: + text_y = min(frame_h - 8, y2 + text_h + 10) + box_top = max(0, text_y - text_h - baseline - 4) + box_bottom = min(frame_h, text_y + baseline + 2) + box_right = min(frame_w, text_x + text_w + 8) + cv2.rectangle(annotated, (text_x - 4, box_top), (box_right, box_bottom), (0, 165, 255), -1) + cv2.putText( + annotated, + label, + (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, + text_scale, + (255, 255, 255), + text_thickness, + ) + + if keypoint_xy is not None: + try: + current_points = keypoint_xy[box_index] + if hasattr(current_points, "cpu"): + current_points = current_points.cpu().numpy() + current_conf = None + if keypoint_conf is not None: + current_conf = keypoint_conf[box_index] + if hasattr(current_conf, "cpu"): + current_conf = current_conf.cpu().numpy() + + for point_index, point in enumerate(current_points): + px, py = [int(v) for v in point[:2]] + if px <= 0 and py <= 0: + continue + if current_conf is not None and point_index < len(current_conf): + if float(current_conf[point_index]) < 0.2: + continue + px = max(0, min(frame_w - 1, px)) + py = max(0, min(frame_h - 1, py)) + cv2.circle(annotated, (px, py), 4, (0, 255, 255), -1) + cv2.circle(annotated, (px, py), 7, (0, 120, 255), 1) + except Exception as kp_exc: + self.log(f"视觉测试关键点绘制失败: {kp_exc}", level="警告") + except Exception as exc: + self.log(f"视觉测试标注失败: {exc}", level="警告") + + return annotated + + def serialize_detection_boxes_for_ui(self, boxes, image=None, fit_line_fn=None): + """把检测框整理成 UI 叠加显示需要的结构。""" + detections = [] + for index, box in enumerate(boxes, start=1): + try: + x1, y1, x2, y2 = [int(value) for value in box.xyxy[0].cpu().numpy()] + cutpoint = [] + end_point = [] + if image is not None and callable(fit_line_fn): + crop_image = image[y1:y2, x1:x2] + line_points = fit_line_fn(crop_image, x1, y1) + if line_points: + p1, p2 = [list(map(int, point)) for point in line_points] + cutpoint, end_point = sorted([p1, p2], key=lambda point: point[1]) + detections.append( + { + "label": f"Tomato-{index}", + "confidence": float(box.conf), + "bbox": [x1, y1, x2, y2], + "cutpoint": cutpoint, + "end_point": end_point, + } + ) + except Exception as exc: + self.log(f"检测框解析失败: {exc}", level="警告") + return detections + + @staticmethod + def _overlay_text_rect(origin, text_size, baseline=0, padding=4): + text_x, text_y = origin + text_w, text_h = text_size + return ( + text_x - padding, + text_y - text_h - baseline - padding, + text_x + text_w + padding, + text_y + baseline + padding, + ) + + @staticmethod + def _rects_overlap(rect_a, rect_b, padding=4): + ax1, ay1, ax2, ay2 = rect_a + bx1, by1, bx2, by2 = rect_b + return not ( + ax2 + padding < bx1 + or bx2 + padding < ax1 + or ay2 + padding < by1 + or by2 + padding < ay1 + ) + + @staticmethod + def _clamp_text_origin(x, y, text_size, baseline, frame_w, frame_h, margin=6): + text_w, text_h = text_size + x = max(margin, min(int(x), frame_w - text_w - margin)) + y = max(text_h + baseline + margin, min(int(y), frame_h - baseline - margin)) + return x, y + + def _place_overlay_text(self, text, anchors, occupied_rects, frame_shape, scale=0.48, thickness=1): + frame_h, frame_w = frame_shape[:2] + (text_w, text_h), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, scale, thickness) + first_origin = None + first_rect = None + for anchor_x, anchor_y in anchors: + origin = self._clamp_text_origin(anchor_x, anchor_y, (text_w, text_h), baseline, frame_w, frame_h) + rect = self._overlay_text_rect(origin, (text_w, text_h), baseline) + if first_origin is None: + first_origin, first_rect = origin, rect + if not any(self._rects_overlap(rect, used_rect, padding=6) for used_rect in occupied_rects): + return origin, rect + return first_origin, first_rect + + @staticmethod + def _draw_outlined_text(image, text, origin, scale, color, thickness=1): + cv2.putText( + image, + text, + origin, + cv2.FONT_HERSHEY_SIMPLEX, + scale, + (0, 0, 0), + thickness + 2, + cv2.LINE_AA, + ) + cv2.putText( + image, + text, + origin, + cv2.FONT_HERSHEY_SIMPLEX, + scale, + color, + thickness, + cv2.LINE_AA, + ) + + def overlay_detection_metadata(self, image, payload): + """在冻结画面上叠加番茄串框、置信度和关键点信息。""" + if not isinstance(payload, dict): + return image + + detections = payload.get("detections") or [] + if not detections: + return image + + annotated = image.copy() + is_freeze_frame = bool(payload.get("freeze")) + stage = payload.get("stage", "") + header_text = "Pick Freeze" if is_freeze_frame else ("Vision Test" if stage == "vision_test" else "Detection") + header_y = max(30, int(annotated.shape[0] * 0.05)) + occupied_rects = [] + (header_w, header_h), header_baseline = cv2.getTextSize( + header_text, + cv2.FONT_HERSHEY_SIMPLEX, + 0.8, + 2, + ) + occupied_rects.append( + self._overlay_text_rect((12, header_y), (header_w, header_h), header_baseline, padding=6) + ) + cv2.putText( + annotated, + header_text, + (12, header_y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.8, + (38, 166, 154), + 2, + ) + + for index, item in enumerate(detections, start=1): + bbox = item.get("bbox") or [] + cutpoint = item.get("cutpoint") or [] + end_point = item.get("end_point") or [] + if len(bbox) != 4: + continue + + x1, y1, x2, y2 = [int(v) for v in bbox] + confidence = float(item.get("confidence", 0.0)) + label = item.get("label", f"Tomato-{index}") + box_color = (41, 121, 255) if is_freeze_frame else (64, 181, 246) + + cv2.rectangle(annotated, (x1, y1), (x2, y2), box_color, 2) + + text = f"{label} {confidence:.2f}" + (text_w, text_h), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 2) + text_x = max(6, min(x1, annotated.shape[1] - text_w - 10)) + text_y = y1 - 10 + if text_y - text_h < header_y + 8: + text_y = min(annotated.shape[0] - 8, y2 + text_h + 10) + box_top = max(0, text_y - text_h - baseline - 4) + box_bottom = min(annotated.shape[0], text_y + baseline + 2) + box_right = min(annotated.shape[1], text_x + text_w + 8) + cv2.rectangle(annotated, (text_x - 4, box_top), (box_right, box_bottom), box_color, -1) + occupied_rects.append((text_x - 4, box_top, box_right, box_bottom)) + cv2.putText( + annotated, + text, + (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.55, + (255, 255, 255), + 2, + ) + + if len(cutpoint) == 2 and len(end_point) == 2: + p1 = tuple(int(v) for v in cutpoint) + p2 = tuple(int(v) for v in end_point) + cv2.circle(annotated, p1, 5, (0, 255, 255), -1) + cv2.circle(annotated, p2, 5, (255, 215, 0), -1) + cv2.line(annotated, p1, p2, (0, 220, 0), 2) + cv2.putText( + annotated, + "cutpoint", + (p1[0] + 8, max(18, p1[1] - 8)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.45, + (0, 255, 255), + 1, + ) + (cut_w, cut_h), cut_baseline = cv2.getTextSize( + "cutpoint", + cv2.FONT_HERSHEY_SIMPLEX, + 0.45, + 1, + ) + occupied_rects.append( + self._overlay_text_rect( + (p1[0] + 8, max(18, p1[1] - 8)), + (cut_w, cut_h), + cut_baseline, + ) + ) + cv2.putText( + annotated, + "end_point", + (p2[0] + 8, min(annotated.shape[0] - 10, p2[1] + 16)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.45, + (255, 215, 0), + 1, + ) + (end_w, end_h), end_baseline = cv2.getTextSize( + "end_point", + cv2.FONT_HERSHEY_SIMPLEX, + 0.45, + 1, + ) + occupied_rects.append( + self._overlay_text_rect( + (p2[0] + 8, min(annotated.shape[0] - 10, p2[1] + 16)), + (end_w, end_h), + end_baseline, + ) + ) + angle_deg = item.get("angle_deg") + if angle_deg is not None: + angle_text = f"Angle: {float(angle_deg):.2f}deg" + (angle_w, angle_h), angle_baseline = cv2.getTextSize( + angle_text, + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + 2, + ) + angle_x = max(6, min(p2[0] + 8, annotated.shape[1] - angle_w - 6)) + angle_y = max(18, min(p2[1] + 18, annotated.shape[0] - angle_baseline - 6)) + occupied_rects.append( + self._overlay_text_rect( + (angle_x, angle_y), + (angle_w, angle_h), + angle_baseline, + ) + ) + + pick_xyz = item.get("pick_xyz") or [] + if is_freeze_frame and len(pick_xyz) >= 3 and len(cutpoint) == 2: + try: + xyz = [float(value) for value in pick_xyz[:3]] + p1 = tuple(int(v) for v in cutpoint) + xyz_text = f"D405 XYZ: {xyz[0]:.3f}, {xyz[1]:.3f}, {xyz[2]:.3f} m" + text_scale = 0.48 + text_thickness = 1 + + (xyz_w, xyz_h), _xyz_baseline = cv2.getTextSize( + xyz_text, + cv2.FONT_HERSHEY_SIMPLEX, + text_scale, + text_thickness, + ) + anchors = [ + (x2 + 12, y1 + xyz_h + 8), + (x1 - xyz_w - 12, y1 + xyz_h + 8), + (x2 + 12, y2 - 8), + (x1 - xyz_w - 12, y2 - 8), + (p1[0] + 14, p1[1] - 18), + (p1[0] + 14, p1[1] + xyz_h + 22), + (p1[0] - xyz_w - 14, p1[1] - 18), + (p1[0] - xyz_w - 14, p1[1] + xyz_h + 22), + ] + origin, xyz_rect = self._place_overlay_text( + xyz_text, + anchors, + occupied_rects, + annotated.shape, + scale=text_scale, + thickness=text_thickness, + ) + if origin and xyz_rect: + leader_end = ( + max(xyz_rect[0], min(p1[0], xyz_rect[2])), + max(xyz_rect[1], min(p1[1], xyz_rect[3])), + ) + cv2.line(annotated, p1, leader_end, (0, 255, 255), 1, cv2.LINE_AA) + self._draw_outlined_text( + annotated, + xyz_text, + origin, + text_scale, + (0, 255, 255), + text_thickness, + ) + occupied_rects.append(xyz_rect) + except Exception as xyz_exc: + self.log(f"pick xyz overlay failed: {xyz_exc}", level="警告") + + return annotated + + def update_camera_frame(self, image): + """接收后台线程传来的 OpenCV 图像,并显示到 Tkinter 标签上。""" + def update(): + try: + if self.show_background_only: + return + payload = image if isinstance(image, dict) else None + frame = payload.get("frame") if payload else image + if frame is None: + return + if payload: + frame = self.overlay_detection_metadata(frame, payload) + image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(image_rgb) + display_w = self.camera_display_frame.winfo_width() + display_h = self.camera_display_frame.winfo_height() + if display_w > 10 and display_h > 10: + # 这里与背景图一样,统一按当前显示区域缩放, + # 让视觉画面始终与右侧面板尺寸匹配。 + pil_img = pil_img.resize((display_w, display_h), Image.LANCZOS) + self.camera_image = ImageTk.PhotoImage(image=pil_img) + self.camera_label.config(image=self.camera_image) # 覆盖1.png + except Exception as e: + self.log(f"相机画面更新失败: {str(e)}", level="错误") + self.root.after(0, update) + + def signal_handler(self, sig, frame): + """响应系统信号,优先走与手动停止一致的退出流程。""" + self.log(f"接收到信号 {sig},正在终止程序...", level="警告") + self.stop_program() + self.root.quit() + +if __name__ == "__main__": + root = tk.Tk() + app = TomatoHarvestingUI(root) + root.mainloop() diff --git a/tools/1.png b/tools/1.png new file mode 100644 index 0000000..e2cca92 Binary files /dev/null and b/tools/1.png differ diff --git a/tools/aubo_joint_position.py b/tools/aubo_joint_position.py new file mode 100644 index 0000000..ab99590 --- /dev/null +++ b/tools/aubo_joint_position.py @@ -0,0 +1,106 @@ +from pyaubo_sdk import RpcClient, RtdeClient +import time +import threading + +ROBOT_IP = "192.168.192.100" +RPC_PORT = 30004 +RTDE_PORT = 30010 +USERNAME = "aubo" +PASSWORD = "123456" +PRINT_INTERVAL = 2.0 + +latest_joint_positions = None +latest_tcp_pose = None +lock = threading.Lock() + + +def main(): + global latest_joint_positions, latest_tcp_pose + + rpc = RpcClient() + rtde = RtdeClient() + topic_id = None + + try: + # 1. RPC 连接 + rpc.connect(ROBOT_IP, RPC_PORT) + rpc.login(USERNAME, PASSWORD) + + robot_names = rpc.getRobotNames() + if not robot_names: + print("未找到机器人") + return + + print(f"已连接机器人: {robot_names[0]}") + + # 2. RTDE 连接 + 登录 + rtde.connect(ROBOT_IP, RTDE_PORT) + rtde.login(USERNAME, PASSWORD) + + # 3. 设置 RTDE 订阅项 + names = ["R1_actual_q", "R1_actual_TCP_pose"] + topic_id = rtde.setTopic(False, names, 50, 0) + + # 4. 回调函数 + def callback(parser): + global latest_joint_positions, latest_tcp_pose + + joint_positions = parser.popVectorDouble() + tcp_pose = parser.popVectorDouble() + + with lock: + latest_joint_positions = joint_positions + latest_tcp_pose = tcp_pose + + # 5. 订阅 + rtde.subscribe(topic_id, callback) + + print(f"开始监听(每 {PRINT_INTERVAL:.0f}s 输出一次)... Ctrl+C 退出") + + while True: + time.sleep(PRINT_INTERVAL) + + with lock: + joints = latest_joint_positions[:] if latest_joint_positions is not None else None + tcp = latest_tcp_pose[:] if latest_tcp_pose is not None else None + + if joints is not None and tcp is not None: + joints_fmt = [round(j, 4) for j in joints] + tcp_fmt = [round(t, 4) for t in tcp] + + print("\n" + "=" * 60) + print(f"关节角 (rad): {joints_fmt}") + print(f"末端位姿 : {tcp_fmt}") + print("=" * 60) + else: + print("尚未收到完整数据...") + + except KeyboardInterrupt: + print("\n程序退出") + except Exception as e: + print(f"\n运行出错: {e}") + finally: + try: + if topic_id is not None: + rtde.removeTopic(False, topic_id) + except Exception: + pass + + try: + rtde.disconnect() + except Exception: + pass + + try: + rpc.logout() + except Exception: + pass + + try: + rpc.disconnect() + except Exception: + pass + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tools/realsense_record_video.py b/tools/realsense_record_video.py new file mode 100644 index 0000000..beddc21 --- /dev/null +++ b/tools/realsense_record_video.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +import argparse +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path + +import cv2 +import numpy as np +import pyrealsense2 as rs + + +SAVE_DIR = Path("d405_recordings") +WINDOW_NAME = "RealSense D405 Video Recorder" +PREFERRED_COLOR_FORMATS = (rs.format.bgr8, rs.format.rgb8) + + +@dataclass(frozen=True) +class ColorStreamConfig: + width: int + height: int + fps: int + stream_format: rs.format + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Record color video from an Intel RealSense D405 camera.") + parser.add_argument("--width", type=int, default=1280, help="Color stream width. Default: 1280") + parser.add_argument("--height", type=int, default=720, help="Color stream height. Default: 720") + parser.add_argument("--fps", type=int, default=30, help="Color stream FPS. Default: 30") + parser.add_argument( + "--list-profiles", + action="store_true", + help="List supported color video profiles for the connected camera and exit.", + ) + return parser.parse_args() + + +def get_color_stream_candidates() -> list[ColorStreamConfig]: + context = rs.context() + devices = context.query_devices() + if len(devices) == 0: + raise RuntimeError("No RealSense device detected.") + + device = devices[0] + stream_candidates: list[ColorStreamConfig] = [] + + for sensor in device.sensors: + for profile in sensor.get_stream_profiles(): + if profile.stream_type() != rs.stream.color: + continue + if profile.format() not in PREFERRED_COLOR_FORMATS: + continue + + try: + video_profile = profile.as_video_stream_profile() + except RuntimeError: + continue + + stream_candidates.append( + ColorStreamConfig( + width=video_profile.width(), + height=video_profile.height(), + fps=profile.fps(), + stream_format=profile.format(), + ) + ) + + if not stream_candidates: + raise RuntimeError("No usable color stream profile found for the RealSense device.") + + return stream_candidates + + +def list_profiles() -> None: + try: + stream_candidates = get_color_stream_candidates() + except RuntimeError as exc: + print(exc) + return + + print("Supported color stream profiles:") + for candidate in sorted( + set(stream_candidates), + key=lambda item: (item.width, item.height, item.fps, str(item.stream_format)), + ): + print(f" {candidate.width}x{candidate.height}@{candidate.fps} {candidate.stream_format}") + + +def select_color_stream(width: int, height: int, fps: int) -> ColorStreamConfig: + stream_candidates = get_color_stream_candidates() + matching_candidates = [ + candidate + for candidate in stream_candidates + if candidate.width == width and candidate.height == height and candidate.fps == fps + ] + + if not matching_candidates: + raise RuntimeError( + f"No supported color stream profile matches {width}x{height}@{fps}. " + "Run with --list-profiles to see available profiles." + ) + + def sort_key(candidate: ColorStreamConfig) -> int: + return 1 if candidate.stream_format == rs.format.bgr8 else 0 + + return max(matching_candidates, key=sort_key) + + +def init_camera(stream_config: ColorStreamConfig) -> rs.pipeline: + pipeline = rs.pipeline() + config = rs.config() + config.enable_stream( + rs.stream.color, + stream_config.width, + stream_config.height, + stream_config.stream_format, + stream_config.fps, + ) + pipeline.start(config) + return pipeline + + +def build_video_writer(output_path: Path, stream_config: ColorStreamConfig) -> cv2.VideoWriter: + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + writer = cv2.VideoWriter( + str(output_path), + fourcc, + stream_config.fps, + (stream_config.width, stream_config.height), + ) + if not writer.isOpened(): + raise RuntimeError(f"Failed to create video writer: {output_path}") + return writer + + +def draw_status(frame: np.ndarray, is_recording: bool, output_path: Path | None) -> np.ndarray: + preview = frame.copy() + frame_height = preview.shape[0] + if is_recording: + cv2.circle(preview, (25, 30), 8, (0, 0, 255), -1) + cv2.putText( + preview, + "REC", + (40, 36), + cv2.FONT_HERSHEY_SIMPLEX, + 0.8, + (0, 0, 255), + 2, + ) + else: + cv2.putText( + preview, + "Press 's' to start recording", + (20, 36), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (0, 255, 255), + 2, + ) + + cv2.putText( + preview, + "Press 'q' to stop and quit", + (20, 70), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + + if output_path is not None: + cv2.putText( + preview, + f"Saving: {output_path.name}", + (20, frame_height - 20), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (255, 255, 255), + 2, + ) + + return preview + + +def frame_to_bgr(color_frame: rs.video_frame, stream_config: ColorStreamConfig) -> np.ndarray: + color_image = np.asanyarray(color_frame.get_data()) + if stream_config.stream_format == rs.format.rgb8: + return cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR) + return color_image + + +def main() -> None: + args = parse_args() + if args.list_profiles: + list_profiles() + return + + SAVE_DIR.mkdir(parents=True, exist_ok=True) + stream_config = select_color_stream(args.width, args.height, args.fps) + pipeline = init_camera(stream_config) + writer: cv2.VideoWriter | None = None + output_path: Path | None = None + is_recording = False + + print("-" * 40) + print("RealSense D405 video recorder") + print( + "Selected color stream: " + f"{stream_config.width}x{stream_config.height} @ {stream_config.fps} FPS " + f"({stream_config.stream_format})" + ) + print("Press 's' to start recording") + print("Press 'q' to stop recording and quit") + print(f"Video files will be saved to: {SAVE_DIR.resolve()}") + print("-" * 40) + + try: + while True: + frames = pipeline.wait_for_frames(timeout_ms=2000) + color_frame = frames.get_color_frame() + if not color_frame: + continue + + color_image = frame_to_bgr(color_frame, stream_config) + + if is_recording and writer is not None: + writer.write(color_image) + + preview = draw_status(color_image, is_recording, output_path) + cv2.imshow(WINDOW_NAME, preview) + + key = cv2.waitKey(1) & 0xFF + + if key == ord("s") and not is_recording: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = SAVE_DIR / f"d405_record_{timestamp}.mp4" + writer = build_video_writer(output_path, stream_config) + is_recording = True + print(f"Recording started: {output_path}") + + if key == ord("q"): + if is_recording: + print("Recording stopped.") + else: + print("Exit without recording.") + break + finally: + if writer is not None: + writer.release() + pipeline.stop() + cv2.destroyAllWindows() + if output_path is not None and output_path.exists(): + print(f"Saved video: {output_path.resolve()}") + + +if __name__ == "__main__": + main() diff --git a/tools/video_to_rgb_frames.py b/tools/video_to_rgb_frames.py new file mode 100644 index 0000000..d2afa59 --- /dev/null +++ b/tools/video_to_rgb_frames.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +import cv2 + + +DEFAULT_VIDEO_DIR = Path("d405_recordings") +DEFAULT_OUTPUT_DIR = Path("d405_rgb_frames") +DEFAULT_TARGET_FPS = 8.0 + + +def find_latest_video(video_dir: Path) -> Path: + candidates = sorted(video_dir.glob("*.mp4"), key=lambda path: path.stat().st_mtime, reverse=True) + if not candidates: + raise FileNotFoundError(f"No .mp4 videos found in {video_dir.resolve()}") + return candidates[0] + + +def extract_frames(video_path: Path, output_dir: Path, target_fps: float) -> None: + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise RuntimeError(f"Failed to open video: {video_path}") + + source_fps = cap.get(cv2.CAP_PROP_FPS) + if source_fps <= 0: + source_fps = 30.0 + + frame_interval = max(source_fps / target_fps, 1.0) + next_frame_to_save = 0.0 + frame_index = 0 + saved_count = 0 + + output_dir.mkdir(parents=True, exist_ok=True) + + print("-" * 40) + print(f"Input video : {video_path.resolve()}") + print(f"Output dir : {output_dir.resolve()}") + print(f"Source FPS : {source_fps:.2f}") + print(f"Target FPS : {target_fps:.2f}") + print("-" * 40) + + try: + while True: + success, frame = cap.read() + if not success: + break + + if frame_index + 1e-6 >= next_frame_to_save: + saved_count += 1 + image_path = output_dir / f"{saved_count:06d}.jpg" + cv2.imwrite(str(image_path), frame) + next_frame_to_save += frame_interval + + frame_index += 1 + finally: + cap.release() + + print(f"Saved {saved_count} RGB frames.") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Extract RGB frames from a recorded D405 video at 8 FPS.", + ) + parser.add_argument( + "--video", + type=Path, + default=None, + help="Path to the input video. If omitted, the latest video in d405_recordings is used.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=None, + help="Directory to save extracted RGB images.", + ) + parser.add_argument( + "--target-fps", + type=float, + default=DEFAULT_TARGET_FPS, + help="Frame extraction rate. Default is 8 FPS.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + video_path = args.video if args.video is not None else find_latest_video(DEFAULT_VIDEO_DIR) + if not video_path.exists(): + raise FileNotFoundError(f"Video file does not exist: {video_path.resolve()}") + + output_dir = args.output_dir + if output_dir is None: + output_dir = DEFAULT_OUTPUT_DIR / video_path.stem + + if args.target_fps <= 0: + raise ValueError("--target-fps must be greater than 0.") + + extract_frames(video_path, output_dir, args.target_fps) + + +if __name__ == "__main__": + main()