ONNX Runtime 的 DirectML / CoreML / CPU 三段回退

ONNX Runtime 跨平台是真的能跨——同一份模型文件,Windows、macOS、Linux 都能跑。但加速后端就完全不通用了:Windows 用 DirectML,macOS 用 CoreML,Linux 大多数情况只剩 CPU 一条路。

要是代码里直接写死 CPUExecutionProvider——在 Mac 上能跑,但 Apple Silicon 的 ANE 完全没用上,浪费;在 Windows 上又用不到核显加速,亏。

下面是这个项目里管理 ORT 会话的几个关键点:自动探测、CoreML 缓存、动态维度覆盖、运行时回退。

优先级列表

按平台从高到低排个序,找到一个能用的就行:

1
2
3
4
5
6
7
import onnxruntime as ort

PREFERRED_EXECUTION_PROVIDERS = (
"DmlExecutionProvider", # Windows
"CoreMLExecutionProvider", # macOS
"CPUExecutionProvider", # 兜底
)

get_available_providers() 返回的是当前 ORT 编译时带的所有 provider 名字,但不代表全部都能真正初始化成功——这点很多人会忽略。所以要先用列表过滤一遍,再实际尝试创建会话:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class DMLManager:
@classmethod
def get_session_providers(cls) -> list[str]:
available_providers = set(ort.get_available_providers())
providers = [
provider
for provider in cls._preferred_execution_providers
if provider in available_providers
]
if "CPUExecutionProvider" not in providers:
providers.append("CPUExecutionProvider")
return providers

@classmethod
def create_session(cls, model_path: str) -> ort.InferenceSession:
providers = cls.get_session_providers()
so = cls._build_session_options()
try:
return ort.InferenceSession(model_path, sess_options=so, providers=providers)
except Exception as exc:
logger.warning(f"加速后端失败,回退到 CPU: {exc}")
return ort.InferenceSession(
model_path,
sess_options=so,
providers=["CPUExecutionProvider"],
)

CoreML 缓存

CoreML 第一次加载 ONNX 模型时,会做一次”编译”——把模型转成 CoreML 内部的表示。这一编译可不得了,一个 YOLO nano 模型在 M1 上能搞十几秒,大模型甚至跑好几分钟。用户那边看到的就是:软件启动条爬得跟便秘似的。

好在只要模型文件没变,编译产物就能复用。所以加了个 cache,用模型文件的哈希做 key:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import hashlib
from pathlib import Path

class DMLManager:
@classmethod
def _build_model_cache_key(cls, model_path: str, provider_options: dict) -> str:
model_file = Path(model_path).resolve()
digest = hashlib.sha256()

# 模型本体和 .data 外部权重都要算进去
for file_path in cls._iter_model_fingerprint_files(model_file):
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
digest.update(chunk)

digest.update(repr(sorted(provider_options.items())).encode())

if cls._free_dimension_overrides:
digest.update(repr(cls._free_dimension_overrides).encode())

return digest.hexdigest()[:24]

@classmethod
def _build_provider_config(cls, model_path: str) -> list:
providers = []
available = cls.get_session_providers()

if "DmlExecutionProvider" in available:
providers.append("DmlExecutionProvider")

if "CoreMLExecutionProvider" in available:
cache_root = cls._get_cache_root()
cache_dir = cache_root / "coreml-cache"

provider_options = {
"ModelFormat": "MLProgram",
"MLComputeUnits": "ALL",
"RequireStaticInputShapes": "0",
}
model_cache_key = cls._build_model_cache_key(model_path, provider_options)
cache_subdir = cache_dir / model_cache_key
cache_subdir.mkdir(parents=True, exist_ok=True)

# 顺手清旧版本的缓存
cls._prune_stale_coreml_cache(cache_dir, model_cache_key)

providers.append((
"CoreMLExecutionProvider",
{**provider_options, "ModelCacheDirectory": str(cache_subdir)},
))

if "CPUExecutionProvider" in available:
providers.append("CPUExecutionProvider")

return providers

哈希里得带上 provider_optionsfree_dimension_overrides——这两个变了,缓存的编译产物就废了。要是不带,会出现”明明改了配置但行为没变”的灵异事件。

缓存损坏要能重试

CoreML 缓存目录被意外搞坏的情况,是真的会发生:手动删了一半文件、磁盘满了写到一半、用户在版本之间倒退……最后都会变成 ORT 初始化时报错。

处理方式是:捕获异常 → 清掉缓存目录 → 再试一次 → 还失败就回 CPU:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
@classmethod
def create_session_with_retry(cls, model_path: str) -> ort.InferenceSession:
providers = cls._build_provider_config(model_path)
so = cls._build_session_options()

try:
return ort.InferenceSession(model_path, sess_options=so, providers=providers)
except Exception as exc:
coreml_cache_dir = cls._extract_coreml_cache_dir(providers)
if coreml_cache_dir is not None:
cls._clear_coreml_cache_dir(coreml_cache_dir)
logger.warning(f"清除 CoreML 缓存后重试: {exc}")
providers = cls._build_provider_config(model_path)
try:
return ort.InferenceSession(model_path, sess_options=so, providers=providers)
except Exception as retry_exc:
logger.warning(f"重试失败,回退到 CPU: {retry_exc}")

return ort.InferenceSession(
model_path,
sess_options=so,
providers=["CPUExecutionProvider"],
)

动态维度覆盖

CoreML 不支持动态 batch/H/W——这是个老问题。

YOLO 模型导出时通常会保留 batch 维度为符号”batch”或者”batch_size”,输入 H/W 也可能是动态的。直接拿去给 CoreML 加载,立马报错给你看。

办法是用 add_free_dimension_override_by_name,在 session option 上把符号维度钉死:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class DMLManager:
_free_dimension_overrides = [
("batch", 1),
("batch_size", 1),
]

@classmethod
def _build_session_options(cls, extra_overrides: dict = None) -> ort.SessionOptions:
so = ort.SessionOptions()
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
so.intra_op_num_threads = 1
so.inter_op_num_threads = 1

for name, value in cls._free_dimension_overrides:
so.add_free_dimension_override_by_name(name, value)

if extra_overrides:
for name, value in extra_overrides.items():
so.add_free_dimension_override_by_name(name, value)
return so

@classmethod
def _build_dim_overrides_for_model(cls, model_path: str) -> dict:
overrides = {}
imgsz = cls._read_model_imgsz(model_path)
if imgsz is not None:
overrides["height"] = imgsz[0]
overrides["width"] = imgsz[1]
return overrides

H 和 W 的具体值不能写死——会从模型的 meta(YOLO 导出时会把 imgsz 写进 ONNX 的 metadata)里读。不同 YOLO 模型可能是 640、960、1280,统一一个常量肯定行不通。

运行时回退

加速后端在初始化时能跑起来,不代表运行时不会挂——见过 DirectML 在显存吃紧时直接抛 RuntimeException 的情况,那一刻服务整个就停摆了。

这种时候不应该让推理失败,得把会话静默换成 CPU 重跑一次:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class DMLManager:
_session_replacements: dict[int, ort.InferenceSession] = {}

@classmethod
def run(cls, session: ort.InferenceSession, feeds: dict) -> list:
with cls._lock:
active_session = cls._resolve_session(session)
try:
return active_session.run(None, feeds)
except Exception as exc:
fallback_session = cls._fallback_to_cpu_session(active_session, exc)
if fallback_session is None:
raise
return fallback_session.run(None, feeds)

@classmethod
def _fallback_to_cpu_session(cls, session, exc) -> ort.InferenceSession | None:
session_id = id(session)
model_path = cls._session_model_paths.get(session_id)
provider_names = cls._session_provider_names.get(session_id)

if not model_path or not provider_names:
return None
if tuple(provider_names) == ("CPUExecutionProvider",):
return None # 已经是 CPU 了

logger.warning(f"ONNX 推理失败,回退到 CPU: {exc}")

fallback_session = cls._create_cpu_session(model_path)
# 同一会话以后再来,直接走 CPU
cls._session_replacements[session_id] = fallback_session
return fallback_session

@classmethod
def _resolve_session(cls, session) -> ort.InferenceSession:
current = session
visited = set()
while True:
current_id = id(current)
replacement = cls._session_replacements.get(current_id)
if replacement is None or current_id in visited:
return current
visited.add(current_id)
current = replacement

_session_replacementsid(session) 做 key,因为 ORT 的 session 本身不可哈希。visited 这个 set 是防 replacement 链出现环——理论上不该发生,但写一行不亏。

回退之后,这个会话后续所有请求都自动走 CPU,不再尝试加速后端。这个语义是故意的——一次失败可能是偶发,但加速后端一旦不稳,反复尝试只会拖死业务。要恢复加速?只能重启进程。

欢迎关注我的其它发布渠道