把 MobileCLIP2 导出成 ONNX——视觉模型的几个真坑

CLIP 这两年在视觉任务里成了瑞士军刀——图搜图、零样本分类、跨模态检索,啥都能搭一手。但 OpenAI 那版 CLIP 模型偏大,跑在桌面用户的设备上有点吃力。

后来 Apple 出了 MobileCLIP,小、快、精度还不错。第二代 MobileCLIP2 更新了下,更香。

问题是——MobileCLIP2 官方代码基于 PyTorch,没现成的 ONNX 文件。我们的项目要跑在用户的桌面(Windows/macOS/Linux 全要支持),PyTorch 那一坨依赖装起来用户得疯——光 torch 包就上 G。ONNX Runtime 加一个几十 MB 的模型文件就完事,对比鲜明。

所以得自己导出。听起来一行 torch.onnx.export 的事,实际坑不少。

为什么不一行搞定

最朴素的版本:

1
2
3
4
5
6
7
import torch
from mobileclip import create_model

model = create_model("mobileclip_s2")
model.eval()
dummy = torch.randn(1, 3, 256, 256)
torch.onnx.export(model, dummy, "mobileclip2.onnx", opset_version=17)

跑一下,立马得到一堆错误。常见的几类:

  • 某个算子在 opset_version 不支持——升级 opset,或者绕过去
  • 动态 shape 没声明——batch 维度希望可变,得显式标 dynamic_axes
  • forward 返回的不是 tensor——是个 dict / namedtuple,ONNX 不认
  • 训练时和推理时分支不同——某个 if self.training 路径里 ONNX 跟踪出错

每个都得单独治。

坑一:模型 forward 输出结构

MobileCLIP 这种双塔模型,forward 一般同时返回 image embedding 和 text embedding。结构大概是这样:

1
2
3
4
5
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
logit_scale = self.logit_scale.exp()
return image_features, text_features, logit_scale

但我们的实际使用场景是只要 image encoder——用 CLIP 给图片编码进向量库,文本侧根本不参与。强行导出整模型,多带了一倍多的参数,纯浪费。

正确做法是只导出 visual encoder 子图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class VisualOnly(torch.nn.Module):
def __init__(self, base_model):
super().__init__()
self.encoder = base_model.image_encoder
# 如果有 projection 层,也带上
self.projection = base_model.image_projection

def forward(self, image):
feat = self.encoder(image)
feat = self.projection(feat)
return feat # 单个 tensor,ONNX 友好

visual = VisualOnly(model)
visual.eval()
torch.onnx.export(visual, dummy, "mobileclip2_visual.onnx", ...)

我们项目里 devtools/export_mobileclip2.py 就是这么干的。模型体积砍掉一半多,推理时也省去了文本侧那一坨无用计算。

坑二:动态 batch 维度

桌面应用里的图像识别,一次推理几张图是不固定的。一会儿单张(实时识别),一会儿几十张(批量学习一组新卡)。

如果导出时 batch 维度被固定成 1,运行时遇到批量推理就只能循环单张,性能拉胯。

声明动态维度的写法:

1
2
3
4
5
6
7
8
9
10
torch.onnx.export(
visual, dummy, "mobileclip2_visual.onnx",
opset_version=17,
input_names=["image"],
output_names=["features"],
dynamic_axes={
"image": {0: "batch"}, # 第 0 维(batch)允许变
"features": {0: "batch"},
},
)

不写 dynamic_axes,ONNX 默认把 dummy 输入的 shape 完全固化。运行时给个 batch=8 进去,直接报错——这是新手最常见的坑。

坑三:opset_version 怎么选

这个版本号说白了就是”ONNX 协议版本”。版本越高支持的算子越多、越新,但 ONNX Runtime 也得跟得上。

实战经验:

  • 太低(比如 opset 9)——很多现代模型用的算子(比如某些 attention 实现、Einsum、新版的归一化)不支持,导出报错
  • 太高(比如 opset 20+)——ONNX Runtime 部分版本不认,模型加载就挂

我们的稳妥选择是 opset 17——这是个甜蜜点:支持绝大多数视觉模型的算子,ORT 几乎所有近期版本都能加载。

要是模型里有 Flash Attention 之类的新东西,可能要到 opset 18+。出问题先把 opset 降一档试试。

坑四:导出后的精度漂移

模型导出后跑一下,结果跟 PyTorch 原模型不完全一致——这几乎是必然的。差异来源:

  • 算子实现差异——同一个 LayerNorm,PyTorch 和 ONNX Runtime 内部实现的数值路径略有差异
  • fp32/fp16 转换——如果你导出时做了量化,差异更大
  • 预处理对齐——图像归一化的 mean/std 写错一位就完蛋

差异大概多少算正常?

  • 小数值差异(每个分量差 1e-4 ~ 1e-3):正常,业务上感知不到
  • 明显差异(差 0.01+):不正常,去查实现差异
  • 方向差异(embedding 相似度算出来差异巨大):肯定哪里错了,回去捋一遍

验证方法很简单——挑十几张图,分别用 PyTorch 和 ONNX 推理一遍,算两边 embedding 的余弦相似度。正常应该 ≥ 0.999。如果只有 0.99 甚至更低,赶紧排查。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import numpy as np
import torch
import onnxruntime as ort

torch_model.eval()
session = ort.InferenceSession("mobileclip2_visual.onnx")

for i, img in enumerate(test_images):
with torch.no_grad():
f_torch = torch_model(img).numpy()
f_onnx = session.run(["features"], {"image": img.numpy()})[0]

cos = (f_torch * f_onnx).sum() / (np.linalg.norm(f_torch) * np.linalg.norm(f_onnx))
print(f"img {i}: cos={cos:.6f}")

坑五:预处理对齐

CLIP 类模型对预处理特别敏感。OpenAI CLIP 的官方预处理是:

1
BGR → RGB → resize → center crop → ToTensor → Normalize(CLIP_MEAN, CLIP_STD)

MobileCLIP 用的是 ImageNet 标准 mean/std:

1
2
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

导出 ONNX 时一定要确认 mean/std 用对。预处理用错了 mean/std,模型推理出来的 embedding 直接乱套——和 PyTorch 端的差异会非常大。

我们的策略是把预处理写在 Python 侧而非 ONNX 内部——ONNX 模型只接受标准化后的 tensor,预处理由 OpenCV 完成:

1
2
3
4
5
6
7
8
def preprocess(img_bgr: np.ndarray) -> np.ndarray:
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
img_rgb = letterbox_resize(img_rgb, 256)
img_rgb = center_crop(img_rgb, 256)
img_f = img_rgb.astype(np.float32) / 255.0
img_f = (img_f - MEAN) / STD
img_f = img_f.transpose(2, 0, 1) # HWC → CHW
return img_f[np.newaxis, :] # 加 batch 维

为啥不把预处理也塞进 ONNX?两个原因:

  1. debug 难度——预处理在 Python 里改起来快,进了 ONNX 改一次重新导出一次,反复横跳痛苦
  2. 算子兼容——某些预处理操作(letterbox、bilinear interpolate 等)在不同 EP 上行为可能不一致

预处理放外面,ONNX 只管纯模型推理,逻辑清爽。

坑六:模型存档后还能再优化一刀

导出来的 ONNX 模型不是最终态。用 onnxsim 或 ORT 自带的优化器跑一遍,体积和速度都能再压一截:

1
2
3
4
5
6
7
import onnx
from onnxsim import simplify

model = onnx.load("mobileclip2_visual.onnx")
simplified, ok = simplify(model)
assert ok
onnx.save(simplified, "mobileclip2_visual.opt.onnx")

onnxsim 主要做常量折叠、算子融合、去除冗余节点。我们项目里实测能把模型从 60+ MB 压到 50 MB 左右,推理速度也快 5%~10%。

但有个前提:优化后必须重新跑一遍精度验证。优化器偶尔会引入数值差异(罕见但发生过),不验证就直接上线,等于裸奔。

实战推理代码

导出 + 优化完之后,业务里调用就很标准了:

1
2
3
4
5
6
7
8
9
10
11
class CLIPModelFromONNX:
def __init__(self, model_path: Path):
providers = self._select_providers() # CoreML / DirectML / CPU
self.session = ort.InferenceSession(str(model_path), providers=providers)

def encode(self, img_bgr: np.ndarray) -> np.ndarray:
tensor = preprocess(img_bgr)
feat = self.session.run(["features"], {"image": tensor})[0]
# L2 归一化,方便后续做余弦相似度
feat = feat / np.linalg.norm(feat, axis=1, keepdims=True)
return feat

得到的 feature 直接喂进项目里的”视觉记忆库”(另一篇专门聊了这块),就能玩各种相似度检索了。

收个尾

把 PyTorch 模型导出成 ONNX,这事儿写在博客里看着很简单——一行 export 命令搞定。

真正动手才知道,导出只是个起点。子图剥离、动态维度、精度验证、预处理对齐、优化器、版本兼容性,每一项都能让你卡几个小时。

但回报很值——一个能跑在用户桌面、不用装 PyTorch、占用几十 MB、推理几十毫秒的 CLIP,是把视觉智能塞进普通应用的关键。

下次想给项目加个图像理解能力,别一上来就 import torch,先想想这事儿能不能用一个 ONNX 模型搞定。多数时候能。

欢迎关注我的其它发布渠道