这台车是给农田巡检用的——装上小型激光雷达和深度相机,自己在田垄之间走、自己建图、检测到的异常上报回服务端。主控选的树莓派 5 16G,跑 ROS2 Humble。

实话说,树莓派的算力是真紧张——SLAM 跑起来 CPU 长期 60%+,Nav2 的代价地图刷新跟着拉满,再叠加摄像头流和 WebSocket 客户端,一不留神就热到降频。这篇就记一下为了让它”勉强能跑顺”踩过的一堆坑。

用 Docker 部署 ROS2

不在树莓派上直接装。理由不是 Docker 多优雅,而是直接装 ROS2 Humble + colcon + 一堆 nav2 包,整个系统会被污染得跟战场似的;编译还慢。Docker 至少能把环境隔住、版本钉死:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
FROM ros:humble-ros-base

RUN apt-get update && apt-get install -y \
ros-humble-slam-toolbox \
ros-humble-navigation2 \
ros-humble-nav2-bringup \
ros-humble-robot-localization \
&& rm -rf /var/lib/apt/lists/*

COPY src/ /opt/rover_ws/src/
COPY scripts/ /opt/rover_ws/scripts/

RUN cd /opt/rover_ws && \
. /opt/ros/humble/setup.sh && \
colcon build --symlink-install

COPY scripts/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
ENTRYPOINT ["/entrypoint.sh"]

compose 配置要注意两件事:

  • network_mode: host——ROS2 的 DDS 必须用主机网络,桥接模式下节点之间互相发现不到,这个坑能让你查一下午
  • devices 映射——激光雷达和深度相机的设备文件得显式挂进去
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
version: '3.8'

services:
rover-edge:
build: .
container_name: rover-edge
privileged: true
network_mode: host
volumes:
- ./src:/opt/rover_ws/src
- ./data/dev_ws:/opt/rover_ws/data
- /dev:/dev
environment:
- ROS_DOMAIN_ID=0
- FORCE_COLCON_BUILD=1
devices:
- /dev/ttyUSB0:/dev/ttyUSB0
- /dev/video0:/dev/video0

privileged: true 其实不一定真需要,但实际部署中我没空一个个细抠 cap,加上图省事——不是好习惯,但够用。

SLAM 用 slam_toolbox

slam_toolbox 的 async 模式在 ARM 上明显比 cartographer 稳——后者在 Humble 这边的构建脚本一直有点幺蛾子,能不碰就不碰。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# launch/slam.launch.py
from launch import LaunchDescription
from launch_ros.actions import Node

def generate_launch_description():
return LaunchDescription([
Node(
package='slam_toolbox',
executable='async_slam_toolbox_node',
name='slam_toolbox',
output='screen',
parameters=[{
'use_sim_time': False,
'slam_toolbox.map_file_name': '',
'slam_toolbox.mode': 'mapping',
'slam_toolbox.scan_topic': '/scan',
'slam_toolbox.publish_map': True,
}],
),
])

建图流程很常规——先遥控走一圈,再保存:

1
2
3
4
ros2 launch agri_rover_edge slam.launch.py
ros2 run teleop_twist_keyboard teleop_twist_keyboard
ros2 service call /slam_toolbox/save_map slam_toolbox/srv/SaveMap \
"{name: {data: 'farm_map'}}"

农田场景里有个反复出现的坑——田间路径又长又直,回环少。SLAM 在长直线上累计误差大,回到起点时可能错位半米以上——地图直接画歪了。

后来加了几个明显的人工路标(停车桩、铁丝架),故意让车经过、回扫,靠这些路标让算法触发 loop closure,地图才慢慢稳下来。

Nav2 默认配置在 PC 上跑得飞起,搬到树莓派上等同自爆——costmap 刷新和路径规划都太勤快。把几个频率拍下来才能喘口气:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
local_costmap:
local_costmap:
ros__parameters:
update_frequency: 2.0
width: 3
height: 3
resolution: 0.05

global_costmap:
global_costmap:
ros__parameters:
update_frequency: 1.0
width: 20
height: 20
resolution: 0.05

planner_server:
ros__parameters:
GridBased:
tolerance: 0.5

controller_server:
ros__parameters:
FollowPath:
max_vel_x: 0.3
max_vel_theta: 0.5

tolerance: 0.5 这个比较关键——默认值是 0.25。农田地面凹凸不平,定位误差天然就在 20-30cm,容差给太小,Nav2 反复重规划,一直觉得”还没到、还没到”——车在原地疯狂调整角度,看着像跳广场舞。

最大速度限到 0.3 m/s 倒不是为了省算力,是因为车的底盘抖动跟速度强相关。再快一点,激光雷达点云的畸变肉眼可见——SLAM 这边就要骂街了。

覆盖巡检路径生成

农田巡检不是 A 到 B 那种简单导航,而是要把整块田走一遍——类似扫地机器人的”弓字形”。手写一段路径生成,输入是地图栅格 + 行间距:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import numpy as np
from typing import List, Tuple

def generate_coverage_route(
map_data: np.ndarray,
resolution: float,
origin: Tuple[float, float],
line_spacing: float = 1.0,
) -> List[Tuple[float, float]]:
traversable = map_data == 0
rows, cols = np.where(traversable)
min_row, max_row = rows.min(), rows.max()
min_col, max_col = cols.min(), cols.max()

def to_world(row, col):
x = origin[0] + col * resolution
y = origin[1] + row * resolution
return (x, y)

route = []
direction = 1
current_row = min_row

while current_row <= max_row:
if direction == 1:
cols_range = range(min_col, max_col + 1)
else:
cols_range = range(max_col, min_col - 1, -1)

for col in cols_range:
if traversable[current_row, col]:
route.append(to_world(current_row, col))

current_row += int(line_spacing / resolution)
direction *= -1

return route

这写法对凸多边形田块够用了。但要是地块是凹的(中间有障碍物),生成的路径会”穿障碍”——实际跑的时候 Nav2 会绕开重规划,体验就有点拉胯。

比较正确的做法是按可达区域分块再分别覆盖,但目前业务上还用不到,先这样将就着。

与服务端的 WebSocket

车上线 → 跟服务端建长连接 → 服务端下指令、车上报状态。WebSocket 客户端写得朴素,关键是断线重连和”消息按 action 分发”:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import asyncio
import json
import websockets

class RoverEdgeClient:
def __init__(self, server_url: str):
self.server_url = server_url
self._ws = None

async def connect(self):
while True:
try:
async with websockets.connect(self.server_url) as ws:
self._ws = ws
logger.success("已连接到服务端")
await self._message_handler()
except Exception as e:
logger.error(f"连接失败: {e}")
await asyncio.sleep(5)

async def _message_handler(self):
async for message in self._ws:
data = json.loads(message)
action = data.get("action")
if action == "navigate":
await self._handle_navigate(data["data"])
elif action == "start_mapping":
await self._handle_start_mapping()
elif action == "stop_mapping":
await self._handle_stop_mapping()
elif action == "update_config":
await self._handle_update_config(data["data"])

async def _handle_navigate(self, payload):
target_x = payload["x"]
target_y = payload["y"]
goal_pose = self._create_goal_pose(target_x, target_y)
self._nav_client.send_goal(goal_pose)
await self._report_navigation_status("navigating")

async def report_position(self, x, y, theta):
await self._ws.send(json.dumps({
"action": "position_update",
"data": {"x": x, "y": y, "theta": theta}
}))

async def report_anomaly(self, anomaly_type, details):
await self._ws.send(json.dumps({
"action": "anomaly",
"data": {"type": anomaly_type, "details": details}
}))

asyncio.sleep(5) 是个偷懒的固定间隔。讲究的话应该用指数退避,避免雷击效应——不过这个场景只有一台车,无所谓。

异常检测

巡检过程中要持续盯几件事:“我没卡死、我还能动、电池还能撑、激光雷达还在出数据”。一个简单的注册表把这些检查项管起来:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class AnomalyDetector:
def __init__(self):
self._anomaly_types = {
"low_battery": self._check_low_battery,
"stuck": self._check_stuck,
"motor_stall": self._check_motor_stall,
"lidar_timeout": self._check_lidar_timeout,
"depth_stream_error": self._check_depth_stream,
"nav2_timeout": self._check_nav2_timeout,
}

def check(self, rover_state: dict) -> List[dict]:
anomalies = []
for anomaly_type, checker in self._anomaly_types.items():
if checker(rover_state):
anomalies.append({
"type": anomaly_type,
"position": rover_state["position"],
"timestamp": time.time(),
"details": rover_state,
})
return anomalies

def _check_low_battery(self, state: dict) -> bool:
return state.get("battery_voltage", 0) < 3.3

def _check_stuck(self, state: dict) -> bool:
if state.get("speed", 0) > 0.01:
self._last_moving_time = time.time()
return time.time() - self._last_moving_time > 30

检测到异常不只是发条消息,还会在地图上打一个 marker——农户后续可以照着这个地图位置直接走过去查看。这个比”发邮件告诉你东 35 北 12 处异常”实用太多了。

ServerManager 这个项目要管一堆远端节点——可能是云服务器、可能是边缘的工控机、也可能是机器人小车。中控服务端要给它们下指令、收它们的状态。HTTP 轮询太重,所以一开始就选了 WebSocket。

WebSocket 一行代码就连上了,难的是上线之后的那些事——网络抖一下连接断了、token 过期了得重新认证、消息怎么防篡改、命令怎么不让节点端乱执行。下面把这几块各自的思路捋一下。

永远在重连

节点客户端的主体是一个 while True:连接、跑消息循环、断开、等几秒、再连。所有异常都不能让它跳出最外层:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import aiohttp
import asyncio
import json

class NodeWebSocketClient:
def __init__(self):
self._session: aiohttp.ClientSession = None
self._ws: aiohttp.ClientWebSocketResponse = None

async def run_forever(self) -> None:
while True:
try:
self._session, auth_status, access_token = await self._authenticate()
if not auth_status:
await asyncio.sleep(5)
continue

async with self._session.ws_connect(
self._ws_url(access_token),
autoping=True,
) as ws:
logger.success("WebSocket 已连接")
self._ws = ws
await self._message_handler()

except aiohttp.ClientError as err:
logger.error(f"WebSocket 连接失败: {err}")
finally:
self._ws = None
await asyncio.sleep(5)

autoping=True 让 aiohttp 自动发 ping,对付 NAT 超时和负载均衡器主动断流。

固定 5 秒重连间隔,在节点规模一大就要换成带抖动的退避——想象一下,一万台机器同一秒全断了再同一秒全重连,服务端那边瞬间就被打挂了。

Bearer Token + 消息签名

光有 Token 是不够的。Token 只能证明”这个客户端有权连”,但连接建立之后,每一条消息是不是真的来自这个客户端、有没有被中间人改过——那是另一个问题。

所以关键操作的消息额外做一层 HMAC 签名。

HTTP 拿 token:

1
2
3
4
5
6
7
async def _authenticate(self):
async with aiohttp.ClientSession() as session:
async with session.post(self._auth_url, json=self._auth_data) as resp:
data = await resp.json()
if resp.status == 200:
return session, True, data["access_token"]
return session, False, None

消息签名用 HMAC-SHA256,key 是节点专属 token(和 access_token 不是一个东西,是节点注册时发的长期密钥):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import hashlib
import hmac

def generate_signature(node_token: str, action: str, payload: dict) -> str:
message = f"{action}:{json.dumps(payload, sort_keys=True)}"
signature = hmac.new(
node_token.encode(),
message.encode(),
hashlib.sha256,
).hexdigest()
return signature

async def websocket_send_json(self, data: dict) -> None:
action = data.get("action", "")
payload = data.get("data", {})

if self._needs_signature(action):
data["_sign"] = generate_signature(self.node_token(), action, payload)

await self._ws.send_str(json.dumps(data))

服务端验证:

1
2
3
def verify_signature(node_token: str, action: str, payload: dict, signature: str) -> bool:
expected = generate_signature(node_token, action, payload)
return hmac.compare_digest(expected, signature)

两个细节要划重点:

hmac.compare_digest 千万别图省事换成 ==——后者有时序攻击风险。这种事自己写很容易写错,能用标准库的就用,别炫技。

sort_keys=True 也是必须的——客户端和服务端要是用不同语言,dict 序列化的字段顺序可能不一致,签名当场对不上。

Action 分发

服务端下发的消息固定有个 action 字段。客户端按 action 分发到对应处理器,写法就是个注册表:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class NodeActionDispatcher:
def __init__(self, client: NodeWebSocketClient):
self._client = client
self._handlers: Dict[str, Callable] = {
"execute_command": self._handle_execute_command,
"download_file": self._handle_download_file,
"update_config": self._handle_update_config,
"restart": self._handle_restart,
}

async def dispatch(self, action: str, payload: dict, raw_data: dict) -> bool:
handler = self._handlers.get(action)
if handler is None:
logger.warning(f"未知的 action: {action}")
return True

try:
await handler(payload)
except Exception as e:
logger.error(f"处理 action {action} 失败: {e}")
await self._client.websocket_send_json({
"action": f"{action}_response",
"data": {"success": False, "error": str(e)},
})

return True

async def _handle_execute_command(self, payload: dict):
command = payload.get("command")
timeout = payload.get("timeout", 30)

if not self._check_command_policy(command):
raise ValueError(f"命令被安全策略禁止: {command}")

result = await self._execute_command(command, timeout)

await self._client.websocket_send_json({
"action": "execute_command_response",
"data": {"success": True, "output": result},
})

注意未知 action 用的是 warning,不是 error,并且返回 True 继续处理后续消息——服务端可能下发了节点不认识的新 action(比如客户端版本太老),但这绝不应该让节点断连。

消息循环用了 Python 3.10 的 match 来分发消息类型,比一堆 if elif 清爽:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
async def _message_handler(self) -> None:
while self._ws and not self._ws.closed:
msg = await self._ws.receive()

match msg.type:
case aiohttp.WSMsgType.TEXT:
try:
data = json.loads(msg.data)
except json.JSONDecodeError:
logger.warning("收到非法 JSON")
continue

action = data.get("action")
payload = data.get("data")

if not isinstance(action, str) or not action:
logger.warning("收到缺少 action 的消息")
continue

should_continue = await self._dispatcher.dispatch(action, payload, data)
if not should_continue:
return

case aiohttp.WSMsgType.CLOSE:
logger.info("连接已断开")
return

命令策略

execute_command 这种 action 是 RCE 风险最大的地方——服务端被打穿,就能直接控制节点。所以节点这边自己也得有一道把关。

最低限度是一个黑名单:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class CommandPolicy:
def __init__(self, config: dict):
self._disable_list = config.get("disable_command_list", "")
self._strict = config.get("strict", True)

def check(self, command: str) -> bool:
disabled = [c.strip() for c in self._disable_list.split(",") if c.strip()]

for pattern in disabled:
if pattern == "*":
return False
if pattern.startswith("*") and command.endswith(pattern[1:]):
return False
if pattern.endswith("*") and command.startswith(pattern[:-1]):
return False
if command == pattern:
return False

return True

说句实在话:黑名单的安全性永远低于白名单——总能找到等价命令绕过 (rm -rf 禁了还有 find ... -delete 嘛)。

但白名单在通用节点管理这种场景几乎没法定义——你哪知道用户会想跑啥命令?所以这里的策略只能定位为”防呆”,不是真正的安全边界。

真正的安全边界应该在节点之外——以受限用户运行节点、用 cgroup/容器隔离。

运行时服务

节点上有一堆模块——终端服务、文件服务、监控服务——都依赖同一个会话和连接。包成一个 NodeRuntimeServices 集中管理生命周期,省心:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class NodeRuntimeServices:
def __init__(self, client, session, connection, access_token, config_getter):
self.client = client
self.session = session
self.connection = connection
self.access_token = access_token
self.config_getter = config_getter

@classmethod
def create(cls, client, session, connection, access_token, config_getter, command_policy):
services = cls(client, session, connection, access_token, config_getter)
services.terminal_service = TerminalService(session, connection)
services.file_service = FileService(session, connection)
services.monitor_service = MonitorService(session, connection)
return services

async def close(self):
await self.terminal_service.close()
await self.file_service.close()
await self.monitor_service.close()

注意 config_getter 是个 callable 而不是 dict——配置可能在运行时被服务端下发更新,传 getter 才能拿到最新值。

要是图省事传了 dict,会出现”改了配置但节点还用旧的”这种说不清的灵异问题,做配置热更新的时候被这个坑撞过,痛苦面具直接戴稳。

最早的规则引擎,就直白的 if 条件 then 动作。演示环境下挺好,一上真设备就开始翻车:

  • 温度 31℃,规则每秒都在判断”大于 30”,风扇一秒一开一关——你听过风扇打电报的声音吗?我听过。
  • 服务重启之后,温度还是 31℃,可规则就是不触发,因为没有”变化”事件
  • 规则 A 执行动作改了某个设备状态,刚好命中规则 B 的条件,规则 B 又改回去——俩规则互相戳着死循环

这篇记一下后来逐步打上去的几个补丁:边沿触发、分支路由、启动补评估、防重入和防抖。

边沿触发

最朴素的修法:只在条件从 False 变 True 的那个瞬间触发,持续 True 不再重复发:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class AutomationEngine:
def __init__(self):
self._condition_last_results: Dict[str, bool] = {}

def _evaluate_conditions(self, rule: AutomationRule, event: TriggerEvent) -> bool:
if not rule.conditions:
return True

enabled_conditions = [c for c in rule.conditions if c.enabled]
if not enabled_conditions:
return True

all_pass = True
for condition in enabled_conditions:
if condition.condition_type == ConditionType.RANGE:
if not self._check_range_condition(condition, event):
all_pass = False
break
elif condition.condition_type == ConditionType.COMPARE:
if not self._check_compare_condition(condition, event):
all_pass = False
break

last_result = self._condition_last_results.get(rule.rule_id, False)
self._condition_last_results[rule.rule_id] = all_pass

if not all_pass:
return False

if last_result:
# 持续满足,跳过
return False

# False → True,触发
return True

对应回温度那个例子:

1
2
3
4
5
时间 1: 29℃ → False        → 不触发
时间 2: 31℃ → False→True → 触发
时间 3: 32℃ → True→True → 不触发
时间 4: 29℃ → True→False → 重置(下次再过 30 又能触发)
时间 5: 31℃ → False→True → 触发

边沿触发还藏了一层语义:要想再触发一次”开风扇”,温度必须先掉回去。要是用户想要”持续超过 30 就每隔 10 分钟通知一次”,那是另一种东西——周期触发,应该在 trigger 那一侧实现,别塞进条件评估里搅浑。

分支:不只是 if,还有 else

某些条件天然就有两种结果。比如”温度在 25-30 之间”——超出范围其实也是有意义的事件,可能要触发”开警报”,而不是简单的”开制冷”。

所以条件类型里加了几个支持双分支输出的类型:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class ConditionType(str, Enum):
RANGE = "range"
COMPARE = "compare"
TIME_RANGE = "time_range"
LOGIC = "logic"
EXPRESSION = "expression"
IF_ELSE = "if_else"

# True / False 各自走哪个输出端口
CONDITION_TRUE_OUTPUT = {
ConditionType.RANGE: "inRange",
ConditionType.COMPARE: "result",
ConditionType.IF_ELSE: "true",
}
CONDITION_FALSE_OUTPUT = {
ConditionType.RANGE: "outRange",
ConditionType.IF_ELSE: "false",
}

分支路由

执行时按条件结果挑出”当前活跃分支上的动作”——其他分支上挂的动作这次就不执行:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def _resolve_branched_actions(self, rule, event, enabled_conditions):
condition_results = {}
for cond in enabled_conditions:
if cond.condition_type == ConditionType.RANGE:
condition_results[cond.condition_id] = self._check_range_condition(cond, event)
elif cond.condition_type == ConditionType.COMPARE:
condition_results[cond.condition_id] = self._check_compare_condition(cond, event)

branch_map = {}
for cid, result in condition_results.items():
cond = cond_by_id.get(cid)
ctype = cond.condition_type if cond else ConditionType.EXPRESSION
if result:
branch_map[cid] = CONDITION_TRUE_OUTPUT.get(ctype, "result")
else:
false_output = CONDITION_FALSE_OUTPUT.get(ctype)
if false_output:
branch_map[cid] = false_output

# 分支也要做边沿检测,否则会持续触发
state_key = str(sorted(branch_map.items()))
cache_key = f"branch:{rule.rule_id}"
last_state = self._condition_last_results.get(cache_key)
self._condition_last_results[cache_key] = state_key

if state_key == last_state:
return None

executable = []
for action in rule.actions:
branch_info = action.config.get("_condition_branch")
if not branch_info:
executable.append(action) # 不挂分支的动作总是执行
continue
cid = branch_info.get("condition_id")
output = branch_info.get("output")
if cid in branch_map and branch_map[cid] == output:
executable.append(action)

return executable

注意分支也得做边沿检测——不然就算条件没变,每次触发器一到都会重新发一遍动作。

启动补评估

刚重启完,规则引擎里 _condition_last_results 是空的。这时候要是设备状态本来就已经超阈值了,因为没有”变化”事件,规则不会触发。

结果就是个挺尴尬的场景:你重启服务,本来开着的风扇被关掉之后再也开不回来——除非温度先掉下来再升上去。

补丁是启动后做一次全量扫描:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def _bootstrap_startup_entity_evaluation(self):
if not self.device_manager:
return

state_machine = getattr(self.device_manager, "_state_machine", None)
if not state_machine:
return

watch_entity_ids = self._collect_startup_watch_entity_ids()
if not watch_entity_ids:
return

all_states = state_machine.get_all()

for state_entity_id, state in all_states.items():
if not state or not getattr(state, "attributes", None):
continue

extra = getattr(state.attributes, "extra", None) or {}
raw_entity_db_id = extra.get("entity_db_id")
if raw_entity_db_id is None:
continue

entity_db_id = int(raw_entity_db_id)
if entity_db_id not in watch_entity_ids:
continue

new_value = getattr(state, "state", None)
scale = extra.get("scale", 1.0)

self._on_trigger(TriggerEvent(
event_type="entity_value_changed",
entity_id=entity_db_id,
old_value=None, # 启动时没有旧值
new_value=new_value,
scale=scale,
additional_data={"bootstrap": True},
))

additional_data={"bootstrap": True} 这个标记给下游一个信号——这条事件是补评估出来的,不是真的状态变化。

举个例子:某些行为(比如发钉钉通知)可能不希望在启动补评估时触发,毕竟没人想半夜被一堆”温度已恢复正常”的通知吵醒——根据这个标记跳过即可。

防重入

规则 A 触发后写了设备 X,X 的状态变化又满足规则 A 自己的条件——经典自激振荡,俩规则你来我往,CPU 开始冒烟。

修法很俗,但有效:执行中的规则 ID 进一个 set,在 set 里的规则跳过新触发:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class AutomationEngine:
def __init__(self):
self._executing_rules: set = set()

def _on_trigger(self, event: TriggerEvent):
with self._lock:
for rule in self.rules.values():
if not rule.enabled:
continue
if rule.rule_id in self._executing_rules:
continue
# ... 匹配逻辑

def _execute_rule_async(self, rule, context, event):
rule_id = rule.rule_id
with self._lock:
self._executing_rules.add(rule_id)
try:
for action in actions:
await self._execute_action(action, context)
finally:
with self._lock:
self._executing_rules.discard(rule_id)

这招只防”规则触发自己”。规则 A → 规则 B → 规则 A 这种隔了一层的循环挡不住——要是真担心,得把执行栈传下去。目前没遇到这种规则配置,先在心里记着这个边界。

防抖和冷却

传感器抖一下,瞬间穿过阈值再回来,就能给你触发一次没意义的开关。Trigger 这一层加个防抖窗口:

1
2
3
4
5
6
7
8
9
10
11
class BaseTrigger:
def __init__(self):
self._last_trigger_time: float = 0
self._debounce_until: float = 0

def check_debounce(self) -> bool:
now = time.time()
if now < self._debounce_until:
return False
self._debounce_until = now + self.debounce_seconds
return True

规则级再加一个冷却时间和每日次数上限,给用户兜底——这样就算逻辑写错了,也不至于一天给某个设备来个上千次开关:

1
2
3
4
5
6
7
8
9
10
11
12
13
class AutomationEngine:
def _check_cooldown(self, rule: AutomationRule) -> bool:
stats = self.execution_stats.get(rule.rule_id, {})
last_execution = stats.get('last_execution_at')
if last_execution is None:
return True
return time.time() - last_execution >= rule.cooldown_seconds

def _check_daily_limit(self, rule: AutomationRule) -> bool:
if rule.daily_limit is None:
return True
stats = self.execution_stats.get(rule.rule_id, {})
return stats.get('total_executions', 0) < rule.daily_limit

防抖和冷却语义不一样:防抖是”短时间内只算一次”,冷却是”一次执行完之后多久不再执行”。两者都得有。

图编译器

前端有个可视化的节点编辑器——条件节点、动作节点、连接线,用户拖来拖去。后端规则用的是另一套模型(rule、condition、action 三张表)。所以中间需要把节点图”编译”成规则。

主要的活儿,是从连接线里反推每个动作挂在哪个条件的哪个输出端口上,写回 action.config 里:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def _rebuild_condition_branches(self, rule: AutomationRule):
nodes_config = rule.nodes_config
if not nodes_config:
return

connections = nodes_config.get('connections', [])

node_to_condition = {}
for cond in rule.conditions:
node_id = cond.config.get("_node_id") if cond.config else None
if node_id:
node_to_condition[node_id] = cond.condition_id

node_to_action = {}
for action in rule.actions:
node_id = action.config.get("_node_id") if action.config else None
if node_id:
node_to_action[node_id] = action.action_id

for conn in connections:
from_node = conn.get('source')
from_output = conn.get('sourceOutput')
to_node = conn.get('target')

if from_node in condition_node_set and to_node in action_node_set:
backend_action_id = node_to_action[to_node]
backend_condition_id = node_to_condition[from_node]
for action in rule.actions:
if action.action_id == backend_action_id:
action.config["_condition_branch"] = {
"condition_id": backend_condition_id,
"output": from_output,
}

_condition_branch_node_id 这类带下划线前缀的字段,是后端内部用的标记,前端不会展示——这种约定哪天破了会让你查半天,最好早早写进文档里,免得自己挖坑自己跳。

这个农业 IoT 项目最早只接 Modbus 传感器,业务代码到处都是 client.read_holding_registers(...)——那时候挺潇洒。

后来加了 MQTT 智能插座,又来了 ESPHome 的 DIY 板子,再后来还得支持 OpenMQTTGateway 转出来的 RF 设备……每多一种协议,业务层就得跟着改一圈。改到第三种我就受不了了——再这么下去,业务代码迟早成一锅粥。

最后还是绕回去,老老实实做了一层设备抽象。这篇记一下这层是怎么一点点长出来的,以及踩到的两个比较重要的坑:Modbus 串口阻塞,还有启动时的全量预读取。

Driver 注册表

不同协议各写一个 Driver,外面统一拿。注册表本身就是个字典:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from enum import Enum
from typing import Dict, Type

class DeviceConfigType(str, Enum):
MODBUS = "modbus"
MQTT = "mqtt"
ESPHOME_NATIVE = "esphome_native"
ESPHOME_MQTT = "esphome_mqtt"
OMG_GATEWAY = "omg_gateway"

class DeviceManager:
def __init__(self):
self.drivers: Dict[DeviceConfigType, BaseDriver] = {
DeviceConfigType.MODBUS: ModbusDriver(),
DeviceConfigType.MQTT: MQTTDriver(),
DeviceConfigType.ESPHOME_NATIVE: ESPHomeNativeDriver(),
DeviceConfigType.ESPHOME_MQTT: ESPHomeMQTTDriver(),
DeviceConfigType.OMG_GATEWAY: OMGDriver(),
}

每个 Driver 实现一组最小接口:

1
2
3
4
5
6
7
8
9
10
11
12
class BaseDriver(ABC):
@abstractmethod
async def connect(self, config: dict) -> bool: ...

@abstractmethod
async def read_entities(self, config: dict) -> list: ...

@abstractmethod
async def write_entity(self, config: dict, entity_id: str, value: any) -> bool: ...

@abstractmethod
async def disconnect(self): ...

新协议加进来,就是写一个新 Driver、在注册表里登记一下——业务层一行不用改。这种感觉爽到没朋友。

统一实体模型

每个 Driver 内部怎么访问设备,它自己的事。但对外吐出来的,必须是统一的 Entity

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from dataclasses import dataclass
from typing import Any, Optional

@dataclass
class Entity:
entity_id: str
device_id: int
name: str
entity_type: str # sensor / switch / light / button
value: Any
unit: Optional[str]
readable: bool
writable: bool
last_update: float

业务代码只跟 Entity 打交道,根本不知道底下是 Modbus 寄存器还是 MQTT topic:

1
2
3
4
5
temp_entity = device_manager.get_entity("greenhouse_1.temperature")
print(f"当前温度: {temp_entity.value} {temp_entity.unit}")

relay_entity = device_manager.get_entity("greenhouse_1.relay_1")
await device_manager.write_entity(relay_entity.entity_id, True)

这套抽象后来接 Home Assistant 的时候也省了大事——HA 本身就是 Entity 概念,对得上号。

映射缓存

设备的协议地址映射(哪个 entity 对应哪个寄存器 / topic / key)属于典型的”读 1000 次、改 1 次”——每次都查数据库太亏。带个 5 分钟 TTL 的缓存:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from dataclasses import dataclass
import time

@dataclass
class CachedMappingInfo:
entity_id: int
entity_type: str
scale: float
register_address: Optional[int] # Modbus
mqtt_topic: Optional[str] # MQTT
esphome_key: Optional[str] # ESPHome

class DeviceManager:
def __init__(self):
self.mapping_cache: Dict[int, Dict[int, CachedMappingInfo]] = {}
self.config_cache: Dict[int, CachedDeviceConfig] = {}
self._cache_ttl = 300
self._cache_timestamps: Dict[str, float] = {}

def get_mapping(self, device_id: int, entity_id: int):
cache_key = f"mapping_{device_id}"
if self._is_cache_expired(cache_key):
self._refresh_mapping_cache(device_id)
return self.mapping_cache.get(device_id, {}).get(entity_id)

def _is_cache_expired(self, cache_key: str) -> bool:
timestamp = self._cache_timestamps.get(cache_key, 0)
return time.time() - timestamp > self._cache_ttl

5 分钟是拍脑袋定的——长一点会有”改了配置半天不生效”的体感问题;短一点意义又不大。后来加了主动 invalidate,改配置时直接清缓存,TTL 就只是个保底。

端口级独立轮询

这个坑是上线之后才暴露的。

Modbus 设备共用一根 RS-485 串口。某个设备响应慢(断线超时要等 1 秒),同一串口上的其他设备就全跟着卡。一台慢设备,能把整根串口的节奏拖垮——那一刻血压直接上来。

办法是按”端口”维度做隔离:每根串口/TCP 连接,配一个单线程的 ThreadPoolExecutor,互不干扰:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from concurrent.futures import ThreadPoolExecutor, Future

class DeviceManager:
def __init__(self):
self._modbus_port_workers: Dict[str, ThreadPoolExecutor] = {}
self._modbus_port_futures: Dict[str, Future] = {}

def _get_port_worker(self, port_key: str) -> ThreadPoolExecutor:
if port_key not in self._modbus_port_workers:
self._modbus_port_workers[port_key] = ThreadPoolExecutor(
max_workers=1,
thread_name_prefix=f"modbus-{port_key}",
)
return self._modbus_port_workers[port_key]

def poll_device(self, device: Device):
if device.config_type == DeviceConfigType.MODBUS:
port_key = self._get_modbus_port_key(device)
worker = self._get_port_worker(port_key)
future = worker.submit(self._poll_modbus_device, device)
self._modbus_port_futures[port_key] = future
else:
self._poll_device(device)

def _get_modbus_port_key(self, device: Device) -> str:
config = device.modbus_config
if config.serial_port:
return f"serial_{config.serial_port}"
return f"tcp_{config.host}:{config.port}"

max_workers=1 不是手抖——Modbus 在同一根串口上本来就是顺序协议,并行发指令立马乱码。这里说的”并行”,指的是不同串口之间。

启动预读取

启动后所有设备的当前值都是空的,得等第一轮轮询完才能填上。如果设备多,串行连接 + 串行预读取,能让你等几分钟——UI 上看就是满屏的”–“。

8 线程并发拉一下就好了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from concurrent.futures import ThreadPoolExecutor

class DeviceManager:
def __init__(self):
self._startup_preload_executor = ThreadPoolExecutor(
max_workers=8,
thread_name_prefix="startup-preload",
)
self._startup_preload_futures: Dict[int, Future] = {}

def connect_all_devices(self):
devices = Device.select().where(Device.enabled == True)
for device in devices:
future = self._startup_preload_executor.submit(
self._connect_and_preload, device
)
self._startup_preload_futures[device.id] = future

def _connect_and_preload(self, device: Device):
try:
driver = self.drivers[device.config_type]
driver.connect(device.config)
entities = driver.read_entities(device.config)
for entity in entities:
self._update_entity_state(device.id, entity)
except Exception as e:
logger.error(f"设备 {device.name} 连接失败: {e}")

注意,这个预读取池是按”设备”并发的,不是按”端口”。预读取阶段串口的串行约束依然在,会被端口轮询那一层挡住——所以不用担心冲突。

事件总线

设备状态变了之后,下游想做的事不止一种:UI 刷新、自动化规则触发、写历史库、推 MQTT 广播。要是在 Driver 里一个个 hook,耦合就乱成毛线团了。

所以加了个简单的事件总线:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from enum import Enum
from typing import Callable, Dict, List

class EventType(str, Enum):
STATE_CHANGED = "state_changed"
DEVICE_CONNECTED = "device_connected"
DEVICE_DISCONNECTED = "device_disconnected"

class EventBus:
def __init__(self):
self._listeners: Dict[EventType, List[Callable]] = {}

def listen(self, event_type: EventType, callback: Callable):
if event_type not in self._listeners:
self._listeners[event_type] = []
self._listeners[event_type].append(callback)

def unsubscribe():
self._listeners[event_type].remove(callback)
return unsubscribe

def emit(self, event_type: EventType, data: dict):
for callback in self._listeners.get(event_type, []):
try:
callback(data)
except Exception as e:
logger.error(f"事件处理失败: {e}")

事件总线最容易出的事,是回调里又触发别的事件——好家伙,循环了。这块靠业务约定撑着:回调里只读、不写设备。真有写需求,就走自动化引擎那一层(那边有防重入机制,下一篇细讲)。

游戏自动化里,能写规则就尽量写规则——便宜、可控、好调。但有些场景规则真写不动。

比如卡牌游戏的考试阶段,出牌顺序、资源分配会被手牌、对手压制、剩余回合等十几个变量同时拉扯,写 if-else 写到最后你只想砸键盘。

所以这个项目里把”培育”和”考试”两段都封成了 Gymnasium 环境,让 PPO 自己去摸索策略。训练好的模型通过 HTTP 服务对外暴露,主程序跑自动化的时候按需 query。下面是中间几个选择背后的考虑。

观测空间

游戏状态怎么塞进神经网络,是最头疼的一步。最后定的是三段结构:全局状态、每个动作的特征、动作掩码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import gymnasium as gym
from gymnasium import spaces

class GameExamEnv(gym.Env):
def __init__(self):
super().__init__()
self.global_dim = 60 # 全局状态维度
self.action_feature_dim = 100 # 每个动作的特征维度
self.max_actions = 50 # 手牌 + 饮料 + 结束回合

self.observation_space = spaces.Dict({
'global': spaces.Box(-20.0, 20.0, shape=(self.global_dim,), dtype=np.float32),
'action_features': spaces.Box(
-20.0, 20.0,
shape=(self.max_actions, self.action_feature_dim),
dtype=np.float32,
),
'action_mask': spaces.Box(0.0, 1.0, shape=(self.max_actions,), dtype=np.float32),
})

“每个动作一个向量”这种 action_features 设计,是为了让模型在动作集变化时还能复用——手牌每回合都不一样,但每张牌的特征结构是统一的。

全局状态

全局状态主要装”剩余进度”和”当前资源”。为了让训练稳一点,所有值都归一化到大致 [0, 1]:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def _global_observation(self) -> np.ndarray:
state = self.runtime.state

global_values = {
# 进度
'step_ratio': state['step'] / max(self.scenario.steps, 1),
'remaining_step_ratio': max(self.scenario.steps - state['step'], 0) / max(self.scenario.steps, 1),
'audition_progress': state['audition_index'] / max(len(self.scenario.audition_sequence), 1),

# 资源
'stamina_ratio': state['stamina'] / max(state['max_stamina'], 1.0),
'produce_point_ratio': state['produce_points'] / 150.0,
'fan_vote_ratio': state['fan_votes'] / 5000.0,

# 三维参数
'vocal_ratio': state['vocal'] / parameter_scale,
'dance_ratio': state['dance'] / parameter_scale,
'visual_ratio': state['visual'] / parameter_scale,

# 卡组
'deck_quality': state['deck_quality'] / 20.0,
'drink_quality': state['drink_quality'] / 10.0,
'deck_size_ratio': len(self.runtime.deck) / 40.0,
}

return np.array(
[float(global_values[name]) for name in self.global_feature_names],
dtype=np.float32,
)

那个 60 维不是精心算出来的——是把”模型可能用得上的状态量”全往里塞之后凑出的。事后看,估计有一半是冗余的,但训练效果还能看,就懒得动了。

每个动作的特征

每个候选动作(出哪张牌、喝啥饮料、结束回合)都编成定长向量。类型 one-hot,效果 one-hot,剩下都是数值:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def _candidate_feature(self, candidate) -> np.ndarray:
action_type_vec = self.taxonomy.encode_actions([candidate.action_type])
effect_types = self._produce_effect_types(candidate)
effect_vec = self.taxonomy.encode_produce_effects(effect_types)

numeric = np.array([
candidate.stamina_delta / max(state['max_stamina'], 1.0),
candidate.produce_point_delta / 100.0,
candidate.success_probability,
len(candidate.produce_effect_ids) / 8.0,
1.0 if candidate.available else 0.0,
state['stamina'] / max(state['max_stamina'], 1.0),
], dtype=np.float32)

return np.concatenate([action_type_vec, effect_vec, numeric]).astype(np.float32)

动作掩码

非法动作(已经打掉的牌、体力不够发动的牌)必须屏蔽。不然模型会傻乎乎地一直选它,然后撞墙撞个没完:

1
2
3
4
5
def action_masks(self) -> np.ndarray:
return np.array([
bool(candidate.payload.get('available', False))
for candidate in self._candidates
], dtype=bool)

用的是 sb3-contrib 的 MaskablePPO,掩码直接进 policy 网络,违法动作的概率会被强行压成 0——干净利落。

课程学习

你要是头铁,直接把模型扔进最难的”NIA Master 全流程”开训,那场面挺惨的——它一路吃负奖励,啥也学不会,跟个迷茫的萌新一头撞墙似的。

所以搞了套课程,从最简单的”初中间考试”起步,一关一关往上爬:

1
2
3
4
5
6
7
8
9
10
11
CURRICULUM_STAGES = [
"初中间考试",
"初最终考试",
"NIA中间考试",
"NIA最终考试",
"NIA选拔",
"初Regular全流程",
"初Master全流程",
"NIA Pro全流程",
"NIA Master全流程",
]

每个阶段训若干 timesteps,然后评估、存 checkpoint,再进下一关:

1
2
3
4
5
6
7
def run_curriculum(self, stages, timesteps_per_stage=131072):
for stage_name in stages:
env = self._create_env_for_stage(stage_name)
model = MaskablePPO("MultiInputPolicy", env, verbose=1)
model.learn(total_timesteps=timesteps_per_stage)
quality = self._evaluate_model(model, stage_name)
self._save_checkpoint(model, stage_name, quality)

timesteps_per_stage=131072 是 2^17,没啥特殊原因,纯属 PPO 跑下来手感比较顺的一个量级。

自举训练

没人类示范数据,也不想花钱让 GPT 当老师。于是搞了一套”自己教自己”:多个 checkpoint 在固定 seed 上各跑一遍,挑表现最好的轨迹做行为克隆,再 RL 微调一把:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def self_bootstrap(self, iterations=3):
for iteration in range(iterations):
# 多 checkpoint 跑固定种子
candidates = []
for checkpoint in self.checkpoints:
for seed in self.evaluation_seeds:
episode = self._run_episode(checkpoint, seed)
candidates.append(episode)

# 排序选最优
candidates.sort(key=lambda c: c.quality_key(), reverse=True)
best_trajectories = candidates[:10]

# BC 蒸馏
bc_model = self._bc_train(best_trajectories)

# RL 微调
rl_model = self._rl_finetune(bc_model, timesteps=50000)

# 谁稳留谁
if self._evaluate(rl_model) > self._evaluate(bc_model):
self.best_model = rl_model
else:
self.best_model = bc_model

排序 key 这事得多说一句——优先级是”无效动作越少越好 > 排名越高越好 > 分数越高越好”。

为啥不把”分数高”放第一位?因为那样会挑出一堆”靠运气打高分”的轨迹,BC 学完之后会被这些路径带歪,模型反倒更不稳了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@dataclass
class EpisodeCandidate:
checkpoint_path: Path
seed: int
records: list
total_reward: float
terminal_score: float
invalid_actions: int
clear_rank: int

def quality_key(self):
return (
-int(self.invalid_actions),
int(self.clear_rank),
float(self.terminal_score),
float(self.total_reward),
)

推理服务化

训练好的模型没塞进主程序——stable_baselines3 那套依赖太重,整进 GUI 应用包能膨胀得吓人。

所以单独起了个 FastAPI 服务,主程序通过 HTTP 调:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from fastapi import FastAPI
from sb3_contrib import MaskablePPO

app = FastAPI()
model = MaskablePPO.load("best_model.zip")

@app.post("/api/inference/predict")
async def predict(request: PredictRequest):
obs = build_observation(request.state, request.legal_actions)
action_mask = build_action_mask(request.legal_actions)

action_value, _ = model.predict(
obs,
deterministic=True,
action_masks=action_mask,
)

action_index = int(action_value)
return {
"action_index": action_index,
"action_id": request.legal_actions[action_index]["action_id"],
"confidence": float(action_value),
}

主程序这边就是个朴素 HTTP 客户端,超时定短点,失败了就 fallback 到规则策略:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class RLInferenceClient:
def __init__(self, base_url="http://127.0.0.1:8100"):
self._base_url = base_url
self._session = requests.Session()

def predict(self, exam_state, legal_actions, deterministic=True):
payload = {
**exam_state,
"legal_actions": legal_actions,
"deterministic": deterministic,
}
try:
resp = self._session.post(
f"{self._base_url}/api/inference/predict",
json=payload,
timeout=10.0,
)
resp.raise_for_status()
return resp.json()
except requests.RequestException as exc:
logger.warning(f"RL 推理请求失败: {exc}")
return None

score 模式 vs clear 模式

考试有两个互相打架的目标:分数尽量高、能否过线。一套 reward 配置很难两头都讨好,所以干脆分了两套:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def build_reward_config(mode: str) -> RewardConfig:
if mode == "score":
# 拼分数
return RewardConfig(
score_weight=1.0,
clear_bonus=0.5,
invalid_action_penalty=-0.25,
)
elif mode == "clear":
# 保过线
return RewardConfig(
score_weight=0.3,
clear_bonus=2.0,
overclear_penalty=-0.1, # 过线后再硬冲会被惩罚
invalid_action_penalty=-0.25,
)

overclear_penalty 这个是后加的——一开始 clear 模式训出来的策略,过线了还在硬打、停不下来。因为对它来说反正 reward 越多越好嘛。

加了过线后的负反馈,它才学会”得了得了,过了就停”。

中间最容易踩的一个大坑:reward 一改就得重训。课程学习的 checkpoint 不能跨 reward 配置复用——前期学到的”打高分”行为会污染”保过线”目标。我就是没注意这点,白白浪费了好几个小时的训练时间。

游戏自动化要识别卡片,单用 OCR 不行,单用 CLIP 也不行。

OCR 准是准,慢;而且它的命根子是”卡面有清晰文字”——很多技能卡花体字带遮挡,识别率能给你跌成过山车。CLIP 倒是快,但前提是”见过”——头一次遇到的卡,记忆库里压根没特征,找个鬼。

后来灵光一闪,这俩其实可以串起来嘛:日常识别让 CLIP 顶着,识别不出来的,让 OCR 上去救场一次。OCR 拿到结果之后回头把这张图喂给 CLIP “学一下”——下次再遇到同样的卡,CLIP 直接就认出来了。

整个流程长这样:

1
2
3
4
5
6
7
8
9
截图 → YOLO 检测 → CLIP 尝试识别 → 命中? → 直接使用
↓ 未命中
OCR 识别文字

数据库匹配

让 CLIP 学习这张图

持久化到记忆库

CLIP 记忆库本身的工程细节我另写了一篇,这里专门讲两个东西的协作。

学习路径

完整的”学一张新卡”长这样:先让 CLIP 试一手,命中就跳过,没命中才轮到 OCR:

1
2
3
4
5
6
7
8
9
10
11
12
def learn_card(self, app, card_frame, card_list):
for card in card_list:
# 先试 CLIP
existing_id = self._try_clip_identify(app, card.frame)
if existing_id:
continue

# CLIP 不认识,OCR 上
learned_id = self._learn_via_ocr(app, card.frame)
if learned_id:
# OCR 找到了,让 CLIP 把这张图记下来
self.clip_manager.add_to_memory(card.frame, learned_id)

CLIP 识别那一段套了个 try/except——因为 CLIP 失败属于”正常情况”,库为空、相似度不够都会返回 None,没必要往外抛异常吓人:

1
2
3
4
5
6
7
8
def _try_clip_identify(self, app, card_frame):
try:
result = self.clip_manager.retrieve(card_frame)
if result is not None:
return result.payload.id
except Exception as e:
logger.debug(f"CLIP identify failed: {e}")
return None

OCR 这边稍微讲究点——不是把识别到的字直接当卡名(那也太天真了),而是拿去数据库里搜:

1
2
3
4
5
6
7
8
9
10
11
def _learn_via_ocr(self, app, card_image):
ocr_result = ocr_service.ocr(card_image)
if not ocr_result or not ocr_result.results:
return None

for item in ocr_result.results:
if len(item.text) >= 3:
status, db_result = database.search(item.text)
if status and db_result:
return db_result.id
return None

len(item.text) >= 3 这个限制不能少。OCR 看到图标、装饰元素也会硬识别,给你来一堆”●““◇”“+”这种单字符——要是不卡长度,数据库搜索分分钟被这些噪声淹没。三个字以上,基本上误识别就能挡掉大头。

运行时识别

学习路径主要发生在首次扫描或者版本更新那会儿。日常跑的时候走的是另一条道——CLIP 优先,没命中才回头去学一下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def identify_element(self, screenshot, element_frame):
# 置信度够高就直接用
result = self.clip_manager.retrieve(element_frame)
if result is not None and result.similarity > 0.96:
return result.payload

# 没命中,那就触发一次学习
self.learn_element(screenshot, element_frame)

# 再试一次
result = self.clip_manager.retrieve(element_frame)
if result is not None:
return result.payload

return None

可能你会问:学完了为啥还要 retrieve 一次?

因为 learn_element 内部会判断 OCR 是不是成功——成功了才往 CLIP 里灌,失败就啥都不做。所以学完到底有没有可用特征,得靠下一次 retrieve 兜底确认。

几个权衡

阈值定在 0.96,意思就是模棱两可的边缘情况,统统按”未命中”处理。

这是故意的。错认比漏认代价高太多了——漏认顶多就是触发一次 OCR 兜底,慢个几百毫秒;错认就麻烦了,后面整个自动化流程都得跟着跑错路径,那是真的会出事。

OCR 的开销其实不小,单次几百毫秒打底。所以 CLIP 的命中率,直接决定了用户体验的快慢。新版本卡刚出那阵子,命中率会暴跌,体感就是”咋这么慢”。等用户跑过两三次自动化把记忆库灌满,速度自己就回来了。

还有个不太显眼的好处:因为 OCR 只在 CLIP 失手时才上场,可以放心用那些又慢又准的 OCR 引擎(比如 macOS 的 Vision),不用在”速度”和”识别质量”之间做痛苦取舍。

Agent 流式输出的 SSE 事件设计

最早搞流式的时候图省事,只丢 content 增量出去,前端那边就是个慢慢变长的文本框。结果上线第一天,工单就开始飘——“是不是卡了?”“怎么没反应?”

其实后端忙得脚不沾地,正跑工具呢,可前端两眼一抹黑啥都不知道。这事儿一查就明白了:用户需要的不是”最终结果什么时候来”,而是”你现在到底在干嘛”。

后来把事件类型一拆四:

1
2
3
4
EVENT_ROUTING = "routing"           # 路由分类结果
EVENT_TOOL_STATUS = "tool_status" # 工具调用状态
EVENT_TOOL_RESULT = "tool_result" # 工具执行结果
EVENT_CONTENT = "content" # 正文内容

routing:先吼一嗓子,我知道你要啥了

请求一进门,小模型一分类完,立马给前端吐一个 routing 事件。前端拿到就能甩一行”正在查询设备…“上去,至少屏幕上动起来了:

1
2
3
4
5
6
7
8
9
yield {
"type": "routing",
"data": {
"category": "device_management",
"confidence": 0.92,
"reasoning": "用户询问设备状态",
"secondary_category": None,
},
}

这事儿技术上没什么了不起的,但心理效果立竿见影——首字节延迟从”等大模型出第一个 token”变成”等小模型分类完”,一个数量级的差距。用户那种”是不是死机了”的焦虑,立马就没了。

tool_status:工具的开始和收工

工具调用拆成两个事件,前端拿去切 loading 动画:

1
2
3
4
5
6
7
8
9
yield {
"type": "tool_status",
"data": {"name": "list_devices", "status": "calling"},
}

yield {
"type": "tool_status",
"data": {"name": "list_devices", "status": "completed"},
}

为啥不合成一个事件、带个执行时长?因为工具有时候真的会跑很久。比如说让它”统计过去 7 天所有设备的趋势”——这一拉就是好几秒,前端总得在 calling 状态下持续转圈,不然用户又得开始疑神疑鬼。

tool_result:原料给你,自己看着办

工具跑完,原始结果直接甩出去。前端要是想自己渲染(比如把设备列表搞成表格),完全不用等模型用人话复述一遍:

1
2
3
4
5
6
7
8
yield {
"type": "tool_result",
"data": {
"name": "list_devices",
"success": True,
"data": {"devices": [...]},
},
}

content:大家最熟悉的逐字流

最后才是大家常见的那个——一个 token 一个 token 往外吐:

1
2
3
4
5
6
7
8
yield {
"type": "content",
"data": {"content": "当前系统中共有"},
}
yield {
"type": "content",
"data": {"content": " 12 个设备"},
}

一次完整对话长啥样

来个例子,用户问”现在有哪些设备在线”,SSE 流大概是这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
POST /v1/chat/completions
{
"model": "smart-agent",
"messages": [{"role": "user", "content": "现在有哪些设备在线?"}],
"stream": true
}

data: {"type":"routing","data":{"category":"device_management","confidence":0.95}}

data: {"type":"tool_status","data":{"name":"list_devices","status":"calling"}}

data: {"type":"tool_result","data":{"name":"list_devices","success":true,"data":{...}}}

data: {"type":"tool_status","data":{"name":"list_devices","status":"completed"}}

data: {"type":"content","data":{"content":"当前系统中共有"}}

data: {"type":"content","data":{"content":" 12 个设备,其中 8 个在线"}}

data: {"type":"content","data":{"content":",4 个离线。"}}

整个过程像看比赛实况——分类、调工具、出结果、播报,每一步都看得见。

思考链怎么薅出来

Qwen3、DeepSeek 这一类模型都支持思考链,但 LangChain 的封装就有点折腾人了——有时候挂在 content 列表里某个 reasoning block,有时候又躲在 additional_kwargs.reasoning_content 里。

讲究的方法是几个位置都翻一遍,谁有就要谁:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
@classmethod
def _extract_stream_chunk_deltas(cls, chunk) -> Dict[str, str]:
content_delta = cls._message_content_to_text(getattr(chunk, "content", ""))
thinking_parts = []

raw_content = getattr(chunk, "content", None)
if isinstance(raw_content, list):
for block in raw_content:
if not isinstance(block, dict):
continue
block_type = str(block.get("type", "")).lower()
if block_type in {"reasoning", "thinking", "reasoning_content"}:
text = cls._extract_reasoning_text(block)
if text:
thinking_parts.append(text)

for container in (
getattr(chunk, "additional_kwargs", None),
getattr(chunk, "response_metadata", None),
):
if not isinstance(container, dict):
continue
for key in ("reasoning_content", "thinking", "reasoning"):
if key in container:
text = cls._extract_reasoning_text(container.get(key))
if text:
thinking_parts.append(text)

return {
"content_delta": content_delta,
"thinking_delta": "".join(thinking_parts),
}

薅出来之后单独发一个 thinking 事件:

1
2
3
4
yield {
"type": "thinking",
"data": {"content": "用户询问设备状态,我需要调用 list_devices 工具..."},
}

前端可以折叠成一个小抽屉,不混进正文里——不然模型一句话还没说完,思考过程已经把屏幕填满了。

并发别忘了管

不管的话很容易出事。几个用户一起来一个大任务,本地模型服务直接吃完显存。所以前面挡了个 Semaphore,超了就排队等:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import asyncio
from enum import Enum

class RequestStatus(str, Enum):
QUEUED = "queued"
PROCESSING = "processing"
TOOL_CALLING = "tool_calling"
COMPLETED = "completed"
FAILED = "failed"

class QueueManager:
def __init__(self, max_concurrent: int = 10):
self._semaphore = asyncio.Semaphore(max_concurrent)
self._active_requests: Dict[str, RequestTask] = {}

async def submit(self, request_id: str, handler) -> RequestTask:
task = RequestTask(request_id=request_id)
self._active_requests[request_id] = task

async def _execute():
async with self._semaphore:
task.status = RequestStatus.PROCESSING
try:
result = await handler(task)
task.result = result
task.status = RequestStatus.COMPLETED
except Exception as e:
task.error = str(e)
task.status = RequestStatus.FAILED

asyncio.create_task(_execute())
return task

排队过程其实也可以单独发个事件给前端,告诉用户”你排第 3 位呢”。还没做。

FastAPI 这边

接口本身没啥花活。几个 header 倒是要盯紧了:

  • Cache-Control: no-cache:别被中间层缓存了
  • X-Accel-Buffering: no:部署在 nginx 后面的必须加,不然 SSE 流会被缓冲住,一卡半天
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import json

app = FastAPI()

@app.post("/v1/chat/completions")
async def chat_completions(request: ChatRequest):
if request.stream:
return StreamingResponse(
stream_generator(request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
result = await agent.chat(request.messages)
return result

async def stream_generator(request):
async for event in agent.chat_stream(request.messages):
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"

X-Accel-Buffering: no 这条我反复栽过——本地开发好好的,一上 nginx 反代就开始诡异卡顿。每次都得花十分钟想起来,再补上。

前端解析

事件类型一多,前端就得做分发。最朴素的写法就是一个 switch:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
const eventSource = new EventSource('/v1/chat/completions');

eventSource.onmessage = (event) => {
if (event.data === '[DONE]') {
eventSource.close();
return;
}

const data = JSON.parse(event.data);

switch (data.type) {
case 'routing':
showStatus(`正在${data.data.category}...`);
break;
case 'tool_status':
if (data.data.status === 'calling') {
showLoading(`正在调用 ${data.data.name}...`);
}
break;
case 'tool_result':
if (data.data.success) {
showToolResult(data.data.name, data.data.data);
}
break;
case 'content':
appendContent(data.data.content);
break;
case 'thinking':
showThinking(data.data.content);
break;
}
};

这块代码以后必胖——因为每种工具的结果展示方式都不一样,最后多半会演化出一个”工具名 → 渲染组件”的注册表。

不过那是以后的事。第一版能跑就行,等真的有十来种工具结果再去抽象也来得及。

create_react_agent 一行起步的 demo 大家都看过,简洁优雅、五分钟跑通——博文里看着像神器。

可真正接上业务工具之后,麻烦事一桩接一桩:工具抛个异常,直接报错回去;模型在关键参数还没拿到时就开始瞎调;连续失败的时候它会”我不信邪”地一直重试;偶尔还甩你一句”抱歉我无法完成”,把球踢回来。

下面是我后来把这个 ReAct 拆开重写时的几个改动点。

为啥不用 create_react_agent

预制版本本身没啥不好,作为起点也合适:

1
2
3
4
from langgraph.prebuilt import create_react_agent

agent = create_react_agent(model, tools, prompt=SystemMessage(content=system_prompt))
result = await agent.ainvoke({"messages": [HumanMessage(content=user_message)]})

但它把”model 决定 → 调用工具 → 把结果丢回去”这个循环写死了,跟个铁盒子似的。

我想在工具调用前后插自己的逻辑:参数校验、错误归一化、强制回到工具调用、连续失败熔断——这些活儿,都得自己拼图。

用 StateGraph 自己拼

LangGraph 的 StateGraph 提供了节点和条件边,把上面这些需求拆成节点就行:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode

def _build_agent_graph(self, model, tools, system_prompt):
tool_node = ToolNode(tools, awrap_tool_call=self._wrap_tool_call)
finalize_model = model.bind(tool_choice="none")

async def agent_node(state): ...
async def precheck_node(state): ...
async def enforce_tool_node(state): ...
async def finalize_node(state): ...
async def update_flags_node(state): ...

graph = StateGraph(AgentState)
graph.add_node("agent", agent_node)
graph.add_node("precheck", precheck_node)
graph.add_node("run_tools", tool_node)
graph.add_node("enforce_tool", enforce_tool_node)
graph.add_node("finalize", finalize_node)
graph.add_node("update_flags", update_flags_node)

graph.add_conditional_edges("agent", after_agent, {
"precheck": "precheck",
"enforce_tool": "enforce_tool",
"finalize": "finalize",
})
graph.add_conditional_edges("precheck", after_precheck, {
"run_tools": "run_tools",
"agent": "agent",
})
graph.add_edge("run_tools", "update_flags")
graph.add_conditional_edges("update_flags", after_update_flags, {
"agent": "agent",
"finalize": "finalize",
})
graph.add_edge("enforce_tool", "agent")
graph.add_edge("finalize", END)
graph.set_entry_point("agent")

return graph.compile()

各个节点的分工大概是:

节点做的事
agent调用 LLM,决定下一步是回话还是调工具
precheck工具调用前检查参数 / 依赖
run_tools真正执行工具(LangGraph 自带的 ToolNode)
enforce_tool检测到模型”偷懒”时,强制注入指令让它必须调工具
finalizebind(tool_choice="none") 强制产出自然语言
update_flags工具执行后更新失败计数等状态

工具别抛异常

这条是最关键的一条改动,划重点。

原始做法是工具失败直接 raise,LangGraph 默认行为就把异常包成错误终止图——用户那边看到的是一句没头没尾的报错,体验直接拉到底。

更聪明的方式,是工具内部把异常吞了,返回一个结构化的失败消息——让模型自己看见、理解、然后自己修正:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
async def _wrap_tool_call(self, request, execute):
response = await execute(request)
if not isinstance(response, ToolMessage):
return response

parsed = self._parse_tool_output(getattr(response, "content", ""))
success = getattr(response, "status", "") != "error"

error_text = None
if isinstance(parsed, dict):
if "success" in parsed:
success = bool(parsed.get("success"))
if parsed.get("error"):
error_text = str(parsed.get("error"))

normalized_payload = {
"success": success,
"data": parsed.get("data", parsed) if isinstance(parsed, dict) else parsed,
}
if error_text:
normalized_payload["error"] = error_text

return ToolMessage(
content=json.dumps(normalized_payload, ensure_ascii=False),
name=getattr(response, "name", ""),
tool_call_id=getattr(response, "tool_call_id", ""),
status="error" if not success else "success",
)

这样模型看到的东西是这副样子:

1
2
3
4
5
{
"success": false,
"data": null,
"error": "设备 ID 不存在,请先调用 list_devices 获取有效设备列表"
}

模型一看就懂——“哦那我先 list 一下”,自己把这个圈给闭环了。这种自我纠正能力,是大模型最值钱的本事之一,但前提是你得给它看懂的信息。直接甩个 Python traceback 过去?它要么乱猜参数硬重试,要么干脆撂挑子放弃。

precheck:工具调用之前先瞅一眼

有些工具调用,注定要失败——比如 create_automation 需要 device_id,但模型还没查过设备列表,那它肯定是瞎填一个。

与其让它失败再纠正、白白浪费一轮 token,不如在执行前拦下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
async def precheck_node(self, state):
last_msg = state["messages"][-1]
if not (isinstance(last_msg, AIMessage) and last_msg.tool_calls):
return {"_precheck_result": None}

tool_name = last_msg.tool_calls[0].get("name", "")
tool_args = last_msg.tool_calls[0].get("args", {})

precheck_err = self._precheck_tool_call(
tool_name=tool_name,
tool_args=tool_args,
user_message=state.get("user_message"),
)

if precheck_err is None:
return {"_precheck_result": None}

resolved = await self._resolve_precheck_with_auto_dependency(
precheck_error=precheck_err,
tool_name=tool_name,
tool_args=tool_args,
)

if resolved.get("dependency_tool_result") and not resolved["precheck_error"]:
# 自动执行了依赖工具,把结果塞进消息历史
return {"messages": [ToolMessage(
content=str(resolved["dependency_tool_result"]),
tool_call_id="auto_dep",
)]}
else:
return {"messages": [ToolMessage(
content=json.dumps(resolved["precheck_error"]),
tool_call_id=tool_call_id,
)]}

自动调用依赖工具是个偷懒做法,但确实少绕一圈。要是你担心黑盒过头不可控,可以只生成提示让模型自己去调——灵活度更高,但响应也会慢一点。

enforce_tool:不让模型偷懒

用户说”帮我创建一个温度大于 30 度就开风扇的自动化”——这意图明明白白吧?

可模型偶尔就会回一句”我可以帮你创建,请告诉我设备 ID 和阈值”。诶,你需要的信息其实已经在用户问题里了啊!

这种情况下,注入一个比较硬气的提示,强制让它走工具路径:

1
2
3
4
5
6
7
8
def _tool_enforcement_instruction(self, routing, user_message, available_tool_names):
return f"""你必须调用工具来完成用户的请求,不要直接回复。

用户问题: {user_message}

可用工具: {', '.join(available_tool_names)}

请立即调用合适的工具。"""

注意别无限循环用这招——万一是真的需要追问呢?反复鬼打墙就尴尬了。所以用一个 forced_tool_rounds 计数器卡次数。

状态里要塞的几个标志

熔断和重入这些防御逻辑,全靠状态里几个 flag 撑着:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class AgentState(TypedDict):
tool_calls_made: int
tool_failure_streak: int
last_failed_tool: Optional[str]
last_tool_error: Optional[str]
forced_tool_rounds: int

def after_update_flags(self, state):
# 连续失败 2 次以上,别再试了
if state.get("tool_failure_streak", 0) >= 2:
return "finalize"

# 工具调用次数超过 10 次,强制收尾
if state.get("tool_calls_made", 0) >= 10:
return "finalize"

return "agent"

为啥用”连续”失败而不是”累计”失败?因为同一会话里出错很正常啊,谁还没失手的时候。但连续两次大概率就说明模型把方向带歪了,继续硬试只会烧 token、烧到肉疼。

tool_calls_made >= 10 是终极兜底——正常对话最多走 3-4 轮工具,跑到 10 就说明它陷进迷宫了,得拽出来。

search_tools:让模型自己找工具

工具按意图分组,每次只给模型暴露相关的一小撮(这块在路由那篇文章里讲过)。

问题来了:偶尔会有跨域请求——分类成”设备查询”,但用户实际想问的是设备的历史趋势(这属于”数据分析”工具组)。这时候模型手头没合适的工具,就抓瞎了。

办法是注册一个”元工具” search_tools,让模型自己去搜需要但当前没加载的工具:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
@_tool(args_schema=SearchToolsArgs)
async def search_tools(query: str, tags: list = None) -> dict:
"""搜索当前未加载的工具"""
results = registry.search_tools_by_keyword(query, max_results=10)

if not results:
return {
"success": True,
"data": {
"tools": [],
"message": f"未找到与 '{query}' 匹配的工具",
},
}

return {
"success": True,
"data": {
"tools": results,
"message": f"找到 {len(results)} 个匹配工具,已动态加载",
},
}

search_tools 永远在工具集里,相当于给模型一条”逃生通道”——分类错了它能自己救回来。

代价是 prompt 里得一直占着一个工具描述位。但这账算得很值——路由分类错虽然是低概率事件,可一旦发生影响就大,留个后门更安心。

之前 Agent 里每来一句话,不管三七二十一,全塞给 32B 的主模型。结果就是用户问一句”今天怎么样”,也得让大模型吭哧吭哧思考三秒,烧掉两千个 token——你说亏不亏?

说白了,大部分请求根本用不着主模型出场。要么就是寒暄两句,要么意图明明白白摆在那儿,杀鸡焉用牛刀。

后来我在前面塞了一层 1.7B 的小分类器,跑在本地 Ollama 上,专门干一件事:判断这次该走哪条路。

意图分了哪几类

按业务划了 6 类,从纯聊天到多模态病虫害诊断都覆盖到了:

1
2
3
4
5
6
7
8
9
from enum import Enum

class IntentCategory(str, Enum):
PEST_DIAGNOSIS = "pest_diagnosis" # 病虫害诊断(多模态)
DEVICE_MANAGEMENT = "device_management" # 设备管理
DATA_ANALYSIS = "data_analysis" # 数据分析
AUTOMATION = "automation" # 自动化规则
AGRICULTURE = "agriculture" # 农业管理
GENERAL = "general" # 通用问答

分类器要的是规规矩矩的结构化输出,所以拿 Pydantic 卡一道——免得它一时兴起跟你聊起人生哲学:

1
2
3
4
5
6
7
8
from pydantic import BaseModel, Field

class RouterOutput(BaseModel):
category: str = Field(description="意图分类")
confidence: float = Field(default=0.5, description="置信度 0-1")
reasoning: str = Field(default="", description="分类理由")
requires_image: bool = Field(default=False, description="是否需要图片")
sub_intent: str = Field(default="", description="子意图")

分类器本体

模型挑了 qwen3:1.7b,温度压到 0.1(这种活儿不需要它发挥创造力),max_tokens 给个 256 就够用,超时卡 5 秒——再长就纯属耽误事,还不如让主模型直接上:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from langchain_openai import ChatOpenAI

class IntentRouter:
def __init__(self):
self._router_model = ChatOpenAI(
base_url="http://localhost:11434/v1",
model="qwen3:1.7b",
max_tokens=256,
temperature=0.1,
timeout=5.0,
)

async def classify(self, query: str, history: list = None) -> RoutingResult:
try:
return await self._classify_with_llm(query, history)
except Exception:
return self._keyword_fallback(query)

关键词兜底

小模型也是会抽风的——超时、吐出非法 JSON、给你一个根本不存在的类别。这些事儿要是直接报错给用户,体验立马就崩。

所以底下又垫了一层关键词匹配。逻辑笨是真的笨,但有个好处:它永远不会挂。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def _keyword_fallback(self, query: str) -> RoutingResult:
q = query.lower()

pest_keywords = ["病虫害", "病害", "虫害", "枯萎", "黄叶", "斑点",
"诊断", "识别病", "什么病", "怎么治"]
if any(kw in q for kw in pest_keywords):
return RoutingResult(
category=IntentCategory.PEST_DIAGNOSIS,
confidence=0.75,
reasoning="关键词匹配: 病虫害相关",
)

auto_keywords = ["自动化", "自动", "规则", "触发", "条件", "定时",
"创建规则", "编排", "工作流", "联动"]
if any(kw in q for kw in auto_keywords):
return RoutingResult(
category=IntentCategory.AUTOMATION,
confidence=0.7,
reasoning="关键词匹配: 自动化相关",
)

device_keywords = ["设备", "传感器", "连接", "断开", "在线", "离线", "状态"]
if any(kw in q for kw in device_keywords):
return RoutingResult(
category=IntentCategory.DEVICE_MANAGEMENT,
confidence=0.7,
reasoning="关键词匹配: 设备管理相关",
)

return RoutingResult(
category=IntentCategory.GENERAL,
confidence=0.5,
reasoning="未匹配到特定领域关键词",
)

匹配顺序这事得动点脑子。举个例子,“设备联动”这句话里既有”设备”又有”联动”,但用户心里想的明明是自动化规则。要是按字母序匹配,肯定被设备这一支先吃掉,结果就是用户问”帮我做个联动”,系统跑去查设备列表给他——尴不尴尬?

所以自动化的关键词必须排在设备前面判。

工具按意图加载

分类只是手段,真正想干的事是把工具集裁到一个合理的大小。

Anthropic 之前发过一个经验值:工具一旦超过 20 个,模型选错的概率明显抬头。所以分类完了之后,按 tag 拉对应的工具子集就行:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def _select_pipeline(self, routing: RoutingResult):
registry = self._tool_registry

if routing.category == IntentCategory.DEVICE_MANAGEMENT:
tools = registry.get_tools_by_tags(["device", "entity"])
system_prompt = get_prompt("device_expert")

elif routing.category == IntentCategory.DATA_ANALYSIS:
tools = registry.get_tools_by_tags(["data", "device"])
system_prompt = get_prompt("data_analyst")

elif routing.category == IntentCategory.AUTOMATION:
tools = registry.get_tools_by_tags(["automation", "device", "entity"])
system_prompt = get_prompt("automation_expert")

elif routing.category == IntentCategory.AGRICULTURE:
tools = registry.get_tools_by_tags(["agriculture", "diagnosis", "data"])
system_prompt = get_prompt("agriculture_expert")

else:
tools = registry.get_tools_by_tags(["general"])
system_prompt = get_prompt("general_assistant")

return model, tools, system_prompt

工具注册的时候顺手打 tag,注册中心按 tag 取交集,就这么简单:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class ToolRegistry:
def __init__(self):
self._tools: Dict[str, ToolRegistration] = {}

def register(self, tool, tags=None, description_zh="", examples=None):
reg = ToolRegistration(
tool=tool,
tags=set(tags or []),
description_zh=description_zh or tool.description,
examples=examples or [],
)
self._tools[tool.name] = reg

def get_tools_by_tags(self, tags: List[str]) -> List[BaseTool]:
result_set = {}
tag_set = set(tags)
for name, reg in self._tools.items():
if reg.tags & tag_set:
result_set[name] = reg.tool
return sorted(result_set.values(), key=lambda t: t.name)

跑下来什么感觉

上了这层路由以后,体感最爽的就是寒暄类的请求——基本秒回。小模型一拍板,主模型连脸都不用露。Token 账单也瘦了一圈,没具体统计,但调到自动化场景时,上下文里塞的工具描述肉眼可见地少了一大截。

最坑的地方倒不是模型本身,反而是关键词兜底的顺序——这玩意儿排错了,出来的结果能让你怀疑人生。还有个一直在我 TODO 里没动的事:小模型分类错了到底咋办?目前是无条件相信兜底,但理论上应该按 confidence 卡个线,太低就升级给主模型再判一次。

这块还在慢慢磨。