上一篇说我们最后是”启发式 + LLM + RL”三套决策并存。这就引出一个工程问题——怎么让三套决策器共一个接口,让上层调度无脑切换?

答案听上去老套:Strategy 模式。但 LLM 和 RL 这两位选手的”行为习惯”和经典 OO 教科书里的策略可不太一样,硬套模板会被现实教做人。这篇说说我们最后怎么落地的。

为什么必须共一个接口

三套决策器各有各的强项(上一篇详细写了),但调用方——也就是游戏自动化的任务调度——不应该关心今天用的是哪套。它的诉求很朴素:

给我一个当前局面,我告诉你出什么。

如果三套各暴露各的 API,上层就得到处写:

1
2
3
4
5
6
7
if mode == "algo":
action = algo_solver.choose(...)
elif mode == "llm":
action = llm_strategy.ask(prompt=..., context=...)
elif mode == "rl":
obs = build_observation(...)
action = rl_policy.predict(obs)

每加一种策略,整个调用链都得改。调度逻辑被三家细节污染,谁都不开心。

接口长什么样

抽象到最后就一个方法:

1
2
3
4
class BaseStrategy(ABC):
@abstractmethod
def decide(self, state: GameState) -> Decision:
"""根据当前游戏状态做决策"""

GameState 是一个 dataclass,封装”当前局面所需的一切”——手牌、状态、回合数、可选动作等。Decision 也是 dataclass,包含”选哪个动作 + 决策理由 + 元数据”。

三个具体实现:

1
2
3
4
5
gameplay/strategy/
├── base_strategy.py # 抽象基类
├── algo_strategy.py # 启发式
├── llm_strategy.py # LLM 决策
└── rl_strategy.py # RL 决策

听起来很标准对吧?真做下来三个具体类各有各的脾气。

algo_strategy:最听话的那个

启发式策略是最贴合 Strategy 模板的——纯函数,输入 state 输出 decision,没有任何隐藏状态:

1
2
3
4
5
6
7
8
9
10
11
12
13
class AlgoStrategy(BaseStrategy):
def decide(self, state: GameState) -> Decision:
best_action = None
best_eval = -inf
for action in state.legal_actions:
ev = evaluate(state, action)
if ev > best_eval:
best_eval, best_action = ev, action
return Decision(
action=best_action,
reason=f"evaluation={best_eval:.2f}",
meta={"all_evals": {...}},
)

可重入、线程安全、没有外部依赖。教科书都不舍得这么干净的实现。

如果三套都能这么写,这篇就不用发了。

llm_strategy:把”对话感”塞进无状态接口

LLM 自带”对话历史”的概念——你跟它聊得越多,它越懂你。但 Strategy 接口要求无状态——同样的 state 进来,应该产生(可比较的)输出。

这两者本质冲突。

解决办法是把”上下文”显式放进 state 里,而不是藏在 strategy 内部:

1
2
3
4
5
6
7
8
9
10
@dataclass
class GameState:
# 当前局面
hand: list[Card]
hp: int
turn: int
...
# 历史上下文(由上层维护,显式传入)
session_history: SessionHistory
insights: list[Insight]

session_history 怎么压缩、insights 怎么挑(参考前面两篇 insight 和上下文压缩),都是 state 构造期间完成的,strategy 拿到时已经是”压缩好的快照”。

LLM strategy 内部就变成无状态了:

1
2
3
4
5
class LLMStrategy(BaseStrategy):
def decide(self, state: GameState) -> Decision:
prompt = self._build_prompt(state) # 完全由 state 构造
response = self.llm_client.chat(prompt)
return self._parse(response, state)

关键约束:strategy 内部不允许存任何跨 decide 调用的状态。所有”历史”“记忆”“上下文”都从外面传进来。这条规矩破了,后面 RL 那位选手就跟你急。

rl_strategy:远程推理服务

RL 这位最特别——训练时它在 PyTorch / SB3 那一套里,部署时神经网络要做实时推理。

最朴素的做法是在 strategy 里直接加载模型:

1
2
3
4
5
6
7
8
class RLStrategy(BaseStrategy):
def __init__(self, ckpt_path):
self.policy = load_policy(ckpt_path)

def decide(self, state):
obs = state_to_observation(state)
action = self.policy.predict(obs)
return Decision(action=action, ...)

能跑。但很快就被现实教育了——

问题一:模型几百 MB,每个进程都加载一份,主程序内存爆炸。 问题二:模型用 GPU 推理才划算,但主程序在 CPU 机器上跑。 问题三:模型版本升级要重启整个主程序,运维痛苦。 问题四:训练侧和部署侧的 Python 依赖打架——SB3、Ray、Torch 这些跟主程序的依赖经常冲突。

最后改成了远程无状态推理服务

1
2
3
4
5
[主程序]
↓ HTTP/RPC(传 observation,收 action)
[RL 推理服务(独立部署)]

[加载好的策略网络]

代码层面引入一个 client:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# rl_inference_client.py
class RLInferenceClient:
def predict(self, observation: dict) -> int:
resp = self.http.post(self.endpoint, json=observation)
return resp.json()["action"]

# rl_strategy.py
class RLStrategy(BaseStrategy):
def __init__(self, client: RLInferenceClient):
self.client = client

def decide(self, state):
obs = state_to_observation(state)
action = self.client.predict(obs)
return Decision(action=action, ...)

主程序不需要装 PyTorch、不需要管模型版本、不需要 GPU。RL 服务挂了就降级——上层调度发现 RL strategy 拿不到结果,自动 fallback 到 algo strategy。

项目规范文档里有一条规则正好对应:

决策逻辑需要以无状态方式对接到:train/gakumas_rl

这就是为什么早期我们就强制”无状态”——后来要拆服务、要并发、要降级,都靠这条约束才不至于推倒重来。

三种策略凭什么能并存

三个具体实现完全不同——一个纯函数、一个 LLM 调用、一个 RPC 调用——能塞进同一个接口的原因只有一个:state 和 decision 这两个数据契约定得够稳

GameState 设计原则:

  • 完整:调用方传一次 state,strategy 就能做决策,不需要再去问别的地方
  • 可序列化:dataclass + 基础类型,能 JSON 化(RL 推理服务要走网络)
  • 版本化:加字段用 Optional + 默认值,老的 strategy 不会被新字段搞挂

Decision 设计原则:

  • 统一动作类型:不管哪个 strategy,最后输出的 action 都是同一个枚举
  • 必带 reason:哪怕 RL 给个 “rl_policy_prediction” 也行,调试时要追问
  • 保留 meta:每个 strategy 想塞什么调试信息进去都可以,上层不强制读

这些约束写在 dataclass 里:

1
2
3
4
5
@dataclass(frozen=True)
class Decision:
action: Action # 必须是已知 Action 枚举
reason: str # 必须有
meta: dict[str, Any] = field(default_factory=dict) # 各家自由发挥

frozen=True 是个小心思——decision 一旦产出不允许修改,避免被下游悄悄篡改。

上层调度怎么选

有了三个 strategy,最后还要决定”这次用哪个”。我们的做法很土,但稳:

1
2
3
4
5
6
7
def get_strategy(scene: Scene, config: Config) -> BaseStrategy:
# 配置层面的强制指定优先
if config.force_strategy:
return STRATEGIES[config.force_strategy]

# 按场景默认
return SCENE_DEFAULT_STRATEGY[scene]

然后在 SCENE_DEFAULT_STRATEGY 里把每个游戏场景默认用哪套写死:

1
2
3
4
5
6
SCENE_DEFAULT_STRATEGY = {
Scene.DAILY_TASK: "algo", # 日常任务,启发式够用
Scene.HIGH_RANK_EXAM: "rl", # 高难度,RL 冲分
Scene.STORY_DIALOGUE: "llm", # 剧情对话,LLM 选项
...
}

加新场景时只需要决定”默认走哪条路”,不动 strategy 实现。这种”配置驱动 + 默认值显式”的模式比”代码里 if-else”清爽太多。

降级链

三个 strategy 在线上不可避免会出问题:

  • LLM 服务超时 / 限流
  • RL 推理服务挂了
  • 启发式遇到没考虑过的边界情况

降级链设计:

1
2
3
RL fail → fallback to algo
LLM fail → fallback to algo
algo fail → fallback to "什么都不做"(safe action)

为什么 algo 是兜底——因为它没有外部依赖,纯函数,最不容易挂。这条降级链就一句话:

1
2
3
4
5
6
7
def decide_with_fallback(state, primary, fallbacks=[algo, safe]):
for strategy in [primary, *fallbacks]:
try:
return strategy.decide(state)
except StrategyError as e:
log.warning(f"{strategy} failed: {e}, trying next")
raise RuntimeError("All strategies failed")

至少三个都挂的情况,我们到目前没遇到过。

复盘几条原则

写完这套之后总结的经验:

接口要早定,并且不许 strategy 内部存状态。 任何”我先在 strategy 里缓存一下”都是定时炸弹——后面拆远程服务、并发调用,全爆。

state 和 decision 的 dataclass 是核心契约,比 strategy 实现重要十倍。 实现可以换,契约不能轻易动。

支持降级链,特别是对外部依赖的 strategy。 LLM 和 RL 都有外部依赖,没有降级等于线上裸奔。

让最简单、最没有依赖的实现做兜底。 启发式这种纯函数策略平时不显眼,关键时候救命。

默认策略选择写成数据,不写成 if-else。 加新场景时改个映射表就好,不要去翻代码。

收个尾

经典 Strategy 模式在教科书里就一页纸——抽象基类、几个实现、上下文持有引用,完事。

放到 LLM 和 RL 这种带外部依赖、带历史状态、带服务化诉求的现实场景里,模式还是那个模式,但周边的工程肉戏一大堆——状态外置、远程推理、降级链、数据契约、配置驱动——任何一块没做好,三套策略并存就变成三套互相打架。

接口设计这事儿,从来不只是写个 ABC 这么简单。

游戏自动化做出牌决策,路子大致两条:

  • A 路:把官方游戏里那个”自动出牌推荐”逆向出来,照着它的公式抄一遍
  • B 路:自己把游戏建成 Gym 环境,让 RL agent 从零学

两条都试过。这篇不写哪条赢——两边都有自己的舒适区,搞清楚各自适合什么场景比争胜负有用得多。

A 路:逆向官方公式

游戏里其实有内置的自动打牌——你把画面挂着不动几秒,会有一张手牌闪橙色边框,那就是它推荐的。

社区里有大佬已经把逆向做了大半(Vibbit’s Blog 有篇详细解析),核心方法叫 ExamRuleCalculator.Evaluate。算法本身不复杂:

预先分别计算每张手牌打出后的 evaluation 值,哪个大就出哪张。

evaluation 由 19 个参数加权求和而成。比如:

1
r1 = v1 × ProduceExamAutoEvaluation.evaluation

其中 v1 是当前局面的某个状态值(比如剩余回合数、当前集中、当前体力等),权重从主数据库里查表得到,跟”主打效果””剩余回合””玩法类型”三个维度索引。

主数据库里大概几千条这样的权重 entry。游戏每次更新可能调整。

抄进去之后好用吗

挺好用的。

直接抄一遍核心评估函数,接上自己写的”枚举每张手牌→算 evaluation→出最大值”,立马就能跑出一个能打的 agent。该过线过线,该高分高分。

好处太明显

  • 不用训练,写完即可用
  • 行为可解释——任何一次决策都能输出”为什么是这张”,权重一拉出来就懂
  • 不会”抽风”——纯算式,输入相同输出必相同
  • 调试友好——出问题能精确定位是哪个权重不对

那为啥还想搞 RL

因为这套启发式有几个绕不开的天花板。

第一,它只看当前回合的局部最优。 evaluation 计算的是”打出这张牌之后的状态分”,本质是贪心。有些局面下你需要”忍一回合不出强卡,攒到下回合 buff 起效再出”,启发式做不到——它每一步都挑当前最优解,整盘可能不是最优。

第二,它是开发者写出来的,不是为玩家爽点优化的。 游戏里那个自动推荐主要是”帮新手别打太烂”,目标是”过线”,不是”打高分”。你想冲分?这套启发式天花板就摆在那。

第三,新机制出来要重新逆向。 每次游戏大版本更新,新增了卡片效果、状态、机制,得重新挖一遍主数据库的权重表。维护成本不低。

第四,没法适配自己的目标。 我想训一个”专门刷高排名 fan vote”的 agent,启发式的目标函数和我的目标不一致,怎么调权重都拧巴。

B 路:自己训 RL

另一条路是把游戏建成 Gym 环境,自己训。

我们做的是这套:

  • 把游戏的核心战斗逻辑用纯 Python 重写一遍(这步是大头,几千行)
  • 包成 Gymnasium 接口
  • 用 MaskablePPO 训练(动作空间离散且有非法动作,必须用 mask)
  • 从课程学习入手:先学初级关卡,能过线了再学高级

代码组织大概这样:

1
2
3
4
5
6
7
8
9
train/gakumas_rl/src/
├── simulation/
│ ├── exam/runtime.py # 考试关卡的纯 Python 仿真
│ └── produce/runtime.py # 育成流程的仿真
├── training/
│ ├── autopilot.py # 课程学习的总调度
│ ├── self_bootstrap.py # BC 蒸馏 + RL 微调的自举
│ └── model.py
└── ...

卡在哪儿

仿真环境复现游戏机制,这一步比想象中难十倍。

游戏里看似简单的”一张卡的效果”,背后可能有触发顺序、buff 叠加优先级、上下取整规则、随机种子处理等一堆细节。我们前后对了几个月,对着真实游戏对局抓包反复校准,才把仿真精度调到”打 10 局,分数和真实游戏差异 ≤ 5%”。

仿真不准,训出来的 agent 一上真实游戏就抓瞎——它学的策略是基于错误规则的最优解,到了真规则下面就是次优解。

训练流水线

为了避免冷启动太痛苦,我们走了一套”模仿先行→RL 微调”的路线:

  1. 用启发式(A 路那套)跑大量数据,收集 (状态, 启发式选择) 对
  2. BC(Behavior Cloning)预训练——让神经网络先学会”像启发式那样打”
  3. PPO 微调——在启发式策略基础上继续探索更优解

效果:

  • 单跑 PPO 从零探索,前几百万步基本不收敛
  • BC 预训练 + PPO 微调,几十万步就能稳定超过纯启发式

启发式在这里反而成了 RL 的”师父”——给它一个像样的起点,比从随机策略开始强太多。

课程学习

游戏里有好几种难度档位:初级 Regular → 初级 Master → NIA Pro → NIA Master。直接训最高难度,agent 探索半天连过线都困难,PPO 收敛极慢。

课程的做法是按难度阶梯训:

1
2
3
4
5
6
7
初级中间考试 → 初级最终考试

NIA 中间考试 → NIA 最终考试 → NIA 选拔

初级 Regular 全流程 → 初级 Master 全流程

NIA Pro 全流程 → NIA Master 全流程

每阶段从上一阶段的 checkpoint 热启动。听起来繁琐,但比”直接训最难”快了一个量级。

两条路实测对比

挑几个能量化的维度比一下:

维度启发式(A 路)RL(B 路)
开发成本低(抄公式即可)高(仿真环境最贵)
维护成本中(版本更新要重新逆向)中(仿真要跟版本同步)
决策可解释性极强弱(黑箱)
决策稳定性完全可复现训练 seed 不同会有差异
过线率(初级关卡)接近 100%接近 100%
高分率(顶级关卡)中等较高(特别是 fan vote 那种长链优化)
应对新机制要重写公式要更新仿真,但策略可微调适应
异常局面处理严格按规则,遇到没考虑的情况会很离谱见过类似的状态有泛化能力

我们最后怎么用的

实战部署里两个都没扔——按场景挑。

日常自动化场景(清日常任务、刷材料、过普通关卡):用启发式。

这种场景诉求是”稳、快、能解释”。一个 evaluation 函数就搞定,没必要请 RL 出马。RL 模型加载、推理都比纯函数慢,没意义。

冲分场景(高难度活动、排行榜挑战):用 RL。

这种场景诉求是”在已知规则下逼近最优”。RL 训出来的策略能学到一些启发式想不到的”反直觉操作”——比如某些回合故意不打最强牌,留给后面叠 buff 后翻倍输出。

中间的判断由上层调度做,模型 strategy 接口都一致(这块的接口设计另外写了一篇)。

一些个人观察

写完两套之后我反过来想,启发式和 RL 不是对立关系,是上下游关系

  • 启发式给你一个”能解释、能调试、能稳定运行”的基线
  • RL 在这个基线之上”找那些人写不出来的细微策略”

一上来就想训 RL、不写启发式基线的项目,多半会卡死在”训了半天还不如随便写个规则”的尴尬阶段。

反过来,认为”有启发式就够了,RL 是花架子”也偏激——你的目标函数和启发式作者的目标函数永远不会完全一致,那一点点 gap,RL 真能补上。

下次有人问”游戏 AI 该走规则还是 RL”,我大概会回:”你两个都该有,然后按场景挑。”

让 LLM 接管游戏决策的早期版本,system prompt 写得相当”全能”——一份大 prompt,把规则、角色、所有可能遇到的场景全塞进去。开发的时候图省事,反正一份配置走遍全场。

跑了一段时间,发现这哥们儿”全能”得有点像班级里那种啥都会一点、啥都不精的同学。

最后的版本拆成了 23 个 prompt 模板。听起来夸张,但真正动手拆完,回头看那个万能 prompt——只能说”它能跑”已经是项目早期最大的功劳了。

万能 prompt 是怎么烂掉的

最开始那份 prompt 大概长这样:

你是一个游戏决策助手。游戏规则是 XXX。当前回合可能处于以下几种状态: - 出牌阶段:…… - 咨询阶段:…… - 课程选择阶段:…… - 奖励选择阶段:…… - 对话阶段:……

请根据当前状态做出合理决策。

这份 prompt 用了一段时间,问题逐渐暴露:

第一,模型注意力被稀释。 当前明明是”奖励选择”阶段,但 prompt 里同时挂着”出牌策略要点”,模型有时候会把出牌的逻辑误用过来——比如在选奖励的时候开始算”体力剩余”,但奖励选择压根不消耗体力。

第二,无关示例污染输出。 prompt 里给”出牌阶段”配了几个示例,但模型在做”对话选择”的时候会被那些示例的格式带跑,输出一堆莫名其妙的字段。

第三,迭代起来灾难性。 你想改”奖励选择”的策略,得在那一大坨 prompt 里翻定位、改完整体测一遍。改一个地方,怕影响另外八个场景。慢慢就不敢动了——典型的”屎山 prompt”。

拆开之后变成什么

游戏的决策本质上是个状态机。每个状态需要的上下文、可选动作、决策规则差别都很大。我们最后按 phase 把 prompt 全拆了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
gameplay/llm/prompts/
├── _common_terms.j2 # 公共词汇表,每个 prompt 都引一下
├── system_default.j2 # 兜底
├── system_consult.j2 # 咨询
├── system_dialogue.j2 # 对话
├── system_exam.j2 # 考试
├── system_lesson.j2 # 课程
├── system_p_drink.j2 # 饮料使用
├── system_schedule.j2 # 日程规划
├── system_skill_reward.j2 # 技能奖励
├── system_item_select.j2 # 道具选择
├── system_insight_generator.j2 # 经验生成
├── system_insight_reviewer.j2 # 经验复盘
├── system_insight_selector.j2 # 经验选择
├── system_session_memory_compactor.j2 # 会话压缩
├── action_select.j2 # 通用动作选择模板
├── insight_phase.j2
├── insight_step.j2
├── insight_review.j2
├── insight_select.j2
├── state_snapshot.j2 # 状态快照
└── ...

23 个 Jinja2 模板。每个只管一件事。

第一反应:这么多模板,维护成本不爆炸?

实际跑下来,维护成本反而降了。因为每个 prompt 只服务一个场景,改它的时候心理负担是零——改坏了顶多影响这一个场景,不会牵连别的。

为什么是 phase,而不是别的拆法

拆 prompt 的方式很多——按角色、按任务、按难度。我们最后选按 phase(也就是游戏当前所处的阶段),是因为:

1. phase 切换是离散事件,可以由代码确定

每个 phase 用哪个 prompt,是上层调度代码根据状态决定的,模型自己不参与”该用哪个 prompt”的判断。这一点很关键——让模型选 prompt 等于让模型给自己安排任务,可控性立马崩。

2. 每个 phase 的输出 schema 不一样

出牌阶段输出”选哪张卡”,对话阶段输出”选哪个选项”,课程选择输出”去哪个教室”。schema 不一样,prompt 自然不一样——硬要塞进同一份 prompt 还得用条件判断在 prompt 里翻译,反而把模型搞晕。

3. 每个 phase 的”重要约束”不重合

出牌阶段要算资源、留 buff,得反复强调”不要浪费集中”。 课程选择阶段要看周次、看角色成长曲线,得强调”别在最后周做没收益的事”。

这些约束塞在一起会互相干扰——出牌的时候被”考虑长期成长”带跑,开始打留长线的牌;选课程的时候被”高效用牌”带跑,开始算单次收益。各管各的反而专注。

公共词汇表怎么处理

拆完之后有个新问题——每个 prompt 都得解释一遍”集中是什么”“体力是什么”“buff 是什么”,重复且容易写歪。

解决办法是抽一个 _common_terms.j2,把游戏里的核心概念定义全放进去:

1
2
3
4
5
6
7
8
9
10
{# _common_terms.j2 #}
## 核心概念
- **集中**:决定输出倍率的资源,每点集中 +20% 输出
- **体力**:每回合行动消耗,归零强制结束
- **buff/debuff**:……

## 输出格式约束
- 所有数值用阿拉伯数字
- 决策理由控制在 50 字以内
- ……

每个具体的 system prompt 第一句就 include 它:

1
2
3
4
5
{% include "_common_terms.j2" %}

## 当前场景:技能奖励选择
你正在为这次育成选择一个技能奖励。
...

改”集中”的定义?改一处,23 份 prompt 全跟着变。这种小重构在万能 prompt 的时代是不敢想的。

几个具体收益

调试不再是猜谜

模型决策错了,能立刻定位是哪个 phase 的 prompt 出问题。日志里把当前调用用的 prompt 名字记下来:

1
2
[exam phase] -> system_exam.j2 + action_select.j2
[reward phase] -> system_skill_reward.j2

出问题打开对应的模板看就行。万能 prompt 时代是这样的:模型出 bug 了,你只能盯着一大坨 prompt 猜哪几行影响了它。

不同 phase 可以用不同模型

这个收益是拆完之后才意识到的——既然 prompt 都拆了,调用的时候完全可以按需要换模型。

简单的对话选择,丢给小模型,速度起飞、成本砍半。 复杂的出牌决策,留给大模型,质量优先。 经验复盘那种容错率高的,甚至可以用本地部署的开源模型。

万能 prompt 时代是不可能这么干的——一份 prompt 适配所有 phase,等于所有 phase 都得用最强模型,浪费严重。

多人协作变得可能

拆开之后,团队里不同人可以并行迭代不同 phase 的 prompt,互不干扰。改 system_skill_reward.j2 的人不需要去看 system_exam.j2,review 的时候 diff 也清爽。

万能 prompt 时代经常出现”你刚改完那段我刚改完这段,merge 完两个都不对了”的情况。

拆 prompt 的几条原则(踩出来的)

写下来给后面的人参考:

一份 prompt 解决一个明确的决策。 如果你发现自己在写 “如果是 A 情况则 X,如果是 B 情况则 Y” 这种分支,大概率就该拆了。

phase 切换在代码里做,不在 prompt 里做。 让模型负责”在这个 phase 下怎么决策”就够了,不要让它判断”现在是哪个 phase”。

公共概念抽到共享模板,但只抽真正稳定的部分。 那些”可能这个场景适用、可能那个场景不适用”的边界规则就别抽了,留在各自 prompt 里更安全。

每个 prompt 都有自己的输出 schema,文档化。 这一条非常重要——schema 写清楚,下游解析才稳。

prompt 文件命名要能从名字看出用途。 system_exam.j2prompt1.j2 强一万倍。看似废话,真有人会图省事用后者。

回头看

把万能 prompt 拆成 23 个之前,每次给模型加新场景都像在做”开颅手术”——动哪儿都怕带坏全身。拆完之后变成”加个新文件就行”,新增成本几乎归零。

这事儿和我们写代码非常像——一开始一个文件搞定,到了一定规模就得按职责拆模块。prompt 也是代码,只不过它是写给模型看的代码,但同样适用单一职责、显式依赖、可测试性这些老规矩。

下次再听到”一份 system prompt 就够了”这种话,可以友善地笑一下。

上一篇说三段式 insight 解决的是”打完一局怎么沉淀经验”。这篇聊一个更要命的事——一局还没打完,历史就已经撑爆了

游戏里一局训练动辄三四十回合,每回合都有局面快照、可选动作、模型决策、最后结果。要是把这一坨全堆进 message 数组里给 LLM 看,到第十回合左右上下文就开始告急;到第二十回合,要么自动截断丢前面(关键设定一起飞了),要么直接报 context length exceeded。

最早的版本就是这么炸的。

朴素方案为什么都不行

先说几个我试过、然后默默删掉的方案。

方案一:滑动窗口,只保留最近 N 轮

简单粗暴。问题也明显——开局阶段的人物设定、卡组构成、战术意图,全是后面决策的根基,丢了之后模型就开始”失忆型抽风”。第 15 回合突然忘了自己是哪个角色,出牌逻辑直接崩。

方案二:直接 summarize 全部历史

每隔 N 回合让 LLM 把前面的对话总结成一段话。听起来很美。

实际跑起来——总结质量飘忽。运气好的时候压缩比 5:1,运气差的时候把关键数值(“体力还剩多少”“场上 buff 是什么”)当背景细节给丢了。下一回合模型一看,“咦我的体力是多少来着”,开始瞎猜。

方案三:让模型自己决定记什么

试过,问题是模型太”贪心”,让它自己删它一条都舍不得删。

最后的做法:分层压缩

后来落地的方案,本质是把会话分层

1
2
3
4
5
6
7
┌─────────────────────────────┐
│ 永久层:人物 + 卡组 + 目标 │ ← 整局不变,开头注入
├─────────────────────────────┤
│ 压缩层:早期回合的结构化摘要 │ ← 由 compactor 定期生成
├─────────────────────────────┤
│ 最近层:最近 K 回合原始对话 │ ← 完整保留,K 大概 3~5
└─────────────────────────────┘

对应代码就两块:

1
2
3
4
gameplay/llm/
├── session_state.py # 维护三层结构
└── prompts/
└── system_session_memory_compactor.j2 # 干压缩活的 prompt

每一层职责清清楚楚。永久层只在开局写一次,整局只读不改;最近层是滑动窗口;中间的压缩层是关键——由 compactor 负责,每隔若干回合把最早的几轮”挤”成结构化摘要。

Compactor 的 prompt 长什么样

这块的 prompt 设计是反复改了七八版才稳的。最早写得很自由:

请总结以下回合的关键信息……

模型回你一段散文。下一次再总结的时候格式又变了,下游解析根本接不住。

最终版本走的是强结构化输出,让模型填表,不让它自由发挥:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
turns: 5-9
state_changes:
- turn: 5
hp: 80 65
buffs_gained: [集中+2]
cards_consumed: [快攻 x1]
- turn: 7
score: 1200 1800
key_decisions:
- turn: 6
chose: 集中堆叠
why: 为第 10 回合留高倍率窗口
notable_events:
- turn: 8: 对手出了 trap,被预判

填表式输出有几个好处:

  1. 数值不会丢——你让它填 hp,它就得给数;总结成”体力下降不少”这种话直接不合格
  2. 可机器读——压缩层本身可以再被读取、再被组合
  3. token 占用可预测——表格结构每条大致占多少 token 是已知的

实测压缩比稳定在 8:1 到 10:1 之间,关键信息保留率(事后人工抽检)大概九成五以上。

触发时机:不能压太勤也不能压太晚

压缩这事儿有调用成本(每次都是一次 LLM 调用),所以不能每回合都压。但拖太晚,原始历史已经长到 compactor 自己都吃不下了。

最后定的策略很土,但有效:

  • 回合数触发:每过 5 个回合压一次
  • token 数兜底:估算当前 message 数组的 token 数,超过阈值的 60% 就强制触发,不管回合数走到没

兜底很重要——有些卡片效果触发会刷一堆系统提示,单回合就能塞进来几千 token,光按回合数计算根本拦不住。

那些没想到的坑

压缩层也会越攒越长

跑长局的时候发现——压缩层本身也在涨。第 30 回合的时候,压缩层里堆了五六段摘要,加起来比原始最近层还长。

后来加了二级压缩:压缩层里的摘要超过 N 段,就触发 compactor 把最早的几段再合并成一段更粗的摘要。粗摘要里只留”宏观走势”——分数曲线、体力曲线、关键转折点,具体数值就不要了。

层级长这样:

1
2
3
4
5
最近层(5 回合原始)

压缩层(每段 = 5 回合摘要,最多 N 段)

归档层(每段 = 多段压缩层的再压缩,只留趋势)

听起来像金字塔,实际就是不停做”更粗的总结”,让总 token 始终可控。

压缩失败要有兜底,不能让整局崩

LLM 调用偶尔会抽风——超时、返回格式不合规、返回内容明显错乱。压缩失败如果直接抛异常,整局训练流程都得停。

兜底很简单:压缩失败就退化成滑动窗口,直接丢掉最早那几个回合。会损失信息,但至少游戏能继续打。日志里把这次失败标红,事后排查 prompt 或者切换模型。

宁可信息有损,也别让流程断。

永久层不是”system prompt 一塞了事”

最早把人物设定、卡组之类的写进 system message,以为这就是”永久”了。结果发现某些 LLM 对 system message 里靠前的内容关注度不均匀,越长越容易”忘”开头。

后来改成每次发请求的时候,把永久层内容显式地、简短地重复一次,附在最近层之前。冗余了一点 token,但确保模型每一轮都”看得见”自己是谁、打的什么卡组。

回过头看

LLM 上下文管理这事,本质是带预算的存档问题——你有固定的 token 预算,要在里面塞下最重要的信息。和数据库的冷热分层、操作系统的缓存淘汰其实是一个套路:高频访问的全量保留,低频但重要的压缩归档,纯历史的可以糊掉甚至丢弃。

工程化的关键不是”用了多牛的 LLM”,而是把这套分层逻辑显式写出来——哪一层放什么、什么时候压缩、压缩失败怎么兜底,全得有明确规则。把这些规则交给模型自己决定,多半翻车。

下一篇说说更隐藏的事——既然 prompt 这么金贵,那一份 system prompt 走天下到底行不行?为什么我们最后拆成了 23 个?

一开始让 LLM 接管游戏决策的时候,思路特别朴素:把当前局面塞给它,让它出招。打完一局,对话扔掉,下一局从头开始。

跑了几天就坐不住了——同一种局面,同一个坑,它能给你掉进去十几次。每次都是新人入职的状态,前面踩过的雷一概不记得。塞进 system prompt?没几局就把上下文撑爆了;塞进 RAG?流水账日志一堆,检索出来全是噪音。

最后落地的方案是”让它自己整理笔记本”,分了三个角色:generator 写,reviewer 改,selector 挑。听起来像办公室流程,实际跑起来还挺顺。

为什么不一步到位

最早试过最偷懒的版本:打完一局直接让 LLM 总结一句”经验”,扔库里。

效果非常糟糕。

总结出来的东西大概长这样:

本局我打得不错,注意了体力管理,下次继续保持。

这玩意儿写进笔记本,下次检索出来等于没看。再仔细一想,问题挺典型——你让一个刚打完一局、还沉浸在结果里的人当场写复盘,他给你的多半是情绪,不是结论。

人写复盘也得过几道:先把流水账记下来(发生了什么),再回头审视(这事儿哪里能改),最后挑出真正值得记住的(下次别忘了)。这三件事性质完全不同,硬塞给一个 prompt 就是难为模型。

拆开。

三个角色各干什么

我直接把 prompt 目录贴出来,一看就明白:

1
2
3
4
gameplay/llm/prompts/
├── system_insight_generator.j2 # 生产者:从一局轨迹里榨经验
├── system_insight_reviewer.j2 # 审查员:评估这条经验值不值得记
└── system_insight_selector.j2 # 选择器:下一局从库里挑哪几条来看

对应的代码也是三个独立文件:

1
2
3
4
gameplay/llm/
├── insight_generator.py
├── insight_store.py
└── insight_data.py

存储那块由 insight_store.py 兜着——所有经验条目带标签、带场景、带打分一起持久化。三个角色谁都不直接动文件,写读都走 store。

Generator:把一局压成几条候选经验

输入是这局的轨迹:每个回合的局面、自己出了什么牌、对手反应、最后得分。输出是若干条候选 insight。注意是候选,不是直接入库。

prompt 的核心约束就两条:

  • 不要写”我做得很好”这种自我评价,只写”在 X 情况下做了 Y,结果是 Z”
  • 每条带触发条件——后面 selector 才知道什么时候该拿出来

举个真实输出的味道:

触发:剩余回合 ≤ 3 且手牌中有”集中+3”类卡 经验:优先打出集中堆叠卡,把最后一回合留给高倍率输出卡,比平均分高 ~15%

这种带条件、带结果的描述,比”注意体力管理”有用十倍。

Reviewer:把噪音过滤掉

Generator 产出的东西不能全信。模型有时候会把偶然事件当规律——这一局对手出了张烂牌,generator 总结成”对手往往会失误”,扔进库就完蛋。

Reviewer 干的就是泼冷水。它拿到候选 insight 后做三件事:

  1. 样本量够不够?只出现过一次的”规律”直接拒
  2. 跟已有 insight 冲突没有?冲突的话要么合并,要么标记为”待验证”
  3. 触发条件清不清晰?模糊的打回 generator 重写

实测下来 reviewer 大概会刷掉 60%~70% 的候选。一开始觉得太狠,看了几轮通过的,反而觉得这个比例正合适——通过的那批确实条条都能用。

Selector:下一局开打前挑笔记

新一局开始,库里可能已经攒了几十上百条 insight。全塞 system prompt?token 烧不起。随机抽?没意义。

Selector 拿到当前初始局面,去库里挑最相关的 N 条(一般是 3~5 条)。挑的依据就是 generator 写下的”触发条件”——这就是为什么前面非要要求条件写清楚。

selector 自己也是个 LLM 调用,但 prompt 很短:当前局面长这样,候选 insight 有这些(带触发条件),挑出现在用得上的。

为什么不直接用向量检索

刚架这套的时候有人问,为啥不上 embedding 检索,效率高多了。

试过。问题是“触发条件相似”和”语义相似”不是一回事

举个例子。库里有条经验:

触发:体力 ≤ 30 且场上没有恢复牌 经验:优先打高倍率单击牌速攻

当前局面是”体力 25,手里全是 buff 牌”。embedding 检索基于文本相似度,可能给你拉出来一堆”体力低”相关的经验,但这条最关键的”没有恢复牌”被淹没了。

让 LLM 做 selector 反而准。它能真正读懂触发条件里的逻辑,而不是字面相似。代价是多一次 LLM 调用——但这次调用只挑笔记,prompt 短、模型可以用小的,成本可控。

那些踩过的坑

insight 越攒越多,selector 也会糊

库小的时候 selector 很准。攒到上百条之后,候选列表本身就长,模型注意力开始飘。后来加了一道前置过滤:先按 insight 自带的”适用场景”标签做粗筛(这块用字符串匹配就够了),再扔给 selector 精筛。

Generator 太”会写”,容易过度归纳

模型这毛病很顽固——你让它从一局总结经验,它非要总结出五条。明明只有一条有价值。

解决办法是在 prompt 里反复强调”宁缺毋滥,没有可总结的就返回空数组”,并且把 reviewer 的拒绝率作为 generator prompt 的迭代信号——如果某次迭代 reviewer 把 generator 的产出全枪毙了,下次 generator 就该收敛。

Reviewer 不能和 generator 用同一个模型实例

最早图省事,三个角色复用同一个 client,结果发现 reviewer 对 generator 的产出特别宽容——你猜怎么着,模型对自己刚写的东西有偏爱,会下意识地认可。

换成不同的 client 实例(甚至不同模型)之后,reviewer 立马严厉起来。这个现象在论文里也有讨论,但自己撞到才印象深刻。

整体看下来

三段式拆完之后,最直观的感受是每一段的 prompt 都变短了,模型每次只做一件事,输出质量肉眼可见地稳。

更隐性的好处是可观测性——哪条 insight 从哪局来的、被 reviewer 怎么评的、selector 在哪些局调出来用过,全都可以单独追踪。出了问题不再是”模型怎么又抽风了”,而是”哪一段的判断错了”。

调试 Agent 系统最缺的就是这种”哪一步出问题”的拆解能力。把决策流拆开、让每一段都能单独打分,比换更大的模型管用得多。

下一篇打算聊聊另一块更恼火的事——这游戏一局打几十回合,对话历史涨得比谁都快,怎么压缩才能不丢关键信息。那是另一个坑。

YOLO 默认的 NMS 是按类别分开做的——同一个位置如果检测到一个 A 类目标和一个 B 类目标,两个都会保留。

大部分场景下,这是对的。一只猫和它身边的项圈,本来就该都被检测出来嘛。

但在游戏里识别卡片时就完蛋。同一张技能卡只能是 Skill Card: ActiveSkill Card: MentalSkill Card: Trap 三选一,模型有时候会同时给同一个位置打两个标签,置信度还都挺高。按类别独立 NMS 一处理,俩都保留下来——传到业务层就懵了,一张卡顶俩身份证,处理逻辑当场崩盘。

这种情况需要”跨类别 NMS”:指定一组互斥类别,让它们之间也参与互相抑制。

对比一下

标准 NMS:

1
2
3
4
5
- Active 卡: (100, 100, 200, 200), 置信度 0.9
- Mental 卡: (105, 105, 205, 205), 置信度 0.8

→ 两类各自做 NMS
→ Active 保留,Mental 保留

Agnostic NMS(指定 Active/Mental/Trap 为一组):

1
2
3
4
5
- Active 卡: (100, 100, 200, 200), 置信度 0.9
- Mental 卡: (105, 105, 205, 205), 置信度 0.8

→ 三类作为一个整体做 NMS
→ Active 保留,Mental 被高 IoU 抑制掉

从模型元数据构建分组

类别名直接在 YOLO 导出的 meta 里。所以分组逻辑可以放在引擎层,自动识别哪些类别需要互斥:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class YoloModelFromONNX:
def __init__(self, model_path: str):
# ...
self._agnostic_nms_groups = self._build_agnostic_nms_groups()

def _build_agnostic_nms_groups(self) -> list | None:
skill_labels = {"Skill Card: Active", "Skill Card: Mental", "Skill Card: Trap"}
group = set()

for cid, name in self._model_meta.names.items():
if name in skill_labels:
group.add(cid)

return [group] if len(group) >= 2 else None

只有当组里有两个或以上类别时才返回——单一类别”自己抑制自己”没意义,直接交给标准 NMS 就行。

后处理实现

具体的 NMS 实现思路:把检测框先分桶——属于某个 agnostic 组的丢到同一个桶,普通类别按 class_id 自己一个桶。然后每个桶独立做 cv2.dnn.NMSBoxes

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import cv2
import numpy as np

def _postprocess(
self,
results: np.ndarray,
conf_threshold: float,
iou_threshold: float,
agnostic_nms_groups: list | None = None,
) -> YoloResult:
boxes = []
scores = []
class_ids = []

for detection in results:
class_scores = detection[5:]
max_score = np.amax(class_scores)
if max_score >= conf_threshold:
class_id = np.argmax(class_scores)
x, y, w, h = detection[:4]
boxes.append([x - w/2, y - h/2, w, h])
scores.append(max_score)
class_ids.append(class_id)

if not boxes:
return YoloResult.empty()

# 类别 → 组索引
agnostic_map = {}
if agnostic_nms_groups:
for gi, group in enumerate(agnostic_nms_groups):
for cid in group:
agnostic_map[cid] = gi

# 分桶
grouped = {}
for i, cid in enumerate(class_ids):
if cid in agnostic_map:
key = ("agnostic", agnostic_map[cid])
else:
key = ("class", int(cid))
grouped.setdefault(key, []).append(i)

# 每个桶独立 NMS
final_indices = []
for indices_in_group in grouped.values():
group_boxes = [boxes[i] for i in indices_in_group]
group_scores = [scores[i] for i in indices_in_group]

keep = cv2.dnn.NMSBoxes(
group_boxes, group_scores, conf_threshold, iou_threshold
)

if keep is not None:
keep = keep.flatten()
for idx in keep:
final_indices.append(indices_in_group[idx])

return YoloResult(
boxes=np.array([boxes[i] for i in final_indices]),
scores=np.array([scores[i] for i in final_indices]),
class_ids=np.array([class_ids[i] for i in final_indices]),
)

key 用元组 ("agnostic", gi) / ("class", cid)——是为了让两种分桶方式共存而不冲突。要是直接拿数字做 key,agnostic 组索引 1 会和 class_id 1 撞车,又是一个查到怀疑人生的 bug。

几种典型分组

技能卡的三种类型互斥:

1
2
skill_labels = {"Skill Card: Active", "Skill Card: Mental", "Skill Card: Trap"}
agnostic_nms_groups = [skill_labels]

同一按钮的不同状态互斥(按下/正常/禁用):

1
2
button_labels = {"Button: Normal", "Button: Pressed", "Button: Disabled"}
agnostic_nms_groups = [button_labels]

多组可以并存。不同组之间不互相抑制——按钮不会去抑制技能卡:

1
2
3
4
agnostic_nms_groups = [
{"Skill Card: Active", "Skill Card: Mental", "Skill Card: Trap"},
{"Button: Normal", "Button: Pressed"},
]

性能其实没啥好担心的

直觉上”分了 k 个桶要做 k 次 NMS”会觉得慢,但实际上:

  • 总框数没变,每个桶里的框数变少,NMS 是 O(n²) 的,反而每个桶各自更快
  • 分桶本身 O(n),可以忽略
  • 多出来的常数开销,就是初始化几个 list/dict

实际跑下来开销和标准 NMS 一个量级——别想太多。

没事别乱用

写出来这么方便,但默认 NMS 已经是经过深思熟虑的设计。Agnostic NMS 只在”同一位置物理上不能同时有多个目标”的场景才合适:

  • 一个像素只能属于一个 skill card 类型 → 适合
  • 一只猫和它的项圈在同一位置 → 不适合,会把项圈抑制掉,那就闹笑话了

错用会造成漏检,而漏检比误检更难调试——输出看着啥都没问题,业务逻辑就是间歇性出错,那种感觉,谁查谁知道。

游戏自动化里有个老大难问题:YOLO 框出来一张卡,到底是哪张?

模板匹配早就玩不转了——同一张卡换个分辨率立马失效,每出新卡你都得苦哈哈地重新挑模板。OCR 呢,倒是认字,但碰上立绘风格的卡面、花里胡哨的字体,那叫一个抓瞎。

折腾了一圈,最后换了套思路:让 CLIP 把每张图编成向量塞进库里,下次再来就用余弦相似度找。第一次见的存档,第二次见就能”认出来”——这就是这个项目里说的”视觉记忆库”。

工作流闭环那一块我另写了一篇,这里只聊几个做这个东西时被坑得不轻的工程细节。

CLIP 模型本体

用的是导出的 ONNX visual encoder,进 224×224,出特征向量。预处理没啥花活,但顺序得跟训练时严丝合缝——BGR → RGB → letterbox → 中心裁剪 → 归一化,mean/std 直接抄 OpenAI CLIP 官方那一套:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import numpy as np
import onnxruntime as ort
import cv2

class CLIPModelFromONNX:
def __init__(self, model_path: str):
self.session = ort.InferenceSession(model_path)
self._input_name = self.session.get_inputs()[0].name

def _preprocess(self, image: np.ndarray) -> np.ndarray:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image, _, _ = letterbox(image, (224, 224))
image = center_crop(image)
image = image.astype(np.float32) / 255.0
mean = np.array([0.48145466, 0.4578275, 0.40821073])
std = np.array([0.26862954, 0.26130258, 0.27577711])
image = (image - mean) / std
image = np.transpose(image, (2, 0, 1))
return image[np.newaxis, :].astype(np.float32)

def forward(self, image: np.ndarray) -> np.ndarray:
input_tensor = self._preprocess(image)
output = self.session.run(None, {self._input_name: input_tensor})
return output[0]

预处理顺序错一个,整个特征空间就跟训练时对不上——表现就是检索效果烂得跟瞎猜似的。这种问题最难查,因为啥都不报错,就是结果不对。

入库 / 检索

库本身就是”特征向量 + payload”的一堆条目。每次入库前都和已有的过一遍相似度,超过 0.96 就当成”见过”,直接跳过:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import pickle
from dataclasses import dataclass

@dataclass
class CLIPRetrieveData:
payload: Any
similarity: float

class CLIPMemory:
def __init__(self, clip_name: str):
self._clip_name = clip_name
self._engine = CLIPModelFromONNX("clip_visual.onnx")
self._image_file_path = f"data/CLIP/{clip_name}"

@staticmethod
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
a, b = a.flatten(), b.flatten()
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def add_to_memory(self, image, payload, threshold=0.96) -> bool:
features = self._engine.forward(image)
for memory in self._load_memories():
existing = pickle.loads(memory.features)
if self._cosine_similarity(features, existing) > threshold:
return False
self._save_memory(features, payload)
return True

def retrieve(self, image, threshold=0.96):
features = self._engine.forward(image)
best_sim, best_payload = -1, None
for memory in self._load_memories():
existing = pickle.loads(memory.features)
sim = self._cosine_similarity(features, existing)
if sim > best_sim:
best_sim, best_payload = sim, memory.payload

if best_payload is not None and best_sim > threshold:
return CLIPRetrieveData(best_payload, best_sim)
return None

0.96 这个数纯属调出来的——0.99 太苛刻,同一张卡换个背景立马就漏;0.92 又太松,长得像的不同卡能给你撮合到一起去。每加一类新元素都得重新校一次,烦是真烦,可没办法。

数据增强:让特征更耐折腾

后来发现一个怪事:同一张卡换个分辨率截图,相似度能掉到 0.93,刚好卡在阈值边上,时灵时不灵。

解法是入库的时候就主动造几个变体一起存进去,免得运行时被各种”小变化”打个措手不及。常用的几种增强:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def augment_image(image: np.ndarray) -> list:
augmented = []
h, w = image.shape[:2]

# JPEG 压缩
_, buf = cv2.imencode(".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, 30])
augmented.append(cv2.imdecode(buf, cv2.IMREAD_COLOR))

# 高斯噪点
noise = np.random.normal(0, 15, image.shape).astype(np.float32)
noisy = np.clip(image.astype(np.float32) + noise, 0, 255).astype(np.uint8)
augmented.append(noisy)

# 亮度偏移
augmented.append(cv2.convertScaleAbs(image, alpha=1.15, beta=20))

# 色调偏移
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV).astype(np.int16)
hsv[:, :, 0] = (hsv[:, :, 0] + 8) % 180
augmented.append(cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR))

# 缩放
small = cv2.resize(image, (int(w * 0.85), int(h * 0.85)))
augmented.append(cv2.resize(small, (w, h)))

return augmented

不过也不能无脑都塞。增强后的样本要是跟原图相似度还在 0.95 以上,那存进去也是占地方——所以加了个判断,相似度低于 0.95 才真入库:

1
2
3
4
5
6
7
def add_to_memory_with_augmentation(self, image, payload):
self.add_to_memory(image, payload)
orig_features = self._engine.forward(image)
for aug_image in augment_image(image):
aug_features = self._engine.forward(aug_image)
if self._cosine_similarity(orig_features, aug_features) < 0.95:
self._save_memory(aug_features, payload)

类内偏差验证:防手抖

如果某个 payload 已经攒了几个样本,新加图的时候多走一步——这张新图跟同类样本的最高相似度,不能太低。低于 0.72 基本可以判定是标错了,直接拒了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def validate_inclass_similarity(self, image, target_id) -> bool:
MIN_INCLASS_SIMILARITY = 0.72
image_features = self._engine.forward(image)
best_inclass = -1.0
for memory in self._load_memories():
if memory.payload_id != target_id:
continue
existing = pickle.loads(memory.features)
sim = self._cosine_similarity(image_features, existing)
best_inclass = max(best_inclass, sim)

if best_inclass >= 0 and best_inclass < MIN_INCLASS_SIMILARITY:
return False
return True

这条加进来之前出过一次事故。调试的时候手滑点错了一张卡,错的卡也跟着进了库,越用越歪——后面识别越来越离谱,回查才发现根子在这。从那以后这检查就一直留着了。

失效记忆清理

payload 是引用业务库里卡片记录的,但业务库会被用户删卡、改版本——记忆库这边就会堆出一堆”指向不存在记录”的死链。

懒得专门起后台任务跑清理,干脆放在检索路径上顺手收一下,反正检索远比入库频繁:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def retrieve_with_cleanup(self, image, threshold=0.96):
features = self._engine.forward(image)
best_sim, best_payload = -1, None
stale_ids = []

for memory in self._load_memories():
try:
payload = memory.load_payload()
if payload is None:
stale_ids.append(memory.id)
continue
except Exception:
stale_ids.append(memory.id)
continue

existing = pickle.loads(memory.features)
sim = self._cosine_similarity(features, existing)
if sim > best_sim:
best_sim, best_payload = sim, payload

if stale_ids:
self._delete_memories(stale_ids)

if best_payload and best_sim > threshold:
return CLIPRetrieveData(best_payload, best_sim)
return None

共享一个模型会话

支援卡、技能卡、道具、偶像卡,各管各的记忆库,但用的都是同一个 visual encoder。要是每个识别器都各自加载一份 ONNX,内存能给你撑爆。

所以模型会话只持有一份,所有识别器共享:

1
2
3
4
5
6
7
class CLIPServiceManager:
def __init__(self):
self._model_session = CLIPModelFromONNX("clip_visual.onnx")
self.support_card_clip = SupportCardCLIP(self._model_session)
self.skill_card_clip = SkillCardCLIP(self._model_session)
self.item_clip = ItemCLIP(self._model_session)
self.idol_card_clip = IdolCardCLIP(self._model_session)

提醒一句:ONNX Runtime 的 session 写入操作不是线程安全的,但只读推理没问题——只要别让不同识别器同时去改 session options,共享就是安全的。

这个工具 macOS 和 Windows 都得用,但 OCR 引擎在两个平台上没有哪个是完美的。

macOS 上最舒服的,其实就是系统自带的 Vision 框架——不用装额外运行时、不占内存、中文识别还挺不错。但 Vision 在 Windows 上?根本不存在。

Windows 这边比较通用的方案是 RapidOCR(PaddleOCR 的 ONNX 版本)——跨平台、安装简单、识别质量也够用,就是占用比 Vision 大一些。Linux 上 Vision 不可用,那就只剩 RapidOCR 这条路了。

要是硬编码任一种,结果都很尬。所以做了个能切换的后端。

基类和数据结构

最小接口就两个方法:ocris_available。后者用来在初始化时筛掉当前平台跑不动的后端:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List

@dataclass
class OCRResult:
text: str
confidence: float
bbox: tuple # (x1, y1, x2, y2)

@dataclass
class OCRResponse:
results: List[OCRResult]
backend: str

class BaseOCRBackend(ABC):
@abstractmethod
def ocr(self, image) -> OCRResponse: ...

@abstractmethod
def is_available(self) -> bool: ...

OCRResponse 里带个 backend 字段,是给调试时方便用的——日志里能直接看出来”这次识别是哪个后端跑出来的”,省得猜。

RapidOCR 后端

RapidOCR 是 pip 包,import 之后直接用。is_available 实际上就是看包能不能 import 进来:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class RapidOCRBackend(BaseOCRBackend):
def __init__(self):
from rapidocr_onnxruntime import RapidOCR
self._engine = RapidOCR()

def ocr(self, image) -> OCRResponse:
result, _ = self._engine(image)
if result is None:
return OCRResponse(results=[], backend="rapidocr")

ocr_results = []
for line in result:
bbox, text, confidence = line
ocr_results.append(OCRResult(text=text, confidence=confidence, bbox=bbox))
return OCRResponse(results=ocr_results, backend="rapidocr")

def is_available(self) -> bool:
try:
import rapidocr_onnxruntime
return True
except ImportError:
return False

Vision 后端

Vision 这边走 PyObjC 桥接,坑稍微多点。图像得先转成 CIImage、识别请求是异步的但可以同步 wait、识别语言要显式声明否则识别率会掉——一步走错就给你来个屏幕全空:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class VisionOCRBackend(BaseOCRBackend):
def ocr(self, image) -> OCRResponse:
import Vision
import Quartz

ci_image = Quartz.CIImage.imageWithCGImage_(image)

request = Vision.VNRecognizeTextRequest.alloc().init()
request.setRecognitionLanguages_(["zh-Hans", "en"])
request.setRecognitionLevel_(Vision.VNRequestTextRecognitionLevelAccurate)

handler = Vision.VNImageRequestHandler.alloc().initWithCIImage_options_(
ci_image, None
)
success = handler.performRequests_error_([request], None)

if not success:
return OCRResponse(results=[], backend="vision")

ocr_results = []
for observation in request.results():
text = observation.topCandidates_(1)[0].string()
confidence = observation.topCandidates_(1)[0].confidence()
bbox = observation.boundingBox()
ocr_results.append(OCRResult(text=text, confidence=confidence, bbox=bbox))
return OCRResponse(results=ocr_results, backend="vision")

def is_available(self) -> bool:
import platform
return platform.system() == "Darwin"

VNRequestTextRecognitionLevelAccurateFast 慢三五倍,但识别率明显高一截。这种取舍很容易选——我宁愿慢 100ms,也不愿意识别错。

工厂 + 回退

工厂函数干两件事:根据 requested 解析出候选列表,按顺序尝试,第一个能用的就用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from enum import Enum

class OCRBackendType(str, Enum):
AUTO = "auto"
RAPIDOCR = "rapidocr"
VISION = "vision"

def create_ocr_backend(requested_backend: str = None) -> BaseOCRBackend:
requested = requested_backend or get_configured_backend()
candidates = resolve_candidates(requested)

errors = []
for candidate in candidates:
try:
backend = _build_backend(candidate)
if requested == "auto":
logger.info(f"Using OCR backend {candidate} (auto)")
elif candidate != requested:
logger.warning(f"OCR backend {requested} unavailable, fallback to {candidate}")
return backend
except Exception as exc:
errors.append(f"{candidate}: {exc}")
logger.warning(f"Failed to initialize OCR backend {candidate}: {exc}")

raise RuntimeError(f"Failed to initialize any OCR backend: {' | '.join(errors)}")

def _build_backend(backend_name: str) -> BaseOCRBackend:
if backend_name == "vision":
return VisionOCRBackend()
if backend_name == "rapidocr":
return RapidOCRBackend()
raise ValueError(f"Unknown OCR backend: {backend_name}")

按平台选默认顺序

auto 模式的候选顺序看平台脸色:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import platform

def resolve_candidates(requested: str) -> list:
current_platform = platform.system()

if requested == "auto":
if current_platform == "Darwin":
return ["vision", "rapidocr"] # macOS:Vision 优先
return ["rapidocr"]

if requested == "vision":
# 用户指名要 Vision,但失败了就降级到 RapidOCR
return ["vision", "rapidocr"]

return ["rapidocr"]

"vision" 也支持降级,是为了开发者方便——比如他本来在 Mac 上 debug,配置文件写死了 vision,临时塞给 Windows 同事跑也不会当场炸锅。

环境变量覆盖

环境变量优先级高于配置文件,方便临时切换:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import os

def get_configured_backend() -> str:
if env_backend := os.getenv("GAKUMAS_OCR_BACKEND"):
return normalize_backend_name(env_backend)

from config import ConfigService
return ConfigService().base.ocr_backend

def normalize_backend_name(value: str) -> str:
normalized = value.strip().lower()
valid_values = ["auto", "rapidocr", "vision"]
if normalized in valid_values:
return normalized
return "auto"

未知值就退回到 "auto",不报错——这个文件有可能是从老版本配置迁移过来的,宽松一点比直接挂掉好。

对外的统一入口

业务代码不直接用 backend,走 OCRService

1
2
3
4
5
6
7
8
9
10
class OCRService:
def __init__(self):
self._backend = create_ocr_backend()

def ocr(self, image) -> OCRResponse:
return self._backend.ocr(image)

def ocr_text(self, image) -> str:
response = self.ocr(image)
return " ".join(r.text for r in response.results)

后端切换对业务零感知。以后想加个 PaddleOCR-GPU 之类的新后端,也就是在工厂里多注册一个类的事——业务代码动都不用动。

游戏自动化要持续抓屏识别。一开始很自然地写成”主线程里循环抓图 → 推理 → 派发结果”——结果 UI 直接卡成 PPT。

原因不复杂:推理本身要十几毫秒,抓屏在某些情况下能阻塞几十毫秒。主线程一直被这俩活儿占着,UI 哪还有机会响应。

后来拆开了——推理线程在后台跑,主线程只读最新结果。这篇就讲讲这个引擎里几个关键设计的来龙去脉。NMS 相关的内容写在另一篇里了,这里专心聊线程模型。

推理在独立线程

主线程能做的只有一件事:拿一个最新的推理结果引用。所有的截屏、推理、回调,全都丢到后台线程里:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import threading
from time import sleep
from typing import Callable, List

class YoloInferenceEngine:
def __init__(self, device):
self._device = device
self._engine = None
self._latest_results = None

self._infer_callback_list: List[Callable] = []
self._capture_failure_callback_list: List[Callable] = []

self.__flag_loop = False
self.__flag_pause = False

self.__action_lock = threading.Lock()
self.__result_write_lock = threading.Lock()

def start(self):
with self.__action_lock:
if self.__flag_loop:
return False
self.__flag_pause = False
self.__flag_loop = True
self._capture_thread = threading.Thread(
target=self._inference_loop,
daemon=True,
)
self._capture_thread.start()
return True

def stop(self):
with self.__action_lock:
if not self.__flag_loop:
return False
self.__flag_pause = False
self.__flag_loop = False
self._capture_thread.join(timeout=3)
return True

def pause(self):
with self.__action_lock:
if self.__flag_pause:
return False
self.__flag_pause = True
return True

def resume(self):
with self.__action_lock:
if not self.__flag_pause:
return False
self.__flag_pause = False
return True

daemon=True 是 GUI 应用必备——主进程退出的时候,可不希望被推理线程拽住卡半天。join(timeout=3) 是个保底,实际上 daemon 模式下不 join 也没事。

pause/resume 这一对方法,是为了切换模型时不必整个 stop 再 start。后面会用到。

推理循环本体

写得很直白——只在没暂停且能抓到帧的情况下做推理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def _inference_loop(self):
while self.__flag_loop:
if self.__flag_pause:
sleep(0.1)
continue

try:
frame = self._device.capture()
except Exception as e:
self._exec_capture_failure_callback(e)
self.__flag_loop = False
return

if frame is None or frame.size <= 0:
sleep(0.1)
continue

results = self._engine(
frame,
conf_threshold=0.6,
agnostic_nms_groups=self._agnostic_nms_groups,
)

with self.__result_write_lock:
self._latest_results = Yolo_Results(results, frame)

self._exec_infer_callback()

self.__flag_loop = False

抓帧失败是个比较严重的事——通常是设备掉了(虚拟摄像头被关、游戏退出)。这种时候直接终止循环比空转更合适,所以走的是 return 而不是 continue。回调那边可以让上层决定要不要重启。

空帧(黑屏的过场动画、设备临时无数据)只是 sleep(0.1) continue——不算异常,等下一帧就行。

回调注册

业务层不需要轮询 latest_results——它只需要在新结果出现时被通知一下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def register_infer_callback(self, func: Callable):
with self.__action_lock:
if func not in self._infer_callback_list:
self._infer_callback_list.append(func)

def register_capture_failure_callback(self, func: Callable):
with self.__action_lock:
if func not in self._capture_failure_callback_list:
self._capture_failure_callback_list.append(func)

def _exec_infer_callback(self):
for callback in self._infer_callback_list:
try:
callback(self.latest_frame, self.latest_results)
except Exception as e:
logger.error(f"推理回调失败: {e}")

def _exec_capture_failure_callback(self, exc: Exception):
for callback in self._capture_failure_callback_list:
try:
callback(exc)
except Exception as callback_exc:
logger.error(f"截图失败回调失败: {callback_exc}")

每个回调单独 try/except 包起来——一个回调炸了不能拖累其他回调。

这是 callback 列表的常见坑。第一版没包,结果遇到过一次:某个调试用的可视化回调抛了异常,正常业务回调全部不执行。当场一脸黑线。

业务侧用起来:

1
2
3
4
5
6
7
8
9
engine = YoloInferenceEngine(device)

def on_inference_result(frame, results):
for box, score, class_id in zip(results.boxes, results.scores, results.class_ids):
if score > 0.8:
print(f"检测到: {class_id}, 置信度: {score:.2f}")

engine.register_infer_callback(on_inference_result)
engine.start()

状态机

四个状态:Stopped、Running、Paused、(Failed)。最后一个其实只是 Stopped + capture_failure callback——没单独定义。

1
2
3
4
5
6
7
         start()
Stopped ---------> Running
^ |
| stop() | pause()
| v
+--------- Paused <---+
resume()

只读访问器全部加锁:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
@property
def running(self) -> bool:
with self.__action_lock:
return self.__flag_loop

@property
def is_pause(self) -> bool:
with self.__action_lock:
return self.__flag_pause

@property
def latest_frame(self):
try:
return self._latest_results.frame
except Exception:
return None

@property
def latest_results(self):
try:
return self._latest_results
except Exception:
return None

注意 latest_frame/latest_results 这俩没拿 __result_write_lock——这里有个隐患:推理线程正在写的瞬间被读,可能拿到半成品对象。

不过实际上因为 Python 的引用赋值是原子的,self._latest_results = Yolo_Results(...) 这一行的执行,读端要么拿到旧引用要么拿到新引用,不会出现”半个对象”。

但前提是 Yolo_Results 一旦构造就不可变——这是个约定,靠这个约定不需要加锁。约定要是哪天破了,那就好玩了。

模型热切换

游戏不同界面(剧情 UI、培育 UI、考试 UI)用的 YOLO 模型不同。切换时要把当前推理打断,加载新模型,再恢复:

1
2
3
4
5
6
7
8
9
10
11
def load_model(self, model_type: str = YoloModelType.BASE_UI):
if self.__flag_loop:
self.pause()

with self.__action_lock:
self._engine = YoloModelFromONNX(config.model_config[model_type])
self._model_type = model_type
self._agnostic_nms_groups = self._build_skill_card_nms_group()

if self.__flag_loop:
self.resume()

要是直接 self._engine = YoloModelFromONNX(...) 不暂停,推理线程可能正在用旧的 _engine 跑——切换瞬间会冒出来一些”半新半旧”的怪结果。pause() → 改 _engineresume(),把这个空窗堵上。

不过这里还有个小漏洞:pause 内部只是设 flag,真正生效得等推理线程跑到下一次循环检查到 flag。所以 pause 返回后立马去切模型,理论上还能撞上一次正在跑的推理。

讲究的做法是等推理线程进入 pause 状态再切,但目前用户感知不到这个 race,先记着,没动。

ONNX Runtime 跨平台是真的能跨——同一份模型文件,Windows、macOS、Linux 都能跑。但加速后端就完全不通用了:Windows 用 DirectML,macOS 用 CoreML,Linux 大多数情况只剩 CPU 一条路。

要是代码里直接写死 CPUExecutionProvider——在 Mac 上能跑,但 Apple Silicon 的 ANE 完全没用上,浪费;在 Windows 上又用不到核显加速,亏。

下面是这个项目里管理 ORT 会话的几个关键点:自动探测、CoreML 缓存、动态维度覆盖、运行时回退。

优先级列表

按平台从高到低排个序,找到一个能用的就行:

1
2
3
4
5
6
7
import onnxruntime as ort

PREFERRED_EXECUTION_PROVIDERS = (
"DmlExecutionProvider", # Windows
"CoreMLExecutionProvider", # macOS
"CPUExecutionProvider", # 兜底
)

get_available_providers() 返回的是当前 ORT 编译时带的所有 provider 名字,但不代表全部都能真正初始化成功——这点很多人会忽略。所以要先用列表过滤一遍,再实际尝试创建会话:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class DMLManager:
@classmethod
def get_session_providers(cls) -> list[str]:
available_providers = set(ort.get_available_providers())
providers = [
provider
for provider in cls._preferred_execution_providers
if provider in available_providers
]
if "CPUExecutionProvider" not in providers:
providers.append("CPUExecutionProvider")
return providers

@classmethod
def create_session(cls, model_path: str) -> ort.InferenceSession:
providers = cls.get_session_providers()
so = cls._build_session_options()
try:
return ort.InferenceSession(model_path, sess_options=so, providers=providers)
except Exception as exc:
logger.warning(f"加速后端失败,回退到 CPU: {exc}")
return ort.InferenceSession(
model_path,
sess_options=so,
providers=["CPUExecutionProvider"],
)

CoreML 缓存

CoreML 第一次加载 ONNX 模型时,会做一次”编译”——把模型转成 CoreML 内部的表示。这一编译可不得了,一个 YOLO nano 模型在 M1 上能搞十几秒,大模型甚至跑好几分钟。用户那边看到的就是:软件启动条爬得跟便秘似的。

好在只要模型文件没变,编译产物就能复用。所以加了个 cache,用模型文件的哈希做 key:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import hashlib
from pathlib import Path

class DMLManager:
@classmethod
def _build_model_cache_key(cls, model_path: str, provider_options: dict) -> str:
model_file = Path(model_path).resolve()
digest = hashlib.sha256()

# 模型本体和 .data 外部权重都要算进去
for file_path in cls._iter_model_fingerprint_files(model_file):
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
digest.update(chunk)

digest.update(repr(sorted(provider_options.items())).encode())

if cls._free_dimension_overrides:
digest.update(repr(cls._free_dimension_overrides).encode())

return digest.hexdigest()[:24]

@classmethod
def _build_provider_config(cls, model_path: str) -> list:
providers = []
available = cls.get_session_providers()

if "DmlExecutionProvider" in available:
providers.append("DmlExecutionProvider")

if "CoreMLExecutionProvider" in available:
cache_root = cls._get_cache_root()
cache_dir = cache_root / "coreml-cache"

provider_options = {
"ModelFormat": "MLProgram",
"MLComputeUnits": "ALL",
"RequireStaticInputShapes": "0",
}
model_cache_key = cls._build_model_cache_key(model_path, provider_options)
cache_subdir = cache_dir / model_cache_key
cache_subdir.mkdir(parents=True, exist_ok=True)

# 顺手清旧版本的缓存
cls._prune_stale_coreml_cache(cache_dir, model_cache_key)

providers.append((
"CoreMLExecutionProvider",
{**provider_options, "ModelCacheDirectory": str(cache_subdir)},
))

if "CPUExecutionProvider" in available:
providers.append("CPUExecutionProvider")

return providers

哈希里得带上 provider_optionsfree_dimension_overrides——这两个变了,缓存的编译产物就废了。要是不带,会出现”明明改了配置但行为没变”的灵异事件。

缓存损坏要能重试

CoreML 缓存目录被意外搞坏的情况,是真的会发生:手动删了一半文件、磁盘满了写到一半、用户在版本之间倒退……最后都会变成 ORT 初始化时报错。

处理方式是:捕获异常 → 清掉缓存目录 → 再试一次 → 还失败就回 CPU:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
@classmethod
def create_session_with_retry(cls, model_path: str) -> ort.InferenceSession:
providers = cls._build_provider_config(model_path)
so = cls._build_session_options()

try:
return ort.InferenceSession(model_path, sess_options=so, providers=providers)
except Exception as exc:
coreml_cache_dir = cls._extract_coreml_cache_dir(providers)
if coreml_cache_dir is not None:
cls._clear_coreml_cache_dir(coreml_cache_dir)
logger.warning(f"清除 CoreML 缓存后重试: {exc}")
providers = cls._build_provider_config(model_path)
try:
return ort.InferenceSession(model_path, sess_options=so, providers=providers)
except Exception as retry_exc:
logger.warning(f"重试失败,回退到 CPU: {retry_exc}")

return ort.InferenceSession(
model_path,
sess_options=so,
providers=["CPUExecutionProvider"],
)

动态维度覆盖

CoreML 不支持动态 batch/H/W——这是个老问题。

YOLO 模型导出时通常会保留 batch 维度为符号”batch”或者”batch_size”,输入 H/W 也可能是动态的。直接拿去给 CoreML 加载,立马报错给你看。

办法是用 add_free_dimension_override_by_name,在 session option 上把符号维度钉死:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class DMLManager:
_free_dimension_overrides = [
("batch", 1),
("batch_size", 1),
]

@classmethod
def _build_session_options(cls, extra_overrides: dict = None) -> ort.SessionOptions:
so = ort.SessionOptions()
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
so.intra_op_num_threads = 1
so.inter_op_num_threads = 1

for name, value in cls._free_dimension_overrides:
so.add_free_dimension_override_by_name(name, value)

if extra_overrides:
for name, value in extra_overrides.items():
so.add_free_dimension_override_by_name(name, value)
return so

@classmethod
def _build_dim_overrides_for_model(cls, model_path: str) -> dict:
overrides = {}
imgsz = cls._read_model_imgsz(model_path)
if imgsz is not None:
overrides["height"] = imgsz[0]
overrides["width"] = imgsz[1]
return overrides

H 和 W 的具体值不能写死——会从模型的 meta(YOLO 导出时会把 imgsz 写进 ONNX 的 metadata)里读。不同 YOLO 模型可能是 640、960、1280,统一一个常量肯定行不通。

运行时回退

加速后端在初始化时能跑起来,不代表运行时不会挂——见过 DirectML 在显存吃紧时直接抛 RuntimeException 的情况,那一刻服务整个就停摆了。

这种时候不应该让推理失败,得把会话静默换成 CPU 重跑一次:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class DMLManager:
_session_replacements: dict[int, ort.InferenceSession] = {}

@classmethod
def run(cls, session: ort.InferenceSession, feeds: dict) -> list:
with cls._lock:
active_session = cls._resolve_session(session)
try:
return active_session.run(None, feeds)
except Exception as exc:
fallback_session = cls._fallback_to_cpu_session(active_session, exc)
if fallback_session is None:
raise
return fallback_session.run(None, feeds)

@classmethod
def _fallback_to_cpu_session(cls, session, exc) -> ort.InferenceSession | None:
session_id = id(session)
model_path = cls._session_model_paths.get(session_id)
provider_names = cls._session_provider_names.get(session_id)

if not model_path or not provider_names:
return None
if tuple(provider_names) == ("CPUExecutionProvider",):
return None # 已经是 CPU 了

logger.warning(f"ONNX 推理失败,回退到 CPU: {exc}")

fallback_session = cls._create_cpu_session(model_path)
# 同一会话以后再来,直接走 CPU
cls._session_replacements[session_id] = fallback_session
return fallback_session

@classmethod
def _resolve_session(cls, session) -> ort.InferenceSession:
current = session
visited = set()
while True:
current_id = id(current)
replacement = cls._session_replacements.get(current_id)
if replacement is None or current_id in visited:
return current
visited.add(current_id)
current = replacement

_session_replacementsid(session) 做 key,因为 ORT 的 session 本身不可哈希。visited 这个 set 是防 replacement 链出现环——理论上不该发生,但写一行不亏。

回退之后,这个会话后续所有请求都自动走 CPU,不再尝试加速后端。这个语义是故意的——一次失败可能是偶发,但加速后端一旦不稳,反复尝试只会拖死业务。要恢复加速?只能重启进程。