소스 검색

Update pytorch.js

Lutz Roeder 3 년 전
부모
커밋
c4e84d0bc8
1개의 변경된 파일311개의 추가작업 그리고 194개의 파일을 삭제
  1. 311 194
      source/pytorch.js

+ 311 - 194
source/pytorch.js

@@ -2133,6 +2133,7 @@ pytorch.Container.Tar = class {
 
     constructor(entries) {
         this._entries = entries;
+        this._graphs = [ this ];
     }
 
     set metadata(value) {
@@ -2147,6 +2148,10 @@ pytorch.Container.Tar = class {
         return 'PyTorch v0.1.1';
     }
 
+    get graphs() {
+        return this._graphs;
+    }
+
     get type() {
         this._unpickle();
         return this._type;
@@ -2267,6 +2272,7 @@ pytorch.Container.Pickle = class {
 
     constructor(stream) {
         this._stream = stream;
+        this._graphs = [ this ];
     }
 
     set metadata(value) {
@@ -2281,6 +2287,10 @@ pytorch.Container.Pickle = class {
         return 'PyTorch v0.1.10';
     }
 
+    get graphs() {
+        return this._graphs;
+    }
+
     get type() {
         this._unpickle();
         return this._type;
@@ -2394,35 +2404,49 @@ pytorch.Container.Pickle = class {
 pytorch.Container.Zip = class {
 
     static open(entries) {
-        const name = Array.from(entries.keys()).find((name) => name == 'model.json' || name == 'data.pkl' || name.endsWith('/model.json') || name.endsWith('/data.pkl'));
-        if (!name) {
-            return null;
-        }
-        let model = null;
-        if (name.endsWith('.json')) {
-            try {
-                const stream = entries.get(name);
-                const buffer = stream.peek();
-                const decoder = new TextDecoder('utf-8');
-                const content = decoder.decode(buffer);
-                model = JSON.parse(content);
-                if (!model.mainModule) {
-                    return null;
+        if (entries.size > 0) {
+            let prefix = [];
+            const paths = Array.from(entries.keys()).map((path) => path.split('/').reverse());
+            for (;;) {
+                const set = new Set(paths.map((path) => path.length > 0 ? path.pop() : null));
+                if (set.size !== 1 || set.keys().next().value === null) {
+                    break;
+                }
+                prefix.push(set.keys().next().value);
+            }
+            prefix = prefix.join('/');
+            prefix = prefix.length > 0 ? prefix + '/' : prefix;
+            entries = new Map(Array.from(entries).map((entry) => [ entry[0].substring(prefix.length), entry[1] ]));
+            if (entries.has('model.json')) {
+                try {
+                    const stream = entries.get('model.json');
+                    const buffer = stream.peek();
+                    const decoder = new TextDecoder('utf-8');
+                    const content = decoder.decode(buffer);
+                    const model = JSON.parse(content);
+                    if (model.mainModule) {
+                        return new pytorch.Container.Zip.Json(entries, model);
+                    }
+                }
+                catch (error) {
+                    // continue regardless of error
                 }
             }
-            catch (error) {
-                return null;
+            if (entries.has('data.pkl')) {
+                return new pytorch.Container.Zip.Pickle(entries);
+            }
+            if (Array.from(entries.keys()).find((name) => name.startsWith('.data/'))) {
+                return new pytorch.Container.Zip.Package(entries);
             }
         }
-        return new pytorch.Container.Zip(entries, name, model);
+        return null;
     }
 
-    constructor(entries, name, model) {
-        this._entries = entries;
+    constructor(entries) {
         // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md
-        this._model = model;
-        const lastIndex = name.lastIndexOf('/');
-        this._prefix = lastIndex === -1 ? '' : name.substring(0, lastIndex + 1);
+        this._entries = entries;
+        this._producer = '';
+        this._graphs = [ this ];
     }
 
     set metadata(value) {
@@ -2433,48 +2457,12 @@ pytorch.Container.Zip = class {
         this._exceptionCallback = value;
     }
 
-    get format() {
-        if (this._format === undefined) {
-            if (this._entry('model.json')) {
-                this._format = this._entry('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0';
-            }
-            else if (this._entry('data.pkl')) {
-                // https://github.com/pytorch/pytorch/blob/master/caffe2/serialize/inline_container.h
-                // kProducedFileFormatVersion
-                const versions = new Map([
-                    [ '1', 'v1.3'  ],
-                    [ '2', 'v1.5'  ], // 7a2889b014ce36fcc333b2c6de6f29f976652f84 (#28122)
-                    [ '3', 'v1.6'  ], // 2ec6a30722b0ef85632a2f3e7ce6f80da403008a (#36085)
-                    [ '4', 'v1.6'  ], // 95489b590f00801bdee7f41783f30874883cf6bb (#38620)
-                    [ '5', 'v1.7'  ], // cb26661fe4faf26386703180a9045e6ac6d157df (#40364)
-                    [ '6', 'v1.9'  ], // 3ee7637ffa50df0d9b231c7b40778ac1c390bf4a (#59714)
-                    [ '7', 'v1.10' ]  // 880098a7e34a20628f960daa8eab0eb1ad566c39 (#63651)
-                ]);
-                const value = this.version;
-                if (!versions.has(value)) {
-                    this._exceptionCallback(new pytorch.Error("Unsupported PyTorch Zip version '" + value + "'."));
-                }
-                const version = versions.get(value);
-                const constants = this._entry('constants.pkl');
-                this._format = (constants ? 'TorchScript' : 'PyTorch') + ' ' + (version || 'v-' + value.toString() );
-            }
-        }
-        return this._format;
-    }
-
-    get version() {
-        const stream = this._entry('version');
-        if (stream) {
-            const decoder = new TextDecoder('utf-8');
-            const buffer = stream.peek();
-            const value = decoder.decode(buffer);
-            return value.split('\n').shift();
-        }
-        return '';
+    get producer() {
+        return this._producer;
     }
 
-    get producer() {
-        return this.data ? this._producer : '';
+    get graphs() {
+        return this.graphs;
     }
 
     get name() {
@@ -2486,19 +2474,19 @@ pytorch.Container.Zip = class {
     }
 
     get type() {
-        this._load();
+        this.read();
         return this._type;
     }
 
     get data() {
-        this._load();
+        this.read();
         return this._data;
     }
 
     get constants() {
         if (this._constants === undefined) {
             this._constants = [];
-            const stream = this._entry('constants.pkl');
+            const stream = this._entries.get('constants.pkl');
             if (stream) {
                 const buffer = stream.peek();
                 this._constants = this._unpickle(buffer, this._storage('constants'));
@@ -2540,8 +2528,8 @@ pytorch.Container.Zip = class {
             const sources = new Map();
             for (const entry of this._entries) {
                 const name = entry[0];
-                if (name.startsWith(this._prefix + 'code')) {
-                    const file = name.substring(this._prefix.length);
+                if (name.startsWith('code')) {
+                    const file = name;
                     if (sources.has(file)) {
                         throw new pytorch.Error("Duplicate source file '" + file + "'.");
                     }
@@ -2560,136 +2548,50 @@ pytorch.Container.Zip = class {
         return this._execution;
     }
 
-    _entry(name) {
-        return this._entries.get(this._prefix + name);
+    version(name) {
+        const stream = this._entries.get(name);
+        if (stream) {
+            const decoder = new TextDecoder('utf-8');
+            const buffer = stream.peek();
+            const text = decoder.decode(buffer);
+            const value = text.split('\n').shift();
+            // https://github.com/pytorch/pytorch/blob/master/caffe2/serialize/inline_container.h
+            // kProducedFileFormatVersion
+            const versions = new Map([
+                [ '1', 'v1.3'  ],
+                [ '2', 'v1.5'  ], // 7a2889b014ce36fcc333b2c6de6f29f976652f84 (#28122)
+                [ '3', 'v1.6'  ], // 2ec6a30722b0ef85632a2f3e7ce6f80da403008a (#36085)
+                [ '4', 'v1.6'  ], // 95489b590f00801bdee7f41783f30874883cf6bb (#38620)
+                [ '5', 'v1.7'  ], // cb26661fe4faf26386703180a9045e6ac6d157df (#40364)
+                [ '6', 'v1.9'  ], // 3ee7637ffa50df0d9b231c7b40778ac1c390bf4a (#59714)
+                [ '7', 'v1.10' ]  // 880098a7e34a20628f960daa8eab0eb1ad566c39 (#63651)
+            ]);
+            if (!versions.has(value)) {
+                this._exceptionCallback(new pytorch.Error("Unsupported PyTorch Zip version '" + value + "'."));
+            }
+            return versions.get(value) || 'v-' + value.toString();
+        }
+        return '';
     }
 
-    _load() {
-        if (this._data === undefined) {
-            this._data = null;
-            const stream = this._entry('data.pkl');
-            if (stream) {
-                const buffer = stream.peek();
-                this._data = this._unpickle(buffer, this._storage('data'));
-            }
-            else if (this._model) {
-                this._producer = this._model.producerName + (this._model.producerVersion ? ' v' + this._model.producerVersion : '');
-                this._data = this._model.mainModule || {};
-                this._name = this._data.name || '';
-                if (this._data.torchscriptArena) {
-                    this._torchscriptArena = this._data.torchscriptArena.key;
-                }
-                const queue = [ this._data ];
-                const entries = new Map();
-                for (const entry of this._entries) {
-                    const name = entry[0];
-                    const stream = entry[1];
-                    const buffer = stream.peek();
-                    entries.set(name, buffer);
-                }
-                const tensorTypeMap = new Map([
-                    [ 'FLOAT', 'Float' ],
-                    [ 'FLOAT16', 'Half' ],
-                    [ 'DOUBLE', 'Double' ],
-                    [ 'INT8', 'Char' ],
-                    [ 'INT32', 'Int' ],
-                    [ 'INT64', 'Long' ]
-                ]);
-                const constants = this._model.tensors || [];
-                this._constants = constants.map((constant) => {
-                    const key = this._prefix + constant.data.key;
-                    if (!tensorTypeMap.has(constant.dataType)) {
-                        throw new pytorch.Error("Unsupported tensor data type '" + constant.dataType + "'.");
-                    }
-                    const type = tensorTypeMap.get(constant.dataType);
-                    const shape = constant.dims ? constant.dims.map((dim) => parseInt(dim, 10)) : null;
-                    const storage_type = this.execution.type('torch.' + type + 'Storage');
-                    const size = (shape || []).reduce((a, b) => a * b, 1);
-                    const offset = parseInt(constant.offset, 10) || 0;
-                    const storage = new storage_type([ size ]);
-                    const itemsize = storage.dtype.itemsize();
-                    const buffer = entries.get(key);
-                    const length = size * itemsize;
-                    const data = buffer.slice(offset, offset + length);
-                    storage._set_cdata(data);
-                    const tensor = this.execution.invoke('torch._utils._rebuild_tensor', [ storage, 0, shape, 0 ]);
-                    tensor.name = constant.data.key;
-                    return tensor;
-                });
-                this._attributes = [];
-                const stream = this._entry('attributes.pkl');
-                if (stream) {
-                    const buffer = stream.peek();
-                    const unpickler = python.Unpickler.open(buffer);
-                    this._attributes.push(...unpickler.load((name, args) => this.execution.invoke(name, args)));
-                }
-                while (queue.length > 0) {
-                    const module = queue.shift();
-                    if (!module.__class__) {
-                        module.__class__ = {
-                            __module__: 'torch.nn.modules.module',
-                            __name__: 'Module'
-                        };
-                    }
-                    if (module.name) {
-                        module.__id__ = module.name;
-                    }
-                    if (module.submodules) {
-                        for (const submodule of module.submodules) {
-                            module[submodule.name] = submodule;
-                            submodule.__parent__ = module;
-                            queue.push(submodule);
-                        }
-                        delete module.submodules;
-                    }
-                    const attributes = [];
-                    if (module.attributes) {
-                        attributes.push(...module.attributes);
-                        delete module.attributes;
-                    }
-                    const parameters = [];
-                    if (module.parameters) {
-                        parameters.push(...module.parameters);
-                        delete module.parameters;
-                    }
-                    if (module.arguments) {
-                        parameters.push(...module.arguments);
-                        delete module.arguments;
-                    }
-                    for (const parameter of parameters) {
-                        const tensor = this._constants[parameter.tensorId];
-                        module[parameter.name] = tensor;
-                        if (!parameter.__class__) {
-                            parameter.__class__ = {
-                                __module__: 'torch',
-                                __name__: 'Tensor'
-                            };
-                        }
-                    }
-                    for (const attribute of attributes) {
-                        module[attribute.name] = this._attributes[attribute.id];
-                    }
-                }
-                delete this._model;
-            }
-            if (this.format.startsWith('TorchScript ') && (this._torchscriptArena || this._data.forward)) {
-                this._type = 'script';
-                return;
-            }
-            const root = pytorch.Utility.findModule(this._data);
-            if (root) {
-                this._type = 'module';
-                this._data = root;
+    read() {
+        if (this.format.startsWith('TorchScript ') && (this._torchscriptArena || this._data.forward)) {
+            this._type = 'script';
+            return;
+        }
+        const root = pytorch.Utility.findModule(this._data);
+        if (root) {
+            this._type = 'module';
+            this._data = root;
+        }
+        else {
+            const weights = pytorch.Utility.findWeights(this._data);
+            if (weights) {
+                this._type = 'weights';
+                this._data = weights;
             }
             else {
-                const weights = pytorch.Utility.findWeights(this._data);
-                if (weights) {
-                    this._type = 'weights';
-                    this._data = weights;
-                }
-                else {
-                    throw new pytorch.Error('File does not contain root module or state dictionary.');
-                }
+                throw new pytorch.Error('File does not contain root module or state dictionary.');
             }
         }
     }
@@ -2742,7 +2644,7 @@ pytorch.Container.Zip = class {
 
     _storage(dirname) {
         const map = new Map();
-        const prefix = this._prefix + dirname + '/';
+        const prefix = dirname + '/';
         for (const entry of this._entries) {
             if (entry[0].startsWith(prefix)) {
                 const key = entry[0].substring(prefix.length);
@@ -2884,6 +2786,221 @@ pytorch.Container.Zip = class {
     }
 };
 
+pytorch.Container.Zip.Json = class extends pytorch.Container.Zip {
+
+    constructor(entries, model) {
+        super(entries);
+        this._producer = model && model.producerName ? model.producerName + (model.producerVersion ? ' v' + model.producerVersion : '') : '';
+        this._model = model;
+    }
+
+    get format() {
+        return this._entries.get('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0';
+    }
+
+    read() {
+        if (!this._data) {
+            this._data = this._model.mainModule || {};
+            this._name = this._data.name || '';
+            if (this._data.torchscriptArena) {
+                this._torchscriptArena = this._data.torchscriptArena.key;
+            }
+            const queue = [ this._data ];
+            const entries = new Map();
+            for (const entry of this._entries) {
+                const name = entry[0];
+                const stream = entry[1];
+                const buffer = stream.peek();
+                entries.set(name, buffer);
+            }
+            const tensorTypeMap = new Map([
+                [ 'FLOAT', 'Float' ],
+                [ 'FLOAT16', 'Half' ],
+                [ 'DOUBLE', 'Double' ],
+                [ 'INT8', 'Char' ],
+                [ 'INT32', 'Int' ],
+                [ 'INT64', 'Long' ]
+            ]);
+            const constants = this._model.tensors || [];
+            this._constants = constants.map((constant) => {
+                const key = constant.data.key;
+                if (!tensorTypeMap.has(constant.dataType)) {
+                    throw new pytorch.Error("Unsupported tensor data type '" + constant.dataType + "'.");
+                }
+                const type = tensorTypeMap.get(constant.dataType);
+                const shape = constant.dims ? constant.dims.map((dim) => parseInt(dim, 10)) : null;
+                const storage_type = this.execution.type('torch.' + type + 'Storage');
+                const size = (shape || []).reduce((a, b) => a * b, 1);
+                const offset = parseInt(constant.offset, 10) || 0;
+                const storage = new storage_type([ size ]);
+                const itemsize = storage.dtype.itemsize();
+                const buffer = entries.get(key);
+                const length = size * itemsize;
+                const data = buffer.slice(offset, offset + length);
+                storage._set_cdata(data);
+                const tensor = this.execution.invoke('torch._utils._rebuild_tensor', [ storage, 0, shape, 0 ]);
+                tensor.name = constant.data.key;
+                return tensor;
+            });
+            this._attributes = [];
+            const stream = this._entries.get('attributes.pkl');
+            if (stream) {
+                const buffer = stream.peek();
+                const unpickler = python.Unpickler.open(buffer);
+                this._attributes.push(...unpickler.load((name, args) => this.execution.invoke(name, args)));
+            }
+            while (queue.length > 0) {
+                const module = queue.shift();
+                if (!module.__class__) {
+                    module.__class__ = {
+                        __module__: 'torch.nn.modules.module',
+                        __name__: 'Module'
+                    };
+                }
+                if (module.name) {
+                    module.__id__ = module.name;
+                }
+                if (module.submodules) {
+                    for (const submodule of module.submodules) {
+                        module[submodule.name] = submodule;
+                        submodule.__parent__ = module;
+                        queue.push(submodule);
+                    }
+                    delete module.submodules;
+                }
+                const attributes = [];
+                if (module.attributes) {
+                    attributes.push(...module.attributes);
+                    delete module.attributes;
+                }
+                const parameters = [];
+                if (module.parameters) {
+                    parameters.push(...module.parameters);
+                    delete module.parameters;
+                }
+                if (module.arguments) {
+                    parameters.push(...module.arguments);
+                    delete module.arguments;
+                }
+                for (const parameter of parameters) {
+                    const tensor = this._constants[parameter.tensorId];
+                    module[parameter.name] = tensor;
+                    if (!parameter.__class__) {
+                        parameter.__class__ = {
+                            __module__: 'torch',
+                            __name__: 'Tensor'
+                        };
+                    }
+                }
+                for (const attribute of attributes) {
+                    module[attribute.name] = this._attributes[attribute.id];
+                }
+            }
+            delete this._model;
+            super.read();
+        }
+    }
+};
+
+pytorch.Container.Zip.Pickle = class extends pytorch.Container.Zip {
+
+    constructor(entries) {
+        super(entries);
+    }
+
+    get format() {
+        return (this._entries.get('constants.pkl') ? 'TorchScript' : 'PyTorch') + ' ' + this.version('version');
+    }
+
+    read() {
+        if (!this._data) {
+            const stream = this._entries.get('data.pkl');
+            const buffer = stream.peek();
+            this._data = this._unpickle(buffer, this._storage('data'));
+            super.read();
+        }
+    }
+};
+
+pytorch.Container.Zip.Package = class extends pytorch.Container.Zip {
+
+    constructor(entries) {
+        super(entries);
+    }
+
+    get format() {
+        return 'PyTorch Package' + ' ' + this.version('.data/version');
+    }
+
+    read() {
+        const entries = Array.from(this._entries).filter((entry) => !entry[0].startsWith('.data/') && !entry[0].endsWith('py'));
+        for (const entry of entries) {
+            /* const name = */ entry[0];
+            const stream = entry[1];
+            const loaded_reduces = new Map();
+            // const loaded_storages = new Map();
+            const persistent_load = (saved_id) => {
+                const typename = saved_id.shift();
+                switch (typename) {
+                    case 'storage': {
+                        /*
+                        const storage_type = saved_id[0];
+                        const key = saved_id[1];
+                        const location = saved_id[2];
+                        const size = saved_id[3];
+                        dtype = storage_type.dtype
+                        if key not in loaded_storages:
+                            load_tensor(
+                                dtype,
+                                size,
+                                key,
+                                _maybe_decode_ascii(location),
+                                restore_location,
+                            )
+                        storage = loaded_storages[key]
+                        # TODO: Once we decide to break serialization FC, we can
+                        # stop wrapping with _TypedStorage
+                        return torch.storage._TypedStorage(
+                            wrap_storage=storage._untyped(), dtype=dtype
+                        )
+                        */
+                        throw new pytorch.Error('');
+                    }
+                    case 'reduce_package': {
+                        if (saved_id.left === 2) {
+                            const func = saved_id[0];
+                            const args = saved_id[1];
+                            return execution.invoke(func, args);
+                        }
+                        const reduce_id = saved_id[0];
+                        const func = saved_id[1];
+                        const args = saved_id[2];
+                        if (!loaded_reduces.has(reduce_id)) {
+                            const value = execution.invoke(func, args);
+                            loaded_reduces.set(reduce_id, value);
+                        }
+                        return loaded_reduces.get(reduce_id);
+                    }
+                    default: {
+                        throw new python.Error("Unknown package typename '" + typename + "'.");
+                    }
+                }
+            };
+            const execution = new pytorch.Container.Zip.Execution(null, this._exceptionCallback, this._metadata);
+            execution.registerFunction('torch.jit._script.unpackage_script_module', function(script_module_id) {
+                /* const data = */ '.data/ts_code' + script_module_id + 'data.pkl';
+                // const constants = '.data/ts_code' + script_module_id + 'constants.pkl.pkl';
+                return { __TODO__: 'unpackage_script_module' };
+            });
+            const unpickler = python.Unpickler.open(stream);
+            /* const obj = */ unpickler.load((name, args) => execution.invoke(name, args), persistent_load);
+        }
+    }
+};
+
+pytorch.Container.Zip.Script = class {
+};
+
 pytorch.Container.Zip.Execution = class extends pytorch.Execution {
 
     constructor(sources, exceptionCallback, metadata) {