Sfoglia il codice sorgente

Fix Chainer module name check

Lutz Roeder 5 anni fa
parent
commit
1819fd80c7
1 ha cambiato i file con 32 aggiunte e 39 eliminazioni
  1. 32 39
      src/chainer.js

+ 32 - 39
src/chainer.js

@@ -43,7 +43,7 @@ chainer.ModelFactory = class {
             return host.require('./pickle').then((pickle) => {
                 try {
                     let modules = [];
-                    let map = new Map();
+                    let modulesMap = new Map();
 
                     let functionTable = new Map();
                     let constructorTable = new Map();
@@ -169,32 +169,26 @@ chainer.ModelFactory = class {
                         if (!entry.name.endsWith('.npy')) {
                             throw new chainer.Error("Invalid file name '" + entry.name + "'.");
                         }
-                        const id = entry.name.split('/');
-                        if (id.length < 2) {
+                        const id = entry.name.replace(/\.npy$/, '');
+                        const parts = id.split('/');
+                        if (parts.length < 2) {
                             throw new chainer.Error("Invalid parameter name '" + entry.name + "'.");
                         }
-                        let parameterName = id.pop();
-                        parameterName = parameterName.substring(0, parameterName.length - 4);
-                        const moduleName = id.join('/');
-                        let module = null;
-                        if (map.has(id[0])) {
-                            module = map.get(moduleName);
-                        }
-                        else {
-                            module = { 
-                                name: moduleName,
-                                parameters: []
-                            };
-                            map.set(moduleName, module);
-                            modules.push(module);
+                        const parameterName = parts.pop();
+                        const moduleName = parts.join('/');
+                        if (!modulesMap.has(moduleName)) {
+                            const newModule = { name: moduleName, parameters: [] };
+                            modules.push(newModule);
+                            modulesMap.set(moduleName, newModule);
                         }
+                        const module = modulesMap.get(moduleName);
                         let array = new numpy.Array(entry.data);
                         if (array.byteOrder === '|') {
                             if (array.dataType !== 'O') {
                                 throw new chainer.Error("Invalid data type '" + array.dataType + "'.");
                             }
                             const unpickler = new pickle.Unpickler(array.data);
-                            let root = unpickler.load(function_call);
+                            const root = unpickler.load(function_call);
                             array = { dataType: root.dtype.name, shape: null, data: null, byteOrder: '|' };
                         }
 
@@ -226,24 +220,18 @@ chainer.ModelFactory = class {
                     throw new chainer.Error('File format is not Chainer HDF5');
                 }
                 let modules = [];
-                let map = new Map();
+                let modulesMap = new Map();
                 for (const moduleGroup of rootGroup.groups) {
                     if (Object.keys(moduleGroup.attributes).length !== 0 || moduleGroup.value !== null) {
                         throw new chainer.Error('Module group format is not Chainer HDF5');
                     }
-                    let moduleName = moduleGroup.name;
-                    let module = null;
-                    if (map.has(moduleName)) {
-                        module = map.get(moduleName);
-                    }
-                    else {
-                        module = { 
-                            name: moduleName,
-                            parameters: []
-                        };
-                        map.set(moduleName, module);
-                        modules.push(module);
+                    const moduleName = moduleGroup.name;
+                    if (!modulesMap.has(moduleName)) {
+                        const newModule = { name: moduleName, parameters: [] };
+                        modulesMap.set(moduleName, newModule);
+                        modules.push(newModule);
                     }
+                    const module = modulesMap.get(moduleName);
                     for (const variableGroup of moduleGroup.groups) {
                         if (Object.keys(variableGroup.attributes).length !== 0 || variableGroup.groups.length !== 0) {
                             throw new chainer.Error('Variable format is not Chainer HDF5');
@@ -457,16 +445,24 @@ chainer.Tensor = class  {
         }
         switch (this._type.dataType) {
             case 'float16':
+                context.itemSize = 2;
+                break;
             case 'float32':
+                context.itemSize = 4;
+                break;
             case 'float64':
-            case 'int64':
-                context.dataType = this._type.dataType;
+                context.itemSize = 8;
                 break;
-            default:
-                context.state = 'Tensor data type is not supported.';
+            case 'int64':
+                context.itemSize = 8;
                 break;
         }
+        if (!context.itemSize) {
+            context.state = 'Tensor data type is not supported.';
+            return context;
+        }
         context.dimensions = this._type.shape.dimensions;
+        context.dataType = this._type.dataType;
         context.littleEndian = this._byteOrder == '<';
         context.data = this._data;
         context.rawData = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
@@ -491,21 +487,18 @@ chainer.Tensor = class  {
                     switch (context.dataType) {
                         case 'float16':
                             results.push(context.rawData.getFloat16(context.index, littleEndian));
-                            context.index += 2;
                             break;
                         case 'float32':
                             results.push(context.rawData.getFloat32(context.index, littleEndian));
-                            context.index += 4;
                             break;
                         case 'float64':
                             results.push(context.rawData.getFloat64(context.index, littleEndian));
-                            context.index += 8;
                             break;
                         case 'int64':
                             results.push(long.Long.fromBytes(context.data.subarray(context.index, context.index + 8), true, littleEndian));
-                            context.index += 8;
                             break;
                     }
+                    context.index += context.itemSize;
                     context.count++;
                 }
             }