Browse Source

Update dl4j.js

Lutz Roeder 2 years ago
parent
commit
e5d1ce5715
1 changed files with 21 additions and 56 deletions
  1. 21 56
      source/dl4j.js

+ 21 - 56
source/dl4j.js

@@ -1,6 +1,8 @@
 
 // Experimental
 
+import * as base from './base.js';
+
 const dl4j = {};
 
 dl4j.ModelFactory = class {
@@ -78,12 +80,14 @@ dl4j.Graph = class {
         };
         if (configuration.networkInputs) {
             for (const input of configuration.networkInputs) {
-                this.inputs.push(new dl4j.Argument(input, [ value(input) ]));
+                const argument = new dl4j.Argument(input, [ value(input) ]);
+                this.inputs.push(argument);
             }
         }
         if (configuration.networkOutputs) {
             for (const output of configuration.networkOutputs) {
-                this.outputs.push(new dl4j.Argument(output, [ value(output) ]));
+                const argument = new dl4j.Argument(output, [ value(output) ]);
+                this.outputs.push(argument);
             }
         }
         let inputs = null;
@@ -130,9 +134,12 @@ dl4j.Graph = class {
 
 dl4j.Argument = class {
 
-    constructor(name, value) {
+    constructor(name, value, visible) {
         this.name = name;
         this.value = value;
+        if (visible === false) {
+            this.visible = false;
+        }
     }
 };
 
@@ -214,7 +221,8 @@ dl4j.Node = class {
             }
         }
         if (this.name) {
-            this.outputs.push(new dl4j.Argument('output', [ value(this.name) ]));
+            const argument = new dl4j.Argument('output', [ value(this.name) ]);
+            this.outputs.push(argument);
         }
         let attributes = layer;
         if (layer.activationFn) {
@@ -232,8 +240,8 @@ dl4j.Node = class {
                 }
             }
         }
-        for (const key in attributes) {
-            switch (key) {
+        for (const [name, value] of Object.entries(attributes)) {
+            switch (name) {
                 case '__type__':
                 case 'constraints':
                 case 'layerName':
@@ -244,7 +252,10 @@ dl4j.Node = class {
                 default:
                     break;
             }
-            this.attributes.push(new dl4j.Attribute(metadata.attribute(type, key), key, attributes[key]));
+            const definition = metadata.attribute(type, name);
+            const visible = definition && definition.visible === false ? false : true;
+            const attribute = new dl4j.Argument(name, value, visible);
+            this.attributes.push(attribute);
         }
         if (layer.idropout) {
             const dropout = dl4j.Node._object(layer.idropout);
@@ -276,17 +287,6 @@ dl4j.Node = class {
     }
 };
 
-dl4j.Attribute = class {
-
-    constructor(metadata, name, value) {
-        this.name = name;
-        this.value = value;
-        if (metadata && metadata.visible === false) {
-            this.visible = false;
-        }
-    }
-};
-
 dl4j.Tensor = class {
 
     constructor(dataType, shape) {
@@ -377,53 +377,18 @@ dl4j.NDArray = class {
     }
 };
 
-dl4j.BinaryReader = class {
+dl4j.BinaryReader = class extends base.BinaryReader {
 
     constructor(buffer) {
-        this._buffer = buffer;
-        this._length = buffer.length;
-        this._position = 0;
-        this._view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
-    }
-
-    get length() {
-        return this._length;
-    }
-
-    get position() {
-        return this._position;
-    }
-
-    read(size) {
-        const data = this._buffer.subarray(this._position, this._position + size);
-        this._position += size;
-        return data;
+        super(buffer, false);
     }
 
     string() {
-        const size = this._buffer[this._position++] << 8 | this._buffer[this._position++];
+        const size = this.uint16();
         const buffer = this.read(size);
         this._decoder = this._decoder || new TextDecoder('ascii');
         return this._decoder.decode(buffer);
     }
-
-    int32() {
-        const position = this._position;
-        this._position += 4;
-        return this._view.getInt32(position, false);
-    }
-
-    int64() {
-        const position = this._position;
-        this._position += 4;
-        return this._view.getInt64(position, false).toNumber();
-    }
-
-    float32() {
-        const position = this._position;
-        this._position += 4;
-        return this._view.getFloat32(position, false);
-    }
 };
 
 dl4j.Error = class extends Error {