ONNX Runtime 跨平台是真的能跨——同一份模型文件,Windows、macOS、Linux 都能跑。但加速后端就完全不通用了:Windows 用 DirectML,macOS 用 CoreML,Linux 大多数情况只剩 CPU 一条路。
要是代码里直接写死 CPUExecutionProvider——在 Mac 上能跑,但 Apple Silicon 的 ANE 完全没用上,浪费;在 Windows 上又用不到核显加速,亏。
下面是这个项目里管理 ORT 会话的几个关键点:自动探测、CoreML 缓存、动态维度覆盖、运行时回退。
优先级列表 按平台从高到低排个序,找到一个能用的就行:
1 2 3 4 5 6 7 import onnxruntime as ort PREFERRED_EXECUTION_PROVIDERS = ( "DmlExecutionProvider" , "CoreMLExecutionProvider" , "CPUExecutionProvider" , )
get_available_providers() 返回的是当前 ORT 编译时带的所有 provider 名字,但不代表全部都能真正初始化成功 ——这点很多人会忽略。所以要先用列表过滤一遍,再实际尝试创建会话:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 class DMLManager : @classmethod def get_session_providers (cls ) -> list [str ]: available_providers = set (ort.get_available_providers()) providers = [ provider for provider in cls._preferred_execution_providers if provider in available_providers ] if "CPUExecutionProvider" not in providers: providers.append("CPUExecutionProvider" ) return providers @classmethod def create_session (cls, model_path: str ) -> ort.InferenceSession: providers = cls.get_session_providers() so = cls._build_session_options() try : return ort.InferenceSession(model_path, sess_options=so, providers=providers) except Exception as exc: logger.warning(f"加速后端失败,回退到 CPU: {exc} " ) return ort.InferenceSession( model_path, sess_options=so, providers=["CPUExecutionProvider" ], )
CoreML 缓存 CoreML 第一次加载 ONNX 模型时,会做一次”编译”——把模型转成 CoreML 内部的表示。这一编译可不得了,一个 YOLO nano 模型在 M1 上能搞十几秒,大模型甚至跑好几分钟。用户那边看到的就是:软件启动条爬得跟便秘似的。
好在只要模型文件没变,编译产物就能复用。所以加了个 cache,用模型文件的哈希做 key:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 import hashlibfrom pathlib import Pathclass DMLManager : @classmethod def _build_model_cache_key (cls, model_path: str , provider_options: dict ) -> str : model_file = Path(model_path).resolve() digest = hashlib.sha256() for file_path in cls._iter_model_fingerprint_files(model_file): with open (file_path, "rb" ) as f: for chunk in iter (lambda : f.read(1024 * 1024 ), b"" ): digest.update(chunk) digest.update(repr (sorted (provider_options.items())).encode()) if cls._free_dimension_overrides: digest.update(repr (cls._free_dimension_overrides).encode()) return digest.hexdigest()[:24 ] @classmethod def _build_provider_config (cls, model_path: str ) -> list : providers = [] available = cls.get_session_providers() if "DmlExecutionProvider" in available: providers.append("DmlExecutionProvider" ) if "CoreMLExecutionProvider" in available: cache_root = cls._get_cache_root() cache_dir = cache_root / "coreml-cache" provider_options = { "ModelFormat" : "MLProgram" , "MLComputeUnits" : "ALL" , "RequireStaticInputShapes" : "0" , } model_cache_key = cls._build_model_cache_key(model_path, provider_options) cache_subdir = cache_dir / model_cache_key cache_subdir.mkdir(parents=True , exist_ok=True ) cls._prune_stale_coreml_cache(cache_dir, model_cache_key) providers.append(( "CoreMLExecutionProvider" , {**provider_options, "ModelCacheDirectory" : str (cache_subdir)}, )) if "CPUExecutionProvider" in available: providers.append("CPUExecutionProvider" ) return providers
哈希里得带上 provider_options 和 free_dimension_overrides——这两个变了,缓存的编译产物就废了。要是不带,会出现”明明改了配置但行为没变”的灵异事件。
缓存损坏要能重试 CoreML 缓存目录被意外搞坏的情况,是真的会发生:手动删了一半文件、磁盘满了写到一半、用户在版本之间倒退……最后都会变成 ORT 初始化时报错。
处理方式是:捕获异常 → 清掉缓存目录 → 再试一次 → 还失败就回 CPU:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 @classmethod def create_session_with_retry (cls, model_path: str ) -> ort.InferenceSession: providers = cls._build_provider_config(model_path) so = cls._build_session_options() try : return ort.InferenceSession(model_path, sess_options=so, providers=providers) except Exception as exc: coreml_cache_dir = cls._extract_coreml_cache_dir(providers) if coreml_cache_dir is not None : cls._clear_coreml_cache_dir(coreml_cache_dir) logger.warning(f"清除 CoreML 缓存后重试: {exc} " ) providers = cls._build_provider_config(model_path) try : return ort.InferenceSession(model_path, sess_options=so, providers=providers) except Exception as retry_exc: logger.warning(f"重试失败,回退到 CPU: {retry_exc} " ) return ort.InferenceSession( model_path, sess_options=so, providers=["CPUExecutionProvider" ], )
动态维度覆盖 CoreML 不支持动态 batch/H/W——这是个老问题。
YOLO 模型导出时通常会保留 batch 维度为符号”batch”或者”batch_size”,输入 H/W 也可能是动态的。直接拿去给 CoreML 加载,立马报错给你看。
办法是用 add_free_dimension_override_by_name,在 session option 上把符号维度钉死:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 class DMLManager : _free_dimension_overrides = [ ("batch" , 1 ), ("batch_size" , 1 ), ] @classmethod def _build_session_options (cls, extra_overrides: dict = None ) -> ort.SessionOptions: so = ort.SessionOptions() so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL so.intra_op_num_threads = 1 so.inter_op_num_threads = 1 for name, value in cls._free_dimension_overrides: so.add_free_dimension_override_by_name(name, value) if extra_overrides: for name, value in extra_overrides.items(): so.add_free_dimension_override_by_name(name, value) return so @classmethod def _build_dim_overrides_for_model (cls, model_path: str ) -> dict : overrides = {} imgsz = cls._read_model_imgsz(model_path) if imgsz is not None : overrides["height" ] = imgsz[0 ] overrides["width" ] = imgsz[1 ] return overrides
H 和 W 的具体值不能写死——会从模型的 meta(YOLO 导出时会把 imgsz 写进 ONNX 的 metadata)里读。不同 YOLO 模型可能是 640、960、1280,统一一个常量肯定行不通。
运行时回退 加速后端在初始化时能跑起来,不代表运行时不会挂——见过 DirectML 在显存吃紧时直接抛 RuntimeException 的情况,那一刻服务整个就停摆了。
这种时候不应该让推理失败,得把会话静默换成 CPU 重跑一次:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 class DMLManager : _session_replacements: dict [int , ort.InferenceSession] = {} @classmethod def run (cls, session: ort.InferenceSession, feeds: dict ) -> list : with cls._lock: active_session = cls._resolve_session(session) try : return active_session.run(None , feeds) except Exception as exc: fallback_session = cls._fallback_to_cpu_session(active_session, exc) if fallback_session is None : raise return fallback_session.run(None , feeds) @classmethod def _fallback_to_cpu_session (cls, session, exc ) -> ort.InferenceSession | None : session_id = id (session) model_path = cls._session_model_paths.get(session_id) provider_names = cls._session_provider_names.get(session_id) if not model_path or not provider_names: return None if tuple (provider_names) == ("CPUExecutionProvider" ,): return None logger.warning(f"ONNX 推理失败,回退到 CPU: {exc} " ) fallback_session = cls._create_cpu_session(model_path) cls._session_replacements[session_id] = fallback_session return fallback_session @classmethod def _resolve_session (cls, session ) -> ort.InferenceSession: current = session visited = set () while True : current_id = id (current) replacement = cls._session_replacements.get(current_id) if replacement is None or current_id in visited: return current visited.add(current_id) current = replacement
_session_replacements 用 id(session) 做 key,因为 ORT 的 session 本身不可哈希。visited 这个 set 是防 replacement 链出现环——理论上不该发生,但写一行不亏。
回退之后,这个会话后续所有请求都自动走 CPU,不再尝试加速后端。这个语义是故意的——一次失败可能是偶发,但加速后端一旦不稳,反复尝试只会拖死业务。要恢复加速?只能重启进程。