Lutz Roeder 5 rokov pred
rodič
commit
1c0d8bd6e2
3 zmenil súbory, kde vykonal 257 pridanie a 0 odobranie
  1. 1 0
      source/view.js
  2. 250 0
      source/weka.js
  3. 6 0
      test/models.json

+ 1 - 0
source/view.js

@@ -1245,6 +1245,7 @@ view.ModelFactoryService = class {
         this.register('./npz', [ '.npz', '.h5', '.hd5', '.hdf5' ]);
         this.register('./dl4j', [ '.zip' ]);
         this.register('./mlnet', [ '.zip' ]);
+        this.register('./weka', [ '.model' ]);
     }
 
     register(id, extensions) {

+ 250 - 0
source/weka.js

@@ -0,0 +1,250 @@
+/* jshint esversion: 6 */
+
+// Experimental
+
+var weka = weka || {};
+var json = json || require('./json');
+var java = {};
+
+weka.ModelFactory = class {
+
+    match(context) {
+        try {
+            const reader = new java.io.InputObjectStream(context.buffer);
+            const obj = reader.read();
+            if (obj && obj.$class && obj.$class.name) {
+                return true;
+            }
+        }
+        catch (err) {
+            // continue regardless of error
+        }
+        return false;
+    }
+
+    open(context, host) {
+        return Promise.resolve().then(() => {
+            const reader = new java.io.InputObjectStream(context.buffer);
+            const obj = reader.read();
+            throw new weka.Error("Unsupported type '" + obj.$class.name + "'.");
+        });
+    }
+};
+
+weka.Error = class extends Error {
+
+    constructor(message) {
+        super(message);
+        this.name = 'Error loading Weka model.';
+    }
+};
+
+java.io = {};
+
+java.io.InputObjectStream = class {
+
+    constructor(buffer) {
+        // Object Serialization Stream Protocol
+        // https://www.cis.upenn.edu/~bcpierce/courses/629/jdkdocs/guide/serialization/spec/protocol.doc.html
+        this._reader = new java.io.InputObjectStream.BinaryReader(buffer);
+        this._references = [];
+        if (buffer.length < 5) {
+            throw new java.io.Error('Invalid stream size');
+        }
+        const signature = this._reader.uint16();
+        if (signature !== 0xACED) {
+            throw new java.io.Error('Invalid signature.');
+        }
+        const version = this._reader.uint16();
+        if (version !== 0x0005) {
+            throw new java.io.Error("Unsupported version '" + version + "'.");
+        }
+    }
+
+    read() {
+        return this._object();
+    }
+
+    _object() {
+        const code = this._reader.byte();
+        switch (code) {
+            case 0x73: { // TC_OBJECT
+                const obj = {};
+                obj.$class = this._classDesc();
+                this._newHandle(obj);
+                this._classData(obj);
+                return obj;
+            }
+            case 0x74: { // TC_STRING
+                return this._newString(false);
+            }
+        }
+        throw new java.io.Error("Unsupported code '" + code + "'.");
+    }
+
+    _classDesc() {
+        const code = this._reader.byte();
+        switch (code) {
+            case 0x72: // TC_CLASSDESC
+                this._reader.skip(-1);
+                return this._newClassDesc();
+            case 0x71: // TC_REFERENCE
+                return this._references[this._reader.uint32() - 0x7e0000];
+            case 0x70: // TC_NULL
+                this._reader.byte();
+                return null;
+        }
+        throw new java.io.Error("Unsupported code '" + code + "'.");
+    }
+
+    _newClassDesc() {
+        const code = this._reader.byte();
+        switch (code) {
+            case 0x72: { // TC_CLASSDESC
+                const classDesc = {};
+                classDesc.name = this._reader.string(),
+                classDesc.id = this._reader.uint64().toString();
+                this._newHandle(classDesc);
+                classDesc.flags = this._reader.byte();
+                classDesc.fields = [];
+                const count = this._reader.uint16();
+                for (let i = 0; i < count; i++) {
+                    const field = {};
+                    field.type = String.fromCharCode(this._reader.byte());
+                    field.name = this._reader.string();
+                    if (field.type === '[' || field.type === 'L') {
+                        field.classname = this._object();
+                    }
+                    classDesc.fields.push(field);
+                }
+                if (this._reader.byte() !== 0x78) {
+                    throw new java.io.Error('Expected TC_ENDBLOCKDATA.');
+                }
+                classDesc.superClass = this._classDesc();
+                return classDesc;
+            }
+            case 0x7D: // TC_PROXYCLASSDESC
+                break;
+        }
+        throw new java.io.Error("Unsupported code '" + code + "'.");
+    }
+
+    _classData(obj) {
+        /*
+        const classname = obj.$class.name;
+        let flags = obj.$class.flags;
+        let superClass = obj.$class.superClass;
+        while (superClass) {
+            flags |= superClass.flags;
+            superClass = superClass.superClass;
+        }
+        if (flags & 0x02) { // SC_SERIALIZABLE
+            debugger;
+            var customObject = objects[classname];
+            var hasReadObjectMethod = customObject && customObject.readObject;
+            if (flags & 0x01) { // SC_WRITE_METHOD
+                if (!hasReadObjectMethod) {
+                    throw new Error('Class "'+ classname + '" dose not implement readObject()');
+                }
+                customObject.readObject(this, obj);
+                if (this._reader.byte() !== 0x78) { // TC_ENDBLOCKDATA
+                    throw new java.io.Error('Expected TC_ENDBLOCKDATA.');
+                }
+            }
+            else {
+                if (hasReadObjectMethod) {
+                    customObject.readObject(this, obj);
+                    if (this._reader.byte() !== 0x78) { // TC_ENDBLOCKDATA
+                        throw new java.io.Error('Expected TC_ENDBLOCKDATA.');
+                    }
+                }
+                else {
+                    this._nowrclass(obj);
+                }
+            }
+        }
+        else if (flags & 0x04) { // SC_EXTERNALIZABLE
+            if (flags & 0x08) { // SC_BLOCK_DATA
+                this._objectAnnotation(obj);
+            }
+            else {
+                this._externalContents();
+            }
+        }
+        else {
+            throw new Error('Illegal flags: ' + flags);
+        }
+        */
+    }
+
+    _newString(long) {
+        const value = this._reader.string(long);
+        this._newHandle(value);
+        return value;
+    }
+
+    _newHandle(obj) {
+        this._references.push(obj);
+    }
+};
+
+java.io.InputObjectStream.BinaryReader = class {
+
+    constructor(buffer) {
+        this._buffer = buffer;
+        this._position = 0;
+        this._length = buffer.length;
+        this._view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
+        this._decoder = new TextDecoder('utf-8');
+    }
+
+    skip(offset) {
+        this._position += offset;
+        if (this._position > this._end) {
+            throw new java.io.Error('Expected ' + (this._position - this._end) + ' more bytes. The file might be corrupted. Unexpected end of file.');
+        }
+    }
+
+    byte() {
+        const position = this._position;
+        this.skip(1);
+        return this._buffer[position];
+    }
+
+    uint16() {
+        const position = this._position;
+        this.skip(2);
+        return this._view.getUint16(position, false);
+    }
+
+    uint32() {
+        const position = this._position;
+        this.skip(4);
+        return this._view.getUint32(position, false);
+    }
+
+    uint64() {
+        const position = this._position;
+        this.skip(8);
+        return this._view.getUint64(position, false);
+    }
+
+    string(long) {
+        const size = long ? this.uint64().toNumber() : this.uint16();
+        const position = this._position;
+        this.skip(size);
+        return this._decoder.decode(this._buffer.subarray(position, this._position));
+    }
+};
+
+java.io.Error = class extends Error {
+
+    constructor(message) {
+        super(message);
+        this.name = 'Error loading Object Serialization Stream Protocol.';
+    }
+};
+
+if (typeof module !== 'undefined' && typeof module.exports === 'object') {
+    module.exports.ModelFactory = weka.ModelFactory;
+}

+ 6 - 0
test/models.json

@@ -5680,5 +5680,11 @@
     "target": "sample_unpruned_mobilenet_v2.tar.gz",
     "source": "https://nvidia.box.com/shared/static/8oqvmd79llr6lq1fr43s4fu1ph37v8nt.gz",
     "format": "UFF v1"
+  },
+  {
+    "type":   "weka",
+    "target": "j48model.model",
+    "source": "https://raw.githubusercontent.com/PTaati/wekaTree2python/master/j48model.model",
+    "error":  "Unsupported type 'weka.classifiers.trees.J48' in 'j48model.model'."
   }
 ]