|
|
@@ -1,6 +1,8 @@
|
|
|
|
|
|
// Experimental
|
|
|
|
|
|
+import * as base from './base.js';
|
|
|
+
|
|
|
const dl4j = {};
|
|
|
|
|
|
dl4j.ModelFactory = class {
|
|
|
@@ -78,12 +80,14 @@ dl4j.Graph = class {
|
|
|
};
|
|
|
if (configuration.networkInputs) {
|
|
|
for (const input of configuration.networkInputs) {
|
|
|
- this.inputs.push(new dl4j.Argument(input, [ value(input) ]));
|
|
|
+ const argument = new dl4j.Argument(input, [ value(input) ]);
|
|
|
+ this.inputs.push(argument);
|
|
|
}
|
|
|
}
|
|
|
if (configuration.networkOutputs) {
|
|
|
for (const output of configuration.networkOutputs) {
|
|
|
- this.outputs.push(new dl4j.Argument(output, [ value(output) ]));
|
|
|
+ const argument = new dl4j.Argument(output, [ value(output) ]);
|
|
|
+ this.outputs.push(argument);
|
|
|
}
|
|
|
}
|
|
|
let inputs = null;
|
|
|
@@ -130,9 +134,12 @@ dl4j.Graph = class {
|
|
|
|
|
|
dl4j.Argument = class {
|
|
|
|
|
|
- constructor(name, value) {
|
|
|
+ constructor(name, value, visible) {
|
|
|
this.name = name;
|
|
|
this.value = value;
|
|
|
+ if (visible === false) {
|
|
|
+ this.visible = false;
|
|
|
+ }
|
|
|
}
|
|
|
};
|
|
|
|
|
|
@@ -214,7 +221,8 @@ dl4j.Node = class {
|
|
|
}
|
|
|
}
|
|
|
if (this.name) {
|
|
|
- this.outputs.push(new dl4j.Argument('output', [ value(this.name) ]));
|
|
|
+ const argument = new dl4j.Argument('output', [ value(this.name) ]);
|
|
|
+ this.outputs.push(argument);
|
|
|
}
|
|
|
let attributes = layer;
|
|
|
if (layer.activationFn) {
|
|
|
@@ -232,8 +240,8 @@ dl4j.Node = class {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- for (const key in attributes) {
|
|
|
- switch (key) {
|
|
|
+ for (const [name, value] of Object.entries(attributes)) {
|
|
|
+ switch (name) {
|
|
|
case '__type__':
|
|
|
case 'constraints':
|
|
|
case 'layerName':
|
|
|
@@ -244,7 +252,10 @@ dl4j.Node = class {
|
|
|
default:
|
|
|
break;
|
|
|
}
|
|
|
- this.attributes.push(new dl4j.Attribute(metadata.attribute(type, key), key, attributes[key]));
|
|
|
+ const definition = metadata.attribute(type, name);
|
|
|
+ const visible = definition && definition.visible === false ? false : true;
|
|
|
+ const attribute = new dl4j.Argument(name, value, visible);
|
|
|
+ this.attributes.push(attribute);
|
|
|
}
|
|
|
if (layer.idropout) {
|
|
|
const dropout = dl4j.Node._object(layer.idropout);
|
|
|
@@ -276,17 +287,6 @@ dl4j.Node = class {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-dl4j.Attribute = class {
|
|
|
-
|
|
|
- constructor(metadata, name, value) {
|
|
|
- this.name = name;
|
|
|
- this.value = value;
|
|
|
- if (metadata && metadata.visible === false) {
|
|
|
- this.visible = false;
|
|
|
- }
|
|
|
- }
|
|
|
-};
|
|
|
-
|
|
|
dl4j.Tensor = class {
|
|
|
|
|
|
constructor(dataType, shape) {
|
|
|
@@ -377,53 +377,18 @@ dl4j.NDArray = class {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-dl4j.BinaryReader = class {
|
|
|
+dl4j.BinaryReader = class extends base.BinaryReader {
|
|
|
|
|
|
constructor(buffer) {
|
|
|
- this._buffer = buffer;
|
|
|
- this._length = buffer.length;
|
|
|
- this._position = 0;
|
|
|
- this._view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
|
|
|
- }
|
|
|
-
|
|
|
- get length() {
|
|
|
- return this._length;
|
|
|
- }
|
|
|
-
|
|
|
- get position() {
|
|
|
- return this._position;
|
|
|
- }
|
|
|
-
|
|
|
- read(size) {
|
|
|
- const data = this._buffer.subarray(this._position, this._position + size);
|
|
|
- this._position += size;
|
|
|
- return data;
|
|
|
+ super(buffer, false);
|
|
|
}
|
|
|
|
|
|
string() {
|
|
|
- const size = this._buffer[this._position++] << 8 | this._buffer[this._position++];
|
|
|
+ const size = this.uint16();
|
|
|
const buffer = this.read(size);
|
|
|
this._decoder = this._decoder || new TextDecoder('ascii');
|
|
|
return this._decoder.decode(buffer);
|
|
|
}
|
|
|
-
|
|
|
- int32() {
|
|
|
- const position = this._position;
|
|
|
- this._position += 4;
|
|
|
- return this._view.getInt32(position, false);
|
|
|
- }
|
|
|
-
|
|
|
- int64() {
|
|
|
- const position = this._position;
|
|
|
- this._position += 4;
|
|
|
- return this._view.getInt64(position, false).toNumber();
|
|
|
- }
|
|
|
-
|
|
|
- float32() {
|
|
|
- const position = this._position;
|
|
|
- this._position += 4;
|
|
|
- return this._view.getFloat32(position, false);
|
|
|
- }
|
|
|
};
|
|
|
|
|
|
dl4j.Error = class extends Error {
|