Răsfoiți Sursa

Modular onnx script

Lutz Roeder 7 ani în urmă
părinte
comite
9fac78fb22
4 a modificat fișierele cu 120 adăugiri și 79 ștergeri
  1. 0 75
      tools/metadata/onnx-update
  2. 2 2
      tools/metadata/update
  3. 78 0
      tools/onnx
  4. 40 2
      tools/onnx-script.py

+ 0 - 75
tools/metadata/onnx-update

@@ -1,75 +0,0 @@
-#!/bin/bash
-
-set -e
-
-if [ "$#" == 0 ]; then
-    __sync=true
-    __build=true
-    __update=true
-else
-    while test $# -gt 0
-    do
-        case "$1" in
-            sync) __sync=true;;
-            build) __build=true;;
-            update) __update=true;;
-        esac
-        shift
-    done
-fi
-
-root=$(cd $(dirname ${0})/../..; pwd)
-build=${root}/build
-node_modules=${root}/node_modules
-src=${root}/src
-third_party=${root}/third_party
-tools=${root}/tools
-
-python=${python:-python}
-pip=${pip:-pip}
-
-identifier=onnx
-
-if [ ${__sync} ]; then
-    repository=https://github.com/onnx/${identifier}.git
-    mkdir -p ${third_party}
-    if [ -d "${third_party}/${identifier}" ]; then
-        git -C "${third_party}/${identifier}" fetch -p
-        git -C "${third_party}/${identifier}" reset --hard origin/master
-    else
-        echo "Clone ${repository}..."
-        git -C "${third_party}" clone --recursive ${repository}
-    fi
-fi
-
-echo "Install ONNX"
-virtualenv=${build}/virtualenv/${identifier}
-if [ ${__build} ]; then
-    virtualenv -p ${python} ${virtualenv}
-fi
-if [ -f ${virtualenv}/bin/activate ]; then
-    source ${virtualenv}/bin/activate
-fi
-if [ ${__build} ]; then
-    export ONNX_ML=1
-    export ONNX_NAMESPACE=onnx
-    ${pip} install ${third_party}/${identifier}
-fi
-
-if [ ${__update} ]; then
-    echo "Generate 'onnx-metadata.json'"
-    pushd ${tools}/metadata > /dev/null
-    ${python} onnx-metadata.py
-    popd > /dev/null
-fi
-
-if [ -f ${virtualenv}/bin/activate ]; then
-    deactivate
-fi
-
-if [ ${__update} ]; then
-    echo "Generate 'onnx.js'"
-    ${node_modules}/protobufjs/bin/pbjs -t static-module -w closure --no-encode --no-delimited --no-comments --keep-case --decode-text -r onnx -o ${src}/onnx-proto.js ${third_party}/${identifier}/onnx/onnx-ml.proto ${third_party}/${identifier}/onnx/onnx-operators-ml.proto
-    node ${tools}/metadata/update_pbjs.js array ${src}/onnx-proto.js float_data float 1
-    node ${tools}/metadata/update_pbjs.js array ${src}/onnx-proto.js double_data double 1
-fi

+ 2 - 2
tools/metadata/update

@@ -26,10 +26,10 @@ echo "Update MXNet"
 ${tools}/metadata/mxnet-update
 
 echo "Update ONNX"
-${tools}/metadata/onnx-update
+${tools}/onnx sync build schema metadata
 
 echo "Update TensorFlow Lite"
-${tools}/metadata/tflite-update
+${tools}/tflite-update
 
 echo "Update scikit-learn"
 ${tools}/metadata/sklearn-update

+ 78 - 0
tools/onnx

@@ -0,0 +1,78 @@
+#!/bin/bash
+
+set -e
+
+root=$(cd $(dirname ${0})/..; pwd)
+build=${root}/build
+node_modules=${root}/node_modules
+src=${root}/src
+third_party=${root}/third_party
+tools=${root}/tools
+
+identifier=onnx
+virtualenv=${build}/virtualenv/${identifier}
+
+python=${python:-python}
+pip=${pip:-pip}
+
+git_sync () {
+    mkdir -p "${third_party}"
+    if [ -d "${third_party}/${1}" ]; then
+        git -C "${third_party}/${1}" fetch -p --quiet
+        git -C "${third_party}/${1}" reset --quiet --hard origin/master
+    else
+        echo "Clone ${2}..."
+        git -C "${third_party}" clone --recursive ${2}
+    fi
+}
+
+sync() {
+    git_sync onnx https://github.com/onnx/onnx.git
+    git_sync onnxmltools https://github.com/onnx/onnxmltools.git
+}
+
+build() {
+    echo "Build ONNX"
+    virtualenv --quiet -p ${python} ${virtualenv}
+    source ${virtualenv}/bin/activate
+    export ONNX_ML=1
+    export ONNX_NAMESPACE=onnx
+    ${pip} install --quiet ${third_party}/onnx
+    deactivate
+}
+
+schema() {
+    echo "Generate 'onnx-proto.js'"
+    source ${virtualenv}/bin/activate
+    ${node_modules}/protobufjs/bin/pbjs -t static-module -w closure --no-encode --no-delimited --no-comments --keep-case --decode-text -r onnx -o ${src}/onnx-proto.js ${third_party}/${identifier}/onnx/onnx-ml.proto ${third_party}/${identifier}/onnx/onnx-operators-ml.proto
+    node ${tools}/metadata/update_pbjs.js array ${src}/onnx-proto.js float_data float 1
+    node ${tools}/metadata/update_pbjs.js array ${src}/onnx-proto.js double_data double 1
+    deactivate
+}
+
+metadata() {
+    echo "Generate 'onnx-metadata.json'"
+    source ${virtualenv}/bin/activate
+    pushd ${tools} > /dev/null
+    ${python} onnx-script.py metadata
+    popd > /dev/null
+    deactivate
+}
+
+convert() {
+    source ${virtualenv}/bin/activate
+    ${pip} install --quiet ${third_party}/onnxmltools
+
+    deactivate
+}
+
+while [ "$#" != 0 ]; do
+    command="$1"
+    shift
+    case "${command}" in
+        "sync") sync;;
+        "build") build;;
+        "schema") schema;;
+        "metadata") metadata;;
+    esac
+done

+ 40 - 2
tools/metadata/onnx-metadata.py → tools/onnx-script.py

@@ -1,6 +1,9 @@
 
 from __future__ import unicode_literals
 
+import onnx
+print(onnx.__file__)
+
 import json
 import io
 import sys
@@ -190,7 +193,42 @@ def generate_json(schemas, json_file):
             fout.write(line)
             fout.write('\n')
 
-if __name__ == '__main__':
+def metadata():
     schemas = defs.get_all_schemas_with_history()
     schemas = sorted(schemas, key=lambda schema: schema.name)
-    generate_json(schemas, '../../src/onnx-metadata.json')
+    generate_json(schemas, '../src/onnx-metadata.json')
+
+def convert():
+    file = sys.argv[2];
+    base, extension = os.path.splitext(file)
+    if extension == '.mlmodel':
+        import coremltools
+        import onnxmltools
+        coreml_model = coremltools.utils.load_spec(file)
+        onnx_model = onnxmltools.convert.convert_coreml(coreml_model)
+        onnxmltools.utils.save_model(onnx_model, base + '.onnx')
+    elif extension == '.h5':
+        import keras
+        import onnxmltools
+        keras_model = keras.models.load_model(file)
+        onnx_model = onnxmltools.convert.convert_keras(keras_model)
+        onnxmltools.utils.save_model(onnx_model, base + '.onnx')
+    elif extension == '.pkl':
+        from sklearn.externals import joblib
+        import onnxmltools
+        sklearn_model = joblib.load(file)
+        onnx_model = onnxmltools.convert.convert_sklearn(sklearn_model)
+        onnxmltools.utils.save_model(onnx_model, base + '.onnx')
+    base, extension = os.path.splitext(file)
+    if extension == '.onnx':
+        import onnx
+        from google.protobuf import text_format
+        onnx_model = onnx.load(file)
+        text = text_format.MessageToString(onnx_model)
+        with open(base + '.pbtxt', 'w') as text_file:
+            text_file.write(text)
+
+if __name__ == '__main__':
+    command_table = { 'metadata': metadata }
+    command = sys.argv[1];
+    command_table[command]()