Cannot import keras models via onnx

Hi,

I am trying to import a model I trained in keras into C++ TensorRT using onnx as an intermediate format. The nvonnxparser::IParser always fails on converted keras models. An onnx file downloaded from the onnx model zoo is parsed just fine.

I have converted different keras models using different versions of several libraries (keras2onnx, tf2onnx, onnxmltools). Different opsets also. I get different error messages, in particular assertions on onnx_padding and failed checks on isValidDims(dims).

I have been able to do inference with generated onnx files using the python onnxruntime library (CPU), so the generated onnx is not completely crazy (I also verified this on simple models), but may not be conforming.

I reproduced the problem in python (more convenient – I assume a fix there translates directly to C++).

Any ideas on how to get my keras model into C++/TensorRT are welcome. But in particular: what is wrong with onnx parsing?

Platform
Ubuntu 18.04
GeForce RTX 2070
Driver Version: 418.67
CUDA Version: 10.1
CuDNN Version: 7.3.1
Python 3.6.8
tensorflow 1.14.0
tensorrt 5.0.2.6

Note: more version info in the output of the script below

Script to reproduce
In the script below I import 3 onnx models:

  1. Model from onnx model zoo -- parses just fine
  2. Pretrained ResNet50 that comes with keras -- I cannot export/import
  3. Tiny ConvNet -- I cannot export/import
import os
import sys
import traceback
import urllib.request

import keras
from keras.applications.resnet50 import ResNet50
from keras.models import Model
from keras.layers import Input, Conv2D
import tensorflow as tf
import keras2onnx
import tensorrt as trt

def GiB(val):
    return val * 1 << 30

def parse_onnx(fpath):
    with open(fpath, "rb") as f:
        return parse_serialized_onnx(f.read())

def parse_serialized_onnx(serialized):
    print(f"    Parsing serialized onnx...")
    try:
        builder = trt.Builder(TRT_LOGGER)
        builder.max_workspace_size = GiB(1)
        network = builder.create_network()
        parser = trt.OnnxParser(network, TRT_LOGGER)
        res = parser.parse(serialized)
        print(f"        parsing returned: {res}")
        for i in range(parser.num_errors):
            print(f"        Error #{i+1}/{parser.num_errors}")
            err = parser.get_error(i)
            print(f"            {err.desc()}")
        return
    except:
        traceback.print_exc()

def t1():
    print(f"{'-'*80}")
    print(f"| t1(): Onnx downloaded from model zoo | Works fine")
    print(f"{'-'*80}")
    onnxUrl = "https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/mobilenetv2-1.0.onnx"
    bn = os.path.basename(onnxUrl)
    if not os.path.exists(bn):
        url = urllib.request.urlopen(onnxUrl)
        with open(bn, "wb") as f:
            f.write( url.read() )
    
    parse_onnx(bn)
    return
    
def t2():
    print(f"{'-'*80}")
    print(f"| t2(): Convert keras application to onnx using keras2onnx | Problem")
    print(f"{'-'*80}")
    model = ResNet50()
    onnxModel = keras2onnx.convert_keras(model, model.name)

    serialized = onnxModel.SerializeToString()
    parse_serialized_onnx(serialized)
    return
    


def t3():
    """A very simple ConvNet"""
    print(f"{'-'*80}")
    print(f"| t3(): Create extremely simple ConvNet -- no training | Problem")
    print(f"{'-'*80}")
    inputs = Input((None, None, 1), name="inputs")
    outputs = Conv2D(1, (1,1), activation="sigmoid")(inputs)
    model = Model(inputs=inputs, outputs=outputs, name="ConvModel")

    onnxModel = keras2onnx.convert_keras(model, model.name)

    serialized = onnxModel.SerializeToString()
    parse_serialized_onnx(serialized)
    return
    

os.system("cat /usr/local/cuda/version.txt")
os.system('cat /usr/include/cudnn.h  | grep "#define CUDNN_MAJOR\|#define CUDNN_MINOR\|#define CUDNN_PATCHLEVEL"')
print(f"python: {sys.version}")
print(f"keras: {keras.__version__}")
print(f"tensorflow: {tf.__version__}")
print(f"tensorrt: {trt.__version__}")
print(f"keras2onnx: {keras2onnx.__version__}\n\n")

TRT_LOGGER = trt.Logger(trt.Logger.INFO)

t1()
t2()
t3()

Output for me

$ python3 kerasOnnxProblems.py 
Using TensorFlow backend.
CUDA Version 10.0.130
#define CUDNN_MAJOR 7
#define CUDNN_MINOR 3
#define CUDNN_PATCHLEVEL 1
python: 3.6.8 (default, Jan 14 2019, 11:02:34) 
[GCC 8.0.1 20180414 (experimental) [trunk revision 259383]]
keras: 2.2.4
tensorflow: 1.14.0
tensorrt: 5.0.2.6
keras2onnx: 1.5.0


--------------------------------------------------------------------------------
| t1(): Onnx downloaded from model zoo | Works fine
--------------------------------------------------------------------------------
    Parsing serialized onnx...
        parsing returned: True
--------------------------------------------------------------------------------
| t2(): Convert keras application to onnx using keras2onnx | Problem
--------------------------------------------------------------------------------
WARNING: Logging before flag parsing goes to stderr.
W0625 14:56:34.103247 139773652903744 deprecation_wrapper.py:119] From /home/jurjen/.local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0625 14:56:34.111482 139773652903744 deprecation_wrapper.py:119] From /home/jurjen/.local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0625 14:56:34.114243 139773652903744 deprecation_wrapper.py:119] From /home/jurjen/.local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:4185: The name tf.truncated_normal is deprecated. Please use tf.random.truncated_normal instead.

W0625 14:56:34.127869 139773652903744 deprecation_wrapper.py:119] From /home/jurjen/.local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.

W0625 14:56:34.128011 139773652903744 deprecation_wrapper.py:119] From /home/jurjen/.local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:181: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.

2019-06-25 14:56:34.128182: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2019-06-25 14:56:34.131916: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcuda.so.1
2019-06-25 14:56:34.132192: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1005] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2019-06-25 14:56:34.132783: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0xe660be0 executing computations on platform CUDA. Devices:
2019-06-25 14:56:34.132794: I tensorflow/compiler/xla/service/service.cc:175]   StreamExecutor device (0): GeForce RTX 2070, Compute Capability 7.5
2019-06-25 14:56:34.152747: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 3192000000 Hz
2019-06-25 14:56:34.153649: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0xeccbfb0 executing computations on platform Host. Devices:
2019-06-25 14:56:34.153661: I tensorflow/compiler/xla/service/service.cc:175]   StreamExecutor device (0): <undefined>, <undefined>
2019-06-25 14:56:34.153790: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1005] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2019-06-25 14:56:34.154283: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1640] Found device 0 with properties: 
name: GeForce RTX 2070 major: 7 minor: 5 memoryClockRate(GHz): 1.62
pciBusID: 0000:01:00.0
2019-06-25 14:56:34.154309: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcudart.so.10.0
2019-06-25 14:56:34.154328: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcublas.so.10.0
2019-06-25 14:56:34.155044: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcufft.so.10.0
2019-06-25 14:56:34.155222: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcurand.so.10.0
2019-06-25 14:56:34.156193: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcusolver.so.10.0
2019-06-25 14:56:34.156973: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcusparse.so.10.0
2019-06-25 14:56:34.156998: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcudnn.so.7
2019-06-25 14:56:34.157044: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1005] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2019-06-25 14:56:34.157572: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1005] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2019-06-25 14:56:34.158051: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1763] Adding visible gpu devices: 0
2019-06-25 14:56:34.158070: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcudart.so.10.0
2019-06-25 14:56:34.158088: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1181] Device interconnect StreamExecutor with strength 1 edge matrix:
2019-06-25 14:56:34.158096: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1187]      0 
2019-06-25 14:56:34.158101: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1200] 0:   N 
2019-06-25 14:56:34.158211: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1005] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2019-06-25 14:56:34.158714: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1005] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2019-06-25 14:56:34.159203: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1326] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 6822 MB memory) -> physical GPU (device: 0, name: GeForce RTX 2070, pci bus id: 0000:01:00.0, compute capability: 7.5)
W0625 14:56:34.556359 139773652903744 deprecation_wrapper.py:119] From /home/jurjen/.local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:1834: The name tf.nn.fused_batch_norm is deprecated. Please use tf.compat.v1.nn.fused_batch_norm instead.

W0625 14:56:34.594521 139773652903744 deprecation_wrapper.py:119] From /home/jurjen/.local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3976: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.

W0625 14:56:39.179351 139773652903744 deprecation_wrapper.py:119] From /home/jurjen/.local/lib/python3.6/site-packages/keras2onnx/common/utils.py:38: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.

W0625 14:56:39.179473 139773652903744 deprecation_wrapper.py:119] From /home/jurjen/.local/lib/python3.6/site-packages/keras2onnx/common/utils.py:38: The name tf.logging.WARN is deprecated. Please use tf.compat.v1.logging.WARN instead.

W0625 14:56:39.230884 139773652903744 deprecation.py:323] From /home/jurjen/.local/lib/python3.6/site-packages/keras2onnx/subgraph.py:124: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`
    Parsing serialized onnx...
        parsing returned: False
        Error #1/1
            Assertion failed: onnx_padding[0] == 0 && onnx_padding[1] == 0 && onnx_padding[4] == 0 && onnx_padding[5] == 0
--------------------------------------------------------------------------------
| t3(): Create extremely simple ConvNet -- no training | Problem
--------------------------------------------------------------------------------
    Parsing serialized onnx...
[TensorRT] ERROR: Parameter check failed at: ../builder/Network.cpp::addInput::406, condition: isValidDims(dims)
        parsing returned: False
        Error #1/1
            Assertion failed: *tensor = importer_ctx->network()->addInput( input.name().c_str(), trt_dtype, trt_dims)