|
|
@@ -6,44 +6,62 @@ var base = base || require('./base');
|
|
|
darknet.ModelFactory = class {
|
|
|
|
|
|
match(context) {
|
|
|
- try {
|
|
|
- const reader = base.TextReader.create(context.buffer);
|
|
|
- for (;;) {
|
|
|
- const line = reader.read();
|
|
|
- if (line === undefined) {
|
|
|
- break;
|
|
|
+ const identifier = context.identifier;
|
|
|
+ const extension = identifier.split('.').pop().toLowerCase();
|
|
|
+ switch (extension) {
|
|
|
+ case 'weights':
|
|
|
+ if (darknet.Weights.open(context.buffer)) {
|
|
|
+ return true;
|
|
|
}
|
|
|
- const text = line.trim();
|
|
|
- if (text.length === 0 || text.startsWith('#')) {
|
|
|
- continue;
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ try {
|
|
|
+ const reader = base.TextReader.create(context.buffer);
|
|
|
+ for (;;) {
|
|
|
+ const line = reader.read();
|
|
|
+ if (line === undefined) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ const text = line.trim();
|
|
|
+ if (text.length === 0 || text.startsWith('#')) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (text.startsWith('[') && text.endsWith(']')) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
- if (text.startsWith('[') && text.endsWith(']')) {
|
|
|
- return true;
|
|
|
+ catch (err) {
|
|
|
+ // continue regardless of error
|
|
|
}
|
|
|
- }
|
|
|
- }
|
|
|
- catch (err) {
|
|
|
- // continue regardless of error
|
|
|
+ break;
|
|
|
}
|
|
|
return false;
|
|
|
}
|
|
|
|
|
|
open(context, host) {
|
|
|
return darknet.Metadata.open(host).then((metadata) => {
|
|
|
+ const open = (metadata, cfg, weights) => {
|
|
|
+ return new darknet.Model(metadata, cfg, darknet.Weights.open(weights));
|
|
|
+ };
|
|
|
const identifier = context.identifier;
|
|
|
const parts = identifier.split('.');
|
|
|
- parts.pop();
|
|
|
+ const extension = parts.pop().toLowerCase();
|
|
|
const basename = parts.join('.');
|
|
|
- return context.request(basename + '.weights', null).then((weights) => {
|
|
|
- return this._openModel(metadata, identifier, context.buffer, weights);
|
|
|
- }).catch(() => {
|
|
|
- return this._openModel(metadata, identifier, context.buffer, null);
|
|
|
- });
|
|
|
+ switch (extension) {
|
|
|
+ case 'weights':
|
|
|
+ return context.request(basename + '.cfg', null).then((cfg) => {
|
|
|
+ return open(metadata, cfg, context.buffer);
|
|
|
+ });
|
|
|
+ default:
|
|
|
+ return context.request(basename + '.weights', null).then((weights) => {
|
|
|
+ return open(metadata, context.buffer, weights);
|
|
|
+ }).catch(() => {
|
|
|
+ return open(metadata, context.buffer, null);
|
|
|
+ });
|
|
|
+ }
|
|
|
});
|
|
|
}
|
|
|
- _openModel( metadata, identifier, cfg, weights) {
|
|
|
- return new darknet.Model(metadata, cfg, weights ? new darknet.Weights(weights) : null);
|
|
|
- }
|
|
|
};
|
|
|
|
|
|
darknet.Model = class {
|
|
|
@@ -141,7 +159,7 @@ darknet.Graph = class {
|
|
|
};
|
|
|
|
|
|
const load_weights = (name, shape, visible) => {
|
|
|
- const data = weights ? weights.bytes(4 * shape.reduce((a, b) => a * b)) : null;
|
|
|
+ const data = weights ? weights.read(4 * shape.reduce((a, b) => a * b)) : null;
|
|
|
const type = new darknet.TensorType('float32', make_shape(shape, 'load_weights'));
|
|
|
const initializer = new darknet.Tensor(type, data);
|
|
|
const argument = new darknet.Argument('', null, initializer);
|
|
|
@@ -1087,18 +1105,49 @@ darknet.TensorShape = class {
|
|
|
|
|
|
darknet.Weights = class {
|
|
|
|
|
|
+ static open(buffer) {
|
|
|
+ if (buffer) {
|
|
|
+ const reader = new darknet.Weights.BinaryReader(buffer);
|
|
|
+ const major = reader.int32();
|
|
|
+ const minor = reader.int32();
|
|
|
+ const revision = reader.int32();
|
|
|
+ const seen = ((major * 10 + minor) >= 2) ? reader.int64() : reader.int32();
|
|
|
+ const transpose = (major > 1000) || (minor > 1000);
|
|
|
+ // if (transpose) {
|
|
|
+ // throw new darknet.Error("Unsupported transpose weights file version '" + [ major, minor, revision ].join('.') + "'.");
|
|
|
+ // }
|
|
|
+ if (!transpose) {
|
|
|
+ return new darknet.Weights(reader);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+
|
|
|
+ constructor(reader) {
|
|
|
+ this._reader = reader;
|
|
|
+ }
|
|
|
+
|
|
|
+ read(size) {
|
|
|
+ return this._reader.bytes(size);
|
|
|
+ }
|
|
|
+
|
|
|
+ validate() {
|
|
|
+ if (!this._reader.end()) {
|
|
|
+ throw new darknet.Error('Invalid weights size.');
|
|
|
+ }
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+darknet.Weights.BinaryReader = class {
|
|
|
+
|
|
|
constructor(buffer) {
|
|
|
this._buffer = buffer;
|
|
|
this._dataView = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
|
|
|
this._position = 0;
|
|
|
- const major = this.int32();
|
|
|
- const minor = this.int32();
|
|
|
- const revision = this.int32();
|
|
|
- this._seen = ((major * 10 + minor) >= 2) ? this.int64() : this.int32();
|
|
|
- const transpose = (major > 1000) || (minor > 1000);
|
|
|
- if (transpose) {
|
|
|
- throw new darknet.Error("Unsupported transpose weights file version '" + [ major, minor, revision ].join('.') + "'.");
|
|
|
- }
|
|
|
+ }
|
|
|
+
|
|
|
+ end() {
|
|
|
+ return this._position === this._buffer.length;
|
|
|
}
|
|
|
|
|
|
int32() {
|
|
|
@@ -1125,12 +1174,6 @@ darknet.Weights = class {
|
|
|
throw new darknet.Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
- validate() {
|
|
|
- if (this._position !== this._buffer.length) {
|
|
|
- throw new darknet.Error('Invalid weights size.');
|
|
|
- }
|
|
|
- }
|
|
|
};
|
|
|
|
|
|
darknet.Metadata = class {
|