Lutz Roeder 4 лет назад
Родитель
Сommit
944d5aabc6
3 измененных файлов с 566 добавлено и 0 удалено
  1. 368 0
      source/flax.js
  2. 197 0
      source/msgpack.js
  3. 1 0
      source/view.js

+ 368 - 0
source/flax.js

@@ -0,0 +1,368 @@
+
+// Experimental
+
+var flax = flax || {};
+var python = python || require('./python');
+
+flax.ModelFactory = class {
+
+    match(context) {
+        const stream = context.stream;
+        if (stream.length > 4) {
+            const code = stream.peek(1)[0];
+            if (code === 0xDE || code === 0xDF || ((code & 0x80) === 0x80)) {
+                return 'msgpack.map';
+            }
+        }
+        return '';
+    }
+
+    open(context) {
+        return context.require('./msgpack').then((msgpack) => {
+            const stream = context.stream;
+            const buffer = stream.peek();
+            const execution = new python.Execution(null);
+            const reader = msgpack.BinaryReader.open(buffer, (code, data) => {
+                switch (code) {
+                    case 1: { // _MsgpackExtType.ndarray
+                        const reader = msgpack.BinaryReader.open(data);
+                        const tuple = reader.read();
+                        const dtype = execution.invoke('numpy.dtype', [ tuple[1] ]);
+                        dtype.byteorder = '<';
+                        return execution.invoke('numpy.ndarray', [ tuple[0], dtype, tuple[2] ]);
+                    }
+                    default:
+                        throw new flax.Error("Unknown MessagePack extension '" + code + "'.");
+                }
+            });
+            const obj = reader.read();
+            return new flax.Model(obj);
+        });
+    }
+};
+
+flax.Model = class {
+
+    constructor(obj) {
+        this._graphs = [ new flax.Graph(obj) ];
+    }
+
+    get format() {
+        return 'Flax';
+    }
+
+    get graphs() {
+        return this._graphs;
+    }
+};
+
+flax.Graph = class {
+
+    constructor(obj) {
+        const layers = new Map();
+        const flatten = (path, obj) => {
+            if (Object.entries(obj).every((entry) => entry[1].__class__ && entry[1].__class__.__module__ === 'numpy' && entry[1].__class__.__name__ === 'ndarray')) {
+                layers.set(path.join('.'), obj);
+            }
+            else {
+                for (const pair of Object.entries(obj)) {
+                    flatten(path.concat(pair[0]), pair[1]);
+                }
+            }
+        };
+        flatten([], obj);
+        this._nodes = Array.from(layers).map((entry) => new flax.Node(entry[0], entry[1]));
+    }
+
+    get inputs() {
+        return [];
+    }
+
+    get outputs() {
+        return [];
+    }
+
+    get nodes() {
+        return this._nodes;
+    }
+};
+
+flax.Parameter = class {
+
+    constructor(name, args) {
+        this._name = name;
+        this._arguments = args;
+    }
+
+    get name() {
+        return this._name;
+    }
+
+    get visible() {
+        return true;
+    }
+
+    get arguments() {
+        return this._arguments;
+    }
+};
+
+flax.Argument = class {
+
+    constructor(name, initializer) {
+        if (typeof name !== 'string') {
+            throw new flax.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
+        }
+        this._name = name;
+        this._initializer = initializer || null;
+    }
+
+    get name() {
+        return this._name;
+    }
+
+    get type() {
+        return this._initializer.type;
+    }
+
+    get initializer() {
+        return this._initializer;
+    }
+};
+
+flax.Node = class {
+
+    constructor(name, weights) {
+        this._name = name;
+        this._type = { name: 'Module' };
+        this._inputs = [];
+        for (const entry of Object.entries(weights)) {
+            const name = entry[0];
+            const tensor = new flax.Tensor(entry[1]);
+            const argument = new flax.Argument(this._name + '.' + name, tensor);
+            const parameter = new flax.Parameter(name, [ argument ]);
+            this._inputs.push(parameter);
+        }
+    }
+
+    get type() {
+        return this._type;
+    }
+
+    get name() {
+        return this._name;
+    }
+
+    get inputs() {
+        return this._inputs;
+    }
+
+    get outputs() {
+        return [];
+    }
+
+    get attributes() {
+        return [];
+    }
+};
+
+flax.TensorType = class {
+
+    constructor(dataType, shape) {
+        this._dataType = dataType;
+        this._shape = shape;
+    }
+
+    get dataType() {
+        return this._dataType || '?';
+    }
+
+    get shape() {
+        return this._shape;
+    }
+
+    toString() {
+        return this.dataType + this._shape.toString();
+    }
+};
+
+flax.TensorShape = class {
+
+    constructor(dimensions) {
+        this._dimensions = dimensions;
+    }
+
+    get dimensions() {
+        return this._dimensions;
+    }
+
+    toString() {
+        if (!this._dimensions || this._dimensions.length == 0) {
+            return '';
+        }
+        return '[' + this._dimensions.join(',') + ']';
+    }
+};
+
+flax.Tensor = class {
+
+    constructor(array) {
+        this._type = new flax.TensorType(array.dtype.name, new flax.TensorShape(array.shape));
+        this._data = array.tobytes();
+        this._byteorder = array.dtype.byteorder;
+        this._itemsize = array.dtype.itemsize;
+    }
+
+    get type() {
+        return this._type;
+    }
+
+    get state() {
+        return this._context().state;
+    }
+
+    get value() {
+        const context = this._context();
+        if (context.state) {
+            return null;
+        }
+        context.limit = Number.MAX_SAFE_INTEGER;
+        return this._decode(context, 0);
+    }
+
+    toString() {
+        const context = this._context();
+        if (context.state) {
+            return '';
+        }
+        context.limit = 10000;
+        const value = this._decode(context, 0);
+        return flax.Tensor._stringify(value, '', '    ');
+    }
+
+    _context() {
+        const context = {};
+        context.index = 0;
+        context.count = 0;
+        context.state = null;
+        if (this._byteorder !== '<' && this._byteorder !== '>' && this._type.dataType !== 'uint8' && this._type.dataType !== 'int8') {
+            context.state = 'Tensor byte order is not supported.';
+            return context;
+        }
+        if (!this._data || this._data.length == 0) {
+            context.state = 'Tensor data is empty.';
+            return context;
+        }
+        context.itemSize = this._itemsize;
+        context.dimensions = this._type.shape.dimensions;
+        context.dataType = this._type.dataType;
+        context.littleEndian = this._byteorder == '<';
+        context.data = this._data;
+        context.rawData = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
+        return context;
+    }
+
+    _decode(context, dimension) {
+        const littleEndian = context.littleEndian;
+        const shape = context.dimensions.length == 0 ? [ 1 ] : context.dimensions;
+        const results = [];
+        const size = shape[dimension];
+        if (dimension == shape.length - 1) {
+            for (let i = 0; i < size; i++) {
+                if (context.count > context.limit) {
+                    results.push('...');
+                    return results;
+                }
+                if (context.rawData) {
+                    switch (context.dataType) {
+                        case 'float16':
+                            results.push(context.rawData.getFloat16(context.index, littleEndian));
+                            break;
+                        case 'float32':
+                            results.push(context.rawData.getFloat32(context.index, littleEndian));
+                            break;
+                        case 'float64':
+                            results.push(context.rawData.getFloat64(context.index, littleEndian));
+                            break;
+                        case 'int8':
+                            results.push(context.rawData.getInt8(context.index, littleEndian));
+                            break;
+                        case 'int16':
+                            results.push(context.rawData.getInt16(context.index, littleEndian));
+                            break;
+                        case 'int32':
+                            results.push(context.rawData.getInt32(context.index, littleEndian));
+                            break;
+                        case 'int64':
+                            results.push(context.rawData.getInt64(context.index, littleEndian));
+                            break;
+                        case 'uint8':
+                            results.push(context.rawData.getUint8(context.index, littleEndian));
+                            break;
+                        case 'uint16':
+                            results.push(context.rawData.getUint16(context.index, littleEndian));
+                            break;
+                        case 'uint32':
+                            results.push(context.rawData.getUint32(context.index, littleEndian));
+                            break;
+                    }
+                    context.index += context.itemSize;
+                    context.count++;
+                }
+            }
+        }
+        else {
+            for (let j = 0; j < size; j++) {
+                if (context.count > context.limit) {
+                    results.push('...');
+                    return results;
+                }
+                results.push(this._decode(context, dimension + 1));
+            }
+        }
+        if (context.dimensions.length == 0) {
+            return results[0];
+        }
+        return results;
+    }
+
+    static _stringify(value, indentation, indent) {
+        if (Array.isArray(value)) {
+            const result = [];
+            result.push(indentation + '[');
+            const items = value.map((item) => flax.Tensor._stringify(item, indentation + indent, indent));
+            if (items.length > 0) {
+                result.push(items.join(',\n'));
+            }
+            result.push(indentation + ']');
+            return result.join('\n');
+        }
+        if (typeof value == 'string') {
+            return indentation + value;
+        }
+        if (value == Infinity) {
+            return indentation + 'Infinity';
+        }
+        if (value == -Infinity) {
+            return indentation + '-Infinity';
+        }
+        if (isNaN(value)) {
+            return indentation + 'NaN';
+        }
+        return indentation + value.toString();
+    }
+};
+
+flax.Error = class extends Error {
+
+    constructor(message) {
+        super(message);
+        this.name = 'Error loading Flax model.';
+    }
+};
+
+if (typeof module !== 'undefined' && typeof module.exports === 'object') {
+    module.exports.ModelFactory = flax.ModelFactory;
+}
+
+

+ 197 - 0
source/msgpack.js

@@ -0,0 +1,197 @@
+
+var msgpack = msgpack || {};
+
+// https://github.com/msgpack/msgpack-javascript/blob/master/src/Decoder.ts
+
+msgpack.BinaryReader = class {
+
+    static open(data, callback) {
+        return new msgpack.BinaryReader(data, callback);
+    }
+
+    constructor(buffer, callback) {
+        this._buffer = buffer;
+        this._callback = callback;
+        this._position = 0;
+        this._view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
+        this._decoder = new TextDecoder('utf8');
+    }
+
+    read() {
+        const value = this.value();
+        return value;
+    }
+
+    value() {
+        const c = this.byte();
+        if (c >= 0xe0) {
+            return c - 0x100;
+        }
+        else if (c < 0xC0) {
+            if (c < 0x80) {
+                return c;
+            }
+            else if (c < 0x90) {
+                return this.map(c - 0x80);
+            }
+            else if (c < 0xa0) {
+                return this.array(c - 0x90);
+            }
+            return this.string(c - 0xa0);
+        }
+        else {
+            switch (c) {
+                case 0xC0: return null;
+                case 0xC2: return false;
+                case 0xC3: return true;
+                case 0xC4: return this.bytes(this.byte());
+                case 0xC5: return this.bytes(this.uint16());
+                case 0xC6: return this.bytes(this.uint32());
+                case 0xC7: return this.extension(this.byte());
+                case 0xC8: return this.extension(this.uint16());
+                case 0xC9: return this.extension(this.uint32());
+                case 0xCA: return this.float32();
+                case 0xCB: return this.float64();
+                case 0xCC: return this.byte();
+                case 0xCD: return this.uint16();
+                case 0xCE: return this.uint32();
+                case 0xCF: return this.uint64();
+                case 0xD0: return this.int8();
+                case 0xD1: return this.int16();
+                case 0xD2: return this.int32();
+                case 0xD3: return this.int64();
+                case 0xD4: return this.extension(1);
+                case 0xD5: return this.extension(2);
+                case 0xD6: return this.extension(4);
+                case 0xD7: return this.extension(8);
+                case 0xD8: return this.extension(16);
+                case 0xD9: return this.string(this.byte());
+                case 0xDA: return this.string(this.uint16());
+                case 0xDB: return this.string(this.uint32());
+                case 0xDC: return this.array(this.uint16());
+                case 0xDD: return this.array(this.uint32());
+                case 0xDE: return this.map(this.uint16());
+                case 0xDF: return this.map(this.uint32());
+            }
+        }
+        throw new msgpack.Error("Invalid code '" + c + "'.");
+    }
+
+    map(size) {
+        const map = {};
+        for (let i = 0; i < size; i++) {
+            const key = this.value();
+            const value = this.value();
+            map[key] = value;
+        }
+        return map;
+    }
+
+    array(size) {
+        const array = new Array(size);
+        for (let i = 0; i < size; i++) {
+            array[i] = this.value();
+        }
+        return array;
+    }
+
+    extension(size) {
+        const code = this.byte();
+        const data = this.bytes(size);
+        return this._callback(code, data);
+    }
+
+    seek(position) {
+        this._position = position;
+    }
+
+    skip(offset) {
+        this._position += offset;
+        if (this._position > this._buffer.length) {
+            throw new msgpack.Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
+        }
+    }
+
+    bytes(size) {
+        const data = this._buffer.subarray(this._position, this._position + size);
+        this._position += size;
+        return data;
+    }
+
+    byte() {
+        const position = this._position;
+        this.skip(1);
+        return this._buffer[position];
+    }
+
+    uint16() {
+        const position = this._position;
+        this.skip(2);
+        return this._view.getUint16(position);
+    }
+
+    uint32() {
+        const position = this._position;
+        this.skip(4);
+        return this._view.getUint32(position);
+    }
+
+    uint64() {
+        const position = this._position;
+        this.skip(8);
+        return this._view.getUint64(position);
+    }
+
+    int8() {
+        const position = this._position;
+        this.skip(1);
+        return this._view.getInt8(position);
+    }
+
+    int16() {
+        const position = this._position;
+        this.skip(2);
+        return this._view.getInt16(position);
+    }
+
+    int32() {
+        const position = this._position;
+        this.skip(4);
+        return this._view.getInt32(position);
+    }
+
+    int64() {
+        const position = this._position;
+        this.skip(8);
+        return this._view.getInt64(position);
+    }
+
+    float32() {
+        const position = this._position;
+        this.skip(4);
+        return this._view.getFloat32(position);
+    }
+
+    float64() {
+        const position = this._position;
+        this.skip(8);
+        return this._view.getFloat64(position);
+    }
+
+    string(size) {
+        const buffer = this.bytes(size);
+        return this._decoder.decode(buffer);
+    }
+};
+
+msgpack.Error = class extends Error {
+
+    constructor(message) {
+        super(message);
+        this.name = 'MessagePack Error';
+    }
+};
+
+if (typeof module !== 'undefined' && typeof module.exports === 'object') {
+    module.exports.BinaryReader = msgpack.BinaryReader;
+}

+ 1 - 0
source/view.js

@@ -1561,6 +1561,7 @@ view.ModelFactoryService = class {
         this.register('./mlnet', [ '.zip' ]);
         this.register('./acuity', [ '.json' ]);
         this.register('./imgdnn', [ '.dnn', 'params', '.json' ]);
+        this.register('./flax', [ '.msgpack' ]);
         this.register('./om', [ '.om', '.onnx', '.pb', '.engine' ]);
     }