|
|
@@ -27,32 +27,34 @@ cntk.ModelFactory = class {
|
|
|
return cntk.Metadata.open(context).then((metadata) => {
|
|
|
switch (match) {
|
|
|
case 'cntk.v1': {
|
|
|
+ let obj = null;
|
|
|
try {
|
|
|
const stream = context.stream;
|
|
|
const buffer = stream.peek();
|
|
|
- const obj = new cntk_v1.ComputationNetwork(buffer);
|
|
|
- return new cntk.Model(metadata, 1, obj);
|
|
|
+ obj = new cntk_v1.ComputationNetwork(buffer);
|
|
|
}
|
|
|
catch (error) {
|
|
|
const message = error && error.message ? error.message : error.toString();
|
|
|
throw new cntk.Error('File format is not CNTK v1 (' + message.replace(/\.$/, '') + ').');
|
|
|
}
|
|
|
+ return new cntk.Model(metadata, 1, obj);
|
|
|
}
|
|
|
case 'cntk.v2': {
|
|
|
return context.require('./cntk-proto').then(() => {
|
|
|
+ let obj = null;
|
|
|
try {
|
|
|
cntk_v2 = protobuf.get('cntk').CNTK.proto;
|
|
|
cntk_v2.PoolingType = { 0: 'Max', 1: 'Average' };
|
|
|
const stream = context.stream;
|
|
|
const reader = protobuf.BinaryReader.open(stream);
|
|
|
const dictionary = cntk_v2.Dictionary.decode(reader);
|
|
|
- const obj = cntk.ModelFactory._convertDictionary(dictionary);
|
|
|
- return new cntk.Model(metadata, 2, obj);
|
|
|
+ obj = cntk.ModelFactory._convertDictionary(dictionary);
|
|
|
}
|
|
|
catch (error) {
|
|
|
const message = error && error.message ? error.message : error.toString();
|
|
|
throw new cntk.Error('File format is not cntk.Dictionary (' + message.replace(/\.$/, '') + ').');
|
|
|
}
|
|
|
+ return new cntk.Model(metadata, 2, obj);
|
|
|
});
|
|
|
}
|
|
|
default: {
|