Lutz Roeder 7 лет назад
Родитель
Сommit
7c4a830556
4 измененных файлов с 117 добавлено и 13 удалено
  1. 26 3
      test/app.js
  2. 45 1
      test/models.json
  3. 20 1
      tools/keras
  4. 26 8
      tools/keras-script.py

+ 26 - 3
test/app.js

@@ -3,6 +3,7 @@
 const fs = require('fs');
 const path = require('path');
 const process = require('process');
+const child_process = require('child_process');
 const vm = require('vm');
 const http = require('http');
 const https = require('https');
@@ -47,7 +48,7 @@ class TestHost {
     }
 
     exception(err, fatal) {
-        console.log("ERROR: " + err.toString());
+        console.log(err.toString());
     }
 }
 
@@ -260,6 +261,10 @@ function download(folder, targets, sources, completed, callback) {
         callback(null, completed);
         return;
     }
+    if (!sources) {
+        callback(new Error('Download source not specified.'), null);
+        return;
+    }
     var source = '';
     var sourceFiles = [];
     var startIndex = sources.indexOf('[');
@@ -288,7 +293,7 @@ function download(folder, targets, sources, completed, callback) {
     });
     request(source, [], (err, data) => {
         if (err) {
-            console.log("ERROR: " + err.toString());
+            callback(err, null);
             return;
         }
         if (sourceFiles.length > 0) {
@@ -305,7 +310,7 @@ function download(folder, targets, sources, completed, callback) {
                 process.stdout.write('  write ' + file + '\n');
                 var entry = archive.entries.filter((entry) => entry.name == file)[0];
                 if (!entry) {
-                    console.log("ERROR: Entry not found '" + file + '. Archive contains entries: ' + JSON.stringify(archive.entries.map((entry) => entry.name)) + " .");
+                    callback(new Error("Entry not found '" + file + '. Archive contains entries: ' + JSON.stringify(archive.entries.map((entry) => entry.name)) + " ."), null);
                 }
                 var target = targets.shift();
                 fs.writeFileSync(folder + '/' + target, entry.data, null);
@@ -350,6 +355,24 @@ function next() {
     process.stdout.write(targets[0] + '\n');
     var sources = item.source;
     download(folder, targets, sources, [], (err, completed) => {
+        if (err) {
+            if (item.status == 'script' && item.script) {
+                try { 
+                    var command = path.join(__dirname, item.script[0]) + ' ' + item.script[1];
+                    console.log('  ' + command);
+                    child_process.execSync(command, { stdio: [ 0, 1 , 2] });
+                    completed = targets;
+                }
+                catch (err) {
+                    console.log(err);
+                    return;
+                }
+            }
+            else {
+                console.log(err);
+                return;
+            }
+        }
         loadModel(folder + '/' + completed[0], item, (err, model) => {
             if (err) {
                 console.log(err);

+ 45 - 1
test/models.json

@@ -591,15 +591,47 @@
     "format": "Caffe2",
     "link":   "https://github.com/caffe2/models"
   },
+  {
+    "target": "keras/DenseNet121.h5",
+    "link":   "https://keras.io/applications",
+    "script": [ "../tools/keras", "sync install zoo" ],
+    "status": "script"
+  },
+  {
+    "target": "keras/InceptionResNetV2.h5",
+    "link":   "https://keras.io/applications",
+    "script": [ "../tools/keras", "sync install zoo" ],
+    "status": "script"
+  },
+  {
+    "target": "keras/InceptionV3.h5",
+    "link":   "https://keras.io/applications",
+    "script": [ "../tools/keras", "sync install zoo" ],
+    "status": "script"
+  },
   {
     "target": "keras/mimo.h5",
     "source": "https://github.com/lutzroeder/netron/files/2565761/mimo.h5.zip[mimo.h5]",
     "format": "Keras v2.2.0",
     "link":   "https://github.com/lutzroeder/netron/issues/138"
   },
+  {
+    "target": "keras/MobileNetV2.h5",
+    "link":   "https://keras.io/applications",
+    "script": [ "../tools/keras", "sync install zoo" ],
+    "status": "script"
+  },
+  {
+    "target": "keras/NASNetMobile.h5",
+    "link":   "https://keras.io/applications",
+    "script": [ "../tools/keras", "sync install zoo" ],
+    "status": "script"
+  },
   {
     "target": "keras/tiny-yolo-voc.h5",
-    "source": "https://raw.githubusercontent.com/hollance/YOLO-CoreML-MPSNNGraph/master/Convert/yad2k/model_data/tiny-yolo-voc.h5"
+    "source": "https://raw.githubusercontent.com/hollance/YOLO-CoreML-MPSNNGraph/master/Convert/yad2k/model_data/tiny-yolo-voc.h5",
+    "format": "Keras v1.2.2",
+    "link":   "https://github.com/hollance/YOLO-CoreML-MPSNNGraph/tree/master/Convert/yad2k/model_data"
   },
   {
     "target": "keras/tiramisu_fc_dense103_model.json",
@@ -607,6 +639,18 @@
     "format": "Keras",
     "link":   "https://github.com/0bserver07/One-Hundred-Layers-Tiramisu"
   },
+  {
+    "target": "keras/VGG16.h5",
+    "link":   "https://keras.io/applications",
+    "script": [ "../tools/keras", "sync install zoo" ],
+    "status": "script"
+  },
+  {
+    "target": "keras/VGG19.h5",
+    "link":   "https://keras.io/applications",
+    "script": [ "../tools/keras", "sync install zoo" ],
+    "status": "script"
+  },
   {
     "target": "cntk/v1/AlexNet.model",
     "source": "https://www.cntk.ai/Models/AlexNet/AlexNet.model",

+ 20 - 1
tools/keras

@@ -4,6 +4,7 @@ set -e
 
 root=$(cd $(dirname ${0})/..; pwd)
 build=${root}/build
+test=${root}/test
 tools=${root}/tools
 third_party=${root}/third_party
 
@@ -30,7 +31,6 @@ sync() {
 }
 
 install() {
-    echo "Install Keras"
     virtualenv --quiet -p ${python} ${virtualenv}
     source ${virtualenv}/bin/activate
     ${pip} install --quiet tensorflow
@@ -47,11 +47,30 @@ metadata() {
     deactivate
 }
 
+zoo() {
+    source ${virtualenv}/bin/activate
+    pushd ${tools} > /dev/null
+    ${python} keras-script.py zoo keras.applications.densenet.DenseNet121 ${test}/data/keras/DenseNet121.h5
+    ${python} keras-script.py zoo keras.applications.inception_resnet_v2.InceptionResNetV2 ${test}/data/keras/InceptionResNetV2.h5
+    ${python} keras-script.py zoo keras.applications.inception_v3.InceptionV3 ${test}/data/keras/InceptionV3.h5
+    ${python} keras-script.py zoo keras.applications.mobilenet_v2.MobileNetV2 ${test}/data/keras/MobileNetV2.h5
+    ${python} keras-script.py zoo keras.applications.nasnet.NASNetMobile ${test}/data/keras/NASNetMobile.h5
+    ${python} keras-script.py zoo keras.applications.resnet50.ResNet50 ${test}/data/keras/ResNet50.h5
+    ${python} keras-script.py zoo keras.applications.vgg16.VGG16 ${test}/data/keras/VGG16.h5
+    ${python} keras-script.py zoo keras.applications.vgg19.VGG19 ${test}/data/keras/VGG19.h5
+    ${python} keras-script.py zoo keras.applications.xception.Xception ${test}/data/keras/Xception.h5
+    rm -rf ~/.keras/models
+    popd > /dev/null
+    deactivate
+}
+
+
 while [ "$#" != 0 ]; do
     command="$1" && shift
     case "${command}" in
         "sync") sync;;
         "install") install;;
         "metadata") metadata;;
+        "zoo") zoo;;
     esac
 done

+ 26 - 8
tools/keras-script.py

@@ -4,10 +4,13 @@ from __future__ import print_function
 
 import io
 import json
+import os
 import pydoc
 import re
 import sys
 
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
+
 def count_leading_spaces(s):
     ws = re.search(r'\S', s)
     if ws:
@@ -319,12 +322,27 @@ for entry in json_root:
         if len(headers) > 0:
             raise Exception('\'' + class_name + '.__doc__\' contains unprocessed headers.')
  
-with io.open(json_file, 'w', newline='') as fout:
-    json_data = json.dumps(json_root, sort_keys=True, indent=2)
-    for line in json_data.splitlines():
-        line = line.rstrip()
-        if sys.version_info[0] < 3:
-            line = unicode(line)
-        fout.write(line)
-        fout.write('\n')
+def metadata():
+    with io.open(json_file, 'w', newline='') as fout:
+        json_data = json.dumps(json_root, sort_keys=True, indent=2)
+        for line in json_data.splitlines():
+            line = line.rstrip()
+            if sys.version_info[0] < 3:
+                line = unicode(line)
+            fout.write(line)
+            fout.write('\n')
+
+def zoo():
+    from pydoc import locate
+    type = sys.argv[2];
+    file = sys.argv[3];
+    directory = os.path.dirname(file);
+    if not os.path.exists(directory):
+        os.makedirs(directory)
+    model = locate(type)()
+    model.save(file);
 
+if __name__ == '__main__':
+    command_table = { 'metadata': metadata, 'zoo': zoo }
+    command = sys.argv[1];
+    command_table[command]()