CLIP 视觉记忆库的几个工程细节

游戏自动化里有个老大难问题:YOLO 框出来一张卡,到底是哪张?

模板匹配早就玩不转了——同一张卡换个分辨率立马失效,每出新卡你都得苦哈哈地重新挑模板。OCR 呢,倒是认字,但碰上立绘风格的卡面、花里胡哨的字体,那叫一个抓瞎。

折腾了一圈,最后换了套思路:让 CLIP 把每张图编成向量塞进库里,下次再来就用余弦相似度找。第一次见的存档,第二次见就能”认出来”——这就是这个项目里说的”视觉记忆库”。

工作流闭环那一块我另写了一篇,这里只聊几个做这个东西时被坑得不轻的工程细节。

CLIP 模型本体

用的是导出的 ONNX visual encoder,进 224×224,出特征向量。预处理没啥花活,但顺序得跟训练时严丝合缝——BGR → RGB → letterbox → 中心裁剪 → 归一化,mean/std 直接抄 OpenAI CLIP 官方那一套:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import numpy as np
import onnxruntime as ort
import cv2

class CLIPModelFromONNX:
def __init__(self, model_path: str):
self.session = ort.InferenceSession(model_path)
self._input_name = self.session.get_inputs()[0].name

def _preprocess(self, image: np.ndarray) -> np.ndarray:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image, _, _ = letterbox(image, (224, 224))
image = center_crop(image)
image = image.astype(np.float32) / 255.0
mean = np.array([0.48145466, 0.4578275, 0.40821073])
std = np.array([0.26862954, 0.26130258, 0.27577711])
image = (image - mean) / std
image = np.transpose(image, (2, 0, 1))
return image[np.newaxis, :].astype(np.float32)

def forward(self, image: np.ndarray) -> np.ndarray:
input_tensor = self._preprocess(image)
output = self.session.run(None, {self._input_name: input_tensor})
return output[0]

预处理顺序错一个,整个特征空间就跟训练时对不上——表现就是检索效果烂得跟瞎猜似的。这种问题最难查,因为啥都不报错,就是结果不对。

入库 / 检索

库本身就是”特征向量 + payload”的一堆条目。每次入库前都和已有的过一遍相似度,超过 0.96 就当成”见过”,直接跳过:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import pickle
from dataclasses import dataclass

@dataclass
class CLIPRetrieveData:
payload: Any
similarity: float

class CLIPMemory:
def __init__(self, clip_name: str):
self._clip_name = clip_name
self._engine = CLIPModelFromONNX("clip_visual.onnx")
self._image_file_path = f"data/CLIP/{clip_name}"

@staticmethod
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
a, b = a.flatten(), b.flatten()
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def add_to_memory(self, image, payload, threshold=0.96) -> bool:
features = self._engine.forward(image)
for memory in self._load_memories():
existing = pickle.loads(memory.features)
if self._cosine_similarity(features, existing) > threshold:
return False
self._save_memory(features, payload)
return True

def retrieve(self, image, threshold=0.96):
features = self._engine.forward(image)
best_sim, best_payload = -1, None
for memory in self._load_memories():
existing = pickle.loads(memory.features)
sim = self._cosine_similarity(features, existing)
if sim > best_sim:
best_sim, best_payload = sim, memory.payload

if best_payload is not None and best_sim > threshold:
return CLIPRetrieveData(best_payload, best_sim)
return None

0.96 这个数纯属调出来的——0.99 太苛刻,同一张卡换个背景立马就漏;0.92 又太松,长得像的不同卡能给你撮合到一起去。每加一类新元素都得重新校一次,烦是真烦,可没办法。

数据增强:让特征更耐折腾

后来发现一个怪事:同一张卡换个分辨率截图,相似度能掉到 0.93,刚好卡在阈值边上,时灵时不灵。

解法是入库的时候就主动造几个变体一起存进去,免得运行时被各种”小变化”打个措手不及。常用的几种增强:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def augment_image(image: np.ndarray) -> list:
augmented = []
h, w = image.shape[:2]

# JPEG 压缩
_, buf = cv2.imencode(".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, 30])
augmented.append(cv2.imdecode(buf, cv2.IMREAD_COLOR))

# 高斯噪点
noise = np.random.normal(0, 15, image.shape).astype(np.float32)
noisy = np.clip(image.astype(np.float32) + noise, 0, 255).astype(np.uint8)
augmented.append(noisy)

# 亮度偏移
augmented.append(cv2.convertScaleAbs(image, alpha=1.15, beta=20))

# 色调偏移
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV).astype(np.int16)
hsv[:, :, 0] = (hsv[:, :, 0] + 8) % 180
augmented.append(cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR))

# 缩放
small = cv2.resize(image, (int(w * 0.85), int(h * 0.85)))
augmented.append(cv2.resize(small, (w, h)))

return augmented

不过也不能无脑都塞。增强后的样本要是跟原图相似度还在 0.95 以上,那存进去也是占地方——所以加了个判断,相似度低于 0.95 才真入库:

1
2
3
4
5
6
7
def add_to_memory_with_augmentation(self, image, payload):
self.add_to_memory(image, payload)
orig_features = self._engine.forward(image)
for aug_image in augment_image(image):
aug_features = self._engine.forward(aug_image)
if self._cosine_similarity(orig_features, aug_features) < 0.95:
self._save_memory(aug_features, payload)

类内偏差验证:防手抖

如果某个 payload 已经攒了几个样本,新加图的时候多走一步——这张新图跟同类样本的最高相似度,不能太低。低于 0.72 基本可以判定是标错了,直接拒了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def validate_inclass_similarity(self, image, target_id) -> bool:
MIN_INCLASS_SIMILARITY = 0.72
image_features = self._engine.forward(image)
best_inclass = -1.0
for memory in self._load_memories():
if memory.payload_id != target_id:
continue
existing = pickle.loads(memory.features)
sim = self._cosine_similarity(image_features, existing)
best_inclass = max(best_inclass, sim)

if best_inclass >= 0 and best_inclass < MIN_INCLASS_SIMILARITY:
return False
return True

这条加进来之前出过一次事故。调试的时候手滑点错了一张卡,错的卡也跟着进了库,越用越歪——后面识别越来越离谱,回查才发现根子在这。从那以后这检查就一直留着了。

失效记忆清理

payload 是引用业务库里卡片记录的,但业务库会被用户删卡、改版本——记忆库这边就会堆出一堆”指向不存在记录”的死链。

懒得专门起后台任务跑清理,干脆放在检索路径上顺手收一下,反正检索远比入库频繁:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def retrieve_with_cleanup(self, image, threshold=0.96):
features = self._engine.forward(image)
best_sim, best_payload = -1, None
stale_ids = []

for memory in self._load_memories():
try:
payload = memory.load_payload()
if payload is None:
stale_ids.append(memory.id)
continue
except Exception:
stale_ids.append(memory.id)
continue

existing = pickle.loads(memory.features)
sim = self._cosine_similarity(features, existing)
if sim > best_sim:
best_sim, best_payload = sim, payload

if stale_ids:
self._delete_memories(stale_ids)

if best_payload and best_sim > threshold:
return CLIPRetrieveData(best_payload, best_sim)
return None

共享一个模型会话

支援卡、技能卡、道具、偶像卡,各管各的记忆库,但用的都是同一个 visual encoder。要是每个识别器都各自加载一份 ONNX,内存能给你撑爆。

所以模型会话只持有一份,所有识别器共享:

1
2
3
4
5
6
7
class CLIPServiceManager:
def __init__(self):
self._model_session = CLIPModelFromONNX("clip_visual.onnx")
self.support_card_clip = SupportCardCLIP(self._model_session)
self.skill_card_clip = SkillCardCLIP(self._model_session)
self.item_clip = ItemCLIP(self._model_session)
self.idol_card_clip = IdolCardCLIP(self._model_session)

提醒一句:ONNX Runtime 的 session 写入操作不是线程安全的,但只读推理没问题——只要别让不同识别器同时去改 session options,共享就是安全的。

欢迎关注我的其它发布渠道