Browse Source

Update Deeplearning4J detection (#303)

Lutz Roeder 4 years ago
parent
commit
8c07869eb6
2 changed files with 38 additions and 42 deletions
  1. 36 40
      source/dl4j.js
  2. 2 2
      source/view.js

+ 36 - 40
source/dl4j.js

@@ -7,14 +7,45 @@ var json = json || require('./json');
 dl4j.ModelFactory = class {
 
     match(context) {
-        const entries = context.entries('zip');
-        return dl4j.Container.open(entries);
+        switch (context.identifier) {
+            case 'configuration.json': {
+                const obj = context.open('json');
+                if (obj && (obj.confs || obj.vertices)) {
+                    return 'dl4j.configuration';
+                }
+                break;
+            }
+            case 'coefficients.bin': {
+                const signature = [ 0x00, 0x07, 0x4A, 0x41, 0x56, 0x41, 0x43, 0x50, 0x50 ];
+                const stream = context.stream;
+                if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
+                    return 'dl4j.coefficients';
+                }
+                break;
+            }
+        }
+        return undefined;
     }
 
     open(context, match) {
         return dl4j.Metadata.open(context).then((metadata) => {
-            const container = match;
-            return new dl4j.Model(metadata, container.configuration, container.coefficients);
+            switch (match) {
+                case 'dl4j.configuration': {
+                    const obj = context.open('json');
+                    return context.request('coefficients.bin', null).then((stream) => {
+                        return new dl4j.Model(metadata, obj, stream.peek());
+                    }).catch(() => {
+                        return new dl4j.Model(metadata, obj, null);
+                    });
+                }
+                case 'dl4j.coefficients': {
+                    return context.request('configuration.json', null).then((stream) => {
+                        const reader = json.TextReader.open(stream);
+                        const obj = reader.read();
+                        return new dl4j.Model(metadata, obj, context.stream.peek());
+                    });
+                }
+            }
         });
     }
 };
@@ -43,8 +74,7 @@ dl4j.Graph = class {
         this._outputs =[];
         this._nodes = [];
 
-        const reader = new dl4j.NDArrayReader(coefficients);
-        const dataType = reader.dataType;
+        const dataType = coefficients ? new dl4j.NDArrayReader(coefficients).dataType : '?';
 
         if (configuration.networkInputs) {
             for (const input of configuration.networkInputs) {
@@ -464,40 +494,6 @@ dl4j.Metadata = class {
     }
 };
 
-dl4j.Container = class {
-
-    static open(entries) {
-        const stream = entries.get('configuration.json');
-        const coefficients = entries.get('coefficients.bin');
-        if (stream) {
-            try {
-                const reader = json.TextReader.open(stream);
-                const configuration = reader.read();
-                if (configuration && (configuration.confs || configuration.vertices)) {
-                    return new dl4j.Container(configuration, coefficients ? coefficients.peek() : []);
-                }
-            }
-            catch (error) {
-                // continue regardless of error
-            }
-        }
-        return undefined;
-    }
-
-    constructor(configuration, coefficients) {
-        this._configuration = configuration;
-        this._coefficients = coefficients;
-    }
-
-    get configuration() {
-        return this._configuration;
-    }
-
-    get coefficients() {
-        return this._coefficients;
-    }
-};
-
 dl4j.NDArrayReader = class {
 
     constructor(buffer) {

+ 2 - 2
source/view.js

@@ -1542,9 +1542,9 @@ view.ModelFactoryService = class {
         this.register('./barracuda', [ '.nn' ]);
         this.register('./dnn', [ '.dnn' ]);
         this.register('./xmodel', [ '.xmodel' ]);
-        this.register('./openvino', [ '.xml', '.bin' ]);
         this.register('./flux', [ '.bson' ]);
-        this.register('./dl4j', [ '.zip' ]);
+        this.register('./dl4j', [ '.json', '.bin' ]);
+        this.register('./openvino', [ '.xml', '.bin' ]);
         this.register('./mlnet', [ '.zip' ]);
         this.register('./acuity', [ '.json' ]);
         this.register('./imgdnn', [ '.dnn', 'params', '.json' ]);