MXNet中MTCNN推理时动态输入与batch size相关问题咨询
刚好之前在部署MXNet版MTCNN时踩过这些坑,给你逐个梳理解决方案:
1. 可变输入尺寸/批次的冗余计算优化方案
你现在用超大固定shape绑定的方式确实能跑,但冗余计算太浪费资源了,其实MXNet的Module本身就支持动态调整输入形状,完全不用搞这种妥协方案:
PNet处理可变尺寸图像:不要提前绑定最大尺寸的shape,而是在每次处理不同尺寸的图像时,先调用
reshape方法设置当前图像的实际shape,再执行推理。如果是第一次推理,Module还未绑定,也可以直接根据实际shape完成绑定(记得加上后面提到的partial_shaping=True)。示例代码:# 假设当前待检测图像的张量形状是(1, 3, h, w) self.PNets.reshape(data_shapes=[('data', (1, 3, h, w))]) # 喂入数据并执行推理 self.PNets.forward(mx.io.DataBatch(data=[img_tensor]))RNet/ONet处理可变批次:每次根据PNet输出的候选人脸数量(比如N个),动态调整batch size为N,同样用
reshape修改输入shape:# 假设PNet输出了N个24x24的候选人脸张量,形状为(N, 3, 24, 24) self.RNet.reshape(data_shapes=[('data', (N, 3, 24, 24))]) # 喂入候选人脸张量后执行推理
这样每次都只分配实际需要的计算资源,彻底消除冗余计算。
2. reshape时的AssertionError解决方法
这个错误我当初也遇到过,本质是MXNet静态图Module的严格形状检查机制,和MTCNN网络的训练/推理模式不匹配:MTCNN训练时包含了label相关的节点(比如prob1_label),但你推理时设置了label_names=None,导致reshape时这些label节点的形状变化无法被验证。
解决方法非常直接,按照错误提示的建议,在bind Module的时候加上partial_shaping=True参数,允许部分输入输出的形状动态调整:
以PNet为例,修改后的绑定代码:
sym, arg_params, aux_params = mx.model.load_checkpoint('det1', 0) self.PNets = mx.mod.Module(symbol=sym, context=ctx, label_names=None) self.PNets.bind( data_shapes=[('data', (1, 3, max_img_w, max_img_h))], for_training=False, partial_shaping=True # 关键参数,允许动态调整形状 ) self.PNets.set_params(arg_params, aux_params)
RNet和ONet的绑定代码也同样加上这个参数即可。后续调用reshape调整输入形状时,就不会再触发这个断言错误了。
如果想要彻底清除这些无用的label节点,也可以修改加载后的symbol,去掉所有和label相关的输出,但设置partial_shaping=True是更简单快捷的方式,完全满足推理需求。
3. FeedForward vs Module的内存差异与弃用原因
内存差异的根源
mx.model.FeedForward.load加载模型时,只是加载了symbol结构和参数权重,并没有创建实际的计算图(executor),只有当你第一次调用predict方法时,才会根据输入形状创建executor并分配显存/内存,所以加载后初期内存占用极低。mx.mod.Module在调用bind方法时,就会根据指定的data_shapes创建executor并分配显存/内存,所以绑定后立刻会有内存占用。
FeedForward被弃用的原因
FeedForward是MXNet早期的高层API,设计比较单一,仅支持简单的单输入单输出的训练和推理场景,灵活性很差——比如不支持多设备并行、自定义训练流程、多输入输出网络等。而Module是MXNet后来推出的更模块化、更灵活的API,支持复杂网络结构、多设备部署、自定义回调等功能,完全覆盖了FeedForward的所有场景,并且扩展性更强。MXNet官方为了统一API生态,逐步弃用了FeedForward,推荐使用Module或者更现代的Gluon动态图API。
如果你想要类似FeedForward的延迟初始化效果,可以把Module的bind操作推迟到第一次推理时,也就是根据实际的输入形状来bind,而不是在初始化阶段就绑定固定shape,这样也能避免提前占用内存。
内容的提问来源于stack exchange,提问作者搴瑰畤缈�,mxnet




