|
|
@@ -43,7 +43,7 @@ chainer.ModelFactory = class {
|
|
|
return host.require('./pickle').then((pickle) => {
|
|
|
try {
|
|
|
let modules = [];
|
|
|
- let map = new Map();
|
|
|
+ let modulesMap = new Map();
|
|
|
|
|
|
let functionTable = new Map();
|
|
|
let constructorTable = new Map();
|
|
|
@@ -169,32 +169,26 @@ chainer.ModelFactory = class {
|
|
|
if (!entry.name.endsWith('.npy')) {
|
|
|
throw new chainer.Error("Invalid file name '" + entry.name + "'.");
|
|
|
}
|
|
|
- const id = entry.name.split('/');
|
|
|
- if (id.length < 2) {
|
|
|
+ const id = entry.name.replace(/\.npy$/, '');
|
|
|
+ const parts = id.split('/');
|
|
|
+ if (parts.length < 2) {
|
|
|
throw new chainer.Error("Invalid parameter name '" + entry.name + "'.");
|
|
|
}
|
|
|
- let parameterName = id.pop();
|
|
|
- parameterName = parameterName.substring(0, parameterName.length - 4);
|
|
|
- const moduleName = id.join('/');
|
|
|
- let module = null;
|
|
|
- if (map.has(id[0])) {
|
|
|
- module = map.get(moduleName);
|
|
|
- }
|
|
|
- else {
|
|
|
- module = {
|
|
|
- name: moduleName,
|
|
|
- parameters: []
|
|
|
- };
|
|
|
- map.set(moduleName, module);
|
|
|
- modules.push(module);
|
|
|
+ const parameterName = parts.pop();
|
|
|
+ const moduleName = parts.join('/');
|
|
|
+ if (!modulesMap.has(moduleName)) {
|
|
|
+ const newModule = { name: moduleName, parameters: [] };
|
|
|
+ modules.push(newModule);
|
|
|
+ modulesMap.set(moduleName, newModule);
|
|
|
}
|
|
|
+ const module = modulesMap.get(moduleName);
|
|
|
let array = new numpy.Array(entry.data);
|
|
|
if (array.byteOrder === '|') {
|
|
|
if (array.dataType !== 'O') {
|
|
|
throw new chainer.Error("Invalid data type '" + array.dataType + "'.");
|
|
|
}
|
|
|
const unpickler = new pickle.Unpickler(array.data);
|
|
|
- let root = unpickler.load(function_call);
|
|
|
+ const root = unpickler.load(function_call);
|
|
|
array = { dataType: root.dtype.name, shape: null, data: null, byteOrder: '|' };
|
|
|
}
|
|
|
|
|
|
@@ -226,24 +220,18 @@ chainer.ModelFactory = class {
|
|
|
throw new chainer.Error('File format is not Chainer HDF5');
|
|
|
}
|
|
|
let modules = [];
|
|
|
- let map = new Map();
|
|
|
+ let modulesMap = new Map();
|
|
|
for (const moduleGroup of rootGroup.groups) {
|
|
|
if (Object.keys(moduleGroup.attributes).length !== 0 || moduleGroup.value !== null) {
|
|
|
throw new chainer.Error('Module group format is not Chainer HDF5');
|
|
|
}
|
|
|
- let moduleName = moduleGroup.name;
|
|
|
- let module = null;
|
|
|
- if (map.has(moduleName)) {
|
|
|
- module = map.get(moduleName);
|
|
|
- }
|
|
|
- else {
|
|
|
- module = {
|
|
|
- name: moduleName,
|
|
|
- parameters: []
|
|
|
- };
|
|
|
- map.set(moduleName, module);
|
|
|
- modules.push(module);
|
|
|
+ const moduleName = moduleGroup.name;
|
|
|
+ if (!modulesMap.has(moduleName)) {
|
|
|
+ const newModule = { name: moduleName, parameters: [] };
|
|
|
+ modulesMap.set(moduleName, newModule);
|
|
|
+ modules.push(newModule);
|
|
|
}
|
|
|
+ const module = modulesMap.get(moduleName);
|
|
|
for (const variableGroup of moduleGroup.groups) {
|
|
|
if (Object.keys(variableGroup.attributes).length !== 0 || variableGroup.groups.length !== 0) {
|
|
|
throw new chainer.Error('Variable format is not Chainer HDF5');
|
|
|
@@ -457,16 +445,24 @@ chainer.Tensor = class {
|
|
|
}
|
|
|
switch (this._type.dataType) {
|
|
|
case 'float16':
|
|
|
+ context.itemSize = 2;
|
|
|
+ break;
|
|
|
case 'float32':
|
|
|
+ context.itemSize = 4;
|
|
|
+ break;
|
|
|
case 'float64':
|
|
|
- case 'int64':
|
|
|
- context.dataType = this._type.dataType;
|
|
|
+ context.itemSize = 8;
|
|
|
break;
|
|
|
- default:
|
|
|
- context.state = 'Tensor data type is not supported.';
|
|
|
+ case 'int64':
|
|
|
+ context.itemSize = 8;
|
|
|
break;
|
|
|
}
|
|
|
+ if (!context.itemSize) {
|
|
|
+ context.state = 'Tensor data type is not supported.';
|
|
|
+ return context;
|
|
|
+ }
|
|
|
context.dimensions = this._type.shape.dimensions;
|
|
|
+ context.dataType = this._type.dataType;
|
|
|
context.littleEndian = this._byteOrder == '<';
|
|
|
context.data = this._data;
|
|
|
context.rawData = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
|
|
|
@@ -491,21 +487,18 @@ chainer.Tensor = class {
|
|
|
switch (context.dataType) {
|
|
|
case 'float16':
|
|
|
results.push(context.rawData.getFloat16(context.index, littleEndian));
|
|
|
- context.index += 2;
|
|
|
break;
|
|
|
case 'float32':
|
|
|
results.push(context.rawData.getFloat32(context.index, littleEndian));
|
|
|
- context.index += 4;
|
|
|
break;
|
|
|
case 'float64':
|
|
|
results.push(context.rawData.getFloat64(context.index, littleEndian));
|
|
|
- context.index += 8;
|
|
|
break;
|
|
|
case 'int64':
|
|
|
results.push(long.Long.fromBytes(context.data.subarray(context.index, context.index + 8), true, littleEndian));
|
|
|
- context.index += 8;
|
|
|
break;
|
|
|
}
|
|
|
+ context.index += context.itemSize;
|
|
|
context.count++;
|
|
|
}
|
|
|
}
|