ONNX Runtime CPU推理优化
介绍 ONNX Runtime 常用推理参数。
引言
平时推理用的最多是 ONNX Runtime,推理引擎的合适调配对推理性能有着至关重要的影响。但是有关于 ONNX Runtime 参数设置的资料却散落在各个地方,不能形成有效的指导意见。因此,决定在这一篇文章中来梳理一下相关的设置。
以下参数都是来自 SessionOptions 。相关测试代码可以前往 AI Studio 查看。
欢迎补充和指出不足之处。
推荐常用设置
| import onnxruntime as rt
sess_options = rt.SessionOptions()
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.log_severity_level = 4
sess_options.enable_cpu_mem_arena = False
# 其他参数,采用默认即可
|
参数介绍
作用:启用内存模式优化,减少碎片,默认为 true。Enable the memory pattern optimization. Default is true。
作用:启用内存重用,避免重复分配,默认为 true。Enable the memory reuse optimization. Default is true.
作用:启用 CPU 上的 memory arena。Arena 可能会为将来预先申请很多内存。如果不想使用它,可以设置为 enable_cpu_mem_area=False,默认是 True
个人基于下面模型测试结论:建议关闭。开启之后,程序占用内存会剧增(5618.3M >> 5.3M),且持续占用,无法释放。推理时间提升约 13%。
由于不同模型存在差异,建议用户根据自身模型的实际测试效果,评估该参数开启前后的效果,再决定开启与否。以下是我这里测试的环境、代码和结果。
测试环境:
- Python: 3.7.13
- ONNX Runtime: 1.14.1
测试代码(来自 issue 11627,enable_cpu_memory_area_example.zip)
| # pip install onnxruntime==1.14.1
# pip install memory_profiler
import numpy as np
import onnxruntime as ort
from memory_profiler import profile
@profile
def onnx_prediction(model_path, input_data):
ort_sess = ort.InferenceSession(model_path, sess_options=sess_options)
preds = ort_sess.run(output_names=["predictions"],
input_feed={"input_1": input_data})[0]
return preds
sess_options = ort.SessionOptions()
sess_options.enable_cpu_mem_arena = False
input_data = np.load('enable_cpu_memory_area_example/input.npy')
print(f'input_data shape: {input_data.shape}')
model_path = 'enable_cpu_memory_area_example/model.onnx'
onnx_prediction(model_path, input_data)
|
Windows | macOS | Linux 测试情况
| enable_cpu_mem_arena=True |
|---|
| (demo) PS G:> python .\test_enable_cpu_mem_arena.py
enable_cpu_mem_arena: True
input_data shape: (32, 200, 200, 1)
Filename: .\test_enable_cpu_mem_arena.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
7 69.1 MiB 69.1 MiB 1 @profile
8 def onnx_prediction(model_path, input_data):
9 77.2 MiB 8.1 MiB 1 ort_sess = ort.InferenceSession(model_path, sess_options=sess_options)
10 77.2 MiB 0.0 MiB 1 preds = ort_sess.run(output_names=["predictions"],
11 5695.5 MiB 5618.3 MiB 1 input_feed={"input_1": input_data})[0]
12 5695.5 MiB 0.0 MiB 1 return preds
|
| enable_cpu_mem_arena=False |
|---|
| (demo) PS G:> python .\test_enable_cpu_mem_arena.py
enable_cpu_mem_arena: False
input_data shape: (32, 200, 200, 1)
Filename: .\test_enable_cpu_mem_arena.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
7 69.1 MiB 69.1 MiB 1 @profile
8 def onnx_prediction(model_path, input_data):
9 76.9 MiB 7.8 MiB 1 ort_sess = ort.InferenceSession(model_path, sess_options=sess_options)
10 76.9 MiB 0.0 MiB 1 preds = ort_sess.run(output_names=["predictions"],
11 82.1 MiB 5.3 MiB 1 input_feed={"input_1": input_data})[0]
12 82.1 MiB 0.0 MiB 1 return preds
|
enable_profiling
作用:开启这个参数,在推理时,会生成一个类似 onnxruntime_profile__2023-05-07_09-02-15.json 的日志文件,包含详细的性能数据(线程、每个运算符的延迟等)。建议开启。
| import onnxruntime as rt
sess_options = rt.SessionOptions()
sess_options.enable_profiling = True
|
execution_mode
设置运行模型的模式,包括 rt.ExecutionMode.ORT_SEQUENTIAL 和 rt.ExecutionMode.ORT_PARALLEL。一个序列执行,一个并行。默认是序列执行。
通常来说,当一个模型中有许多分支时,可以设置该参数为ORT_PARALLEL来达到更好的表现
当设置 sess_options.execution_mode = rt.ExecutionMode.ORT_PARALLEL 时,可以设置 sess_options.inter_op_num_threads 来控制使用线程的数量,来并行化执行(模型中各个节点之间)
inter_op_num_threads
设置并行化执行图(跨节点)时,使用的线程数。默认是 0,交由 ONNX Runtime 自行决定。
| import onnxruntime as rt
sess_options = rt.SessionOptions()
sess_options.inter_op_num_threads = 2
|
intra_op_num_threads
设置并行化执行图(内部节点)时,使用的线程数。默认是 0,交由 ONNX Runtime 自行决定,一般会选择使用设备上所有的核。
⚠️ 这个值并不是越大越好,具体参考 AI Studio 中的消融实验。
| import onnxruntime as rt
sess_options = rt.SessionOptions()
sess_options.intra_op_num_threads = 2
|
运行图时,对图中算子的优化水平。默认是开启全部算子的优化。建议采用默认值即可。可选的枚举值有:ORT_DISABLE_ALL | ORT_ENABLE_BASIC | ORT_ENABLE_EXTENDED | ORT_ENABLE_ALL
| import onnxruntime as rt
sess_options = rt.SessionOptions()
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
|
参考资料