Procházet zdrojové kódy

UFF DataType support (#511)

Lutz Roeder před 5 roky
rodič
revize
e449a5be69
2 změnil soubory, kde provedl 71 přidání a 18 odebrání
  1. 50 1
      source/uff-metadata.json
  2. 21 17
      source/uff.js

+ 50 - 1
source/uff-metadata.json

@@ -107,7 +107,10 @@
   {
     "name": "Concat",
     "schema": {
-      "category": "Tensor"
+      "category": "Tensor",
+      "inputs": [
+        { "name": "input", "list": true }
+      ]
     }
   },
   {
@@ -125,5 +128,51 @@
         { "name": "indices" }
       ]
     }
+  },
+  {
+    "name": "Stack",
+    "schema": {
+      "inputs": [
+        { "name": "input" },
+        { "name": "?" },
+        { "name": "?" },
+        { "name": "?" }
+      ]
+    }
+  },
+  {
+    "name": "Shape",
+    "schema": {
+      "inputs": [
+        { "name": "input" }
+      ]
+    }
+  },
+  {
+    "name": "_FlattenConcat_TRT",
+    "schema": {
+      "category": "Tensor",
+      "inputs": [
+        { "name": "inputs", "list": true }
+      ]
+    }
+  },
+  {
+    "name": "_NMS_TRT",
+    "schema": {
+      "inputs": [
+        { "name": "input" },
+        { "name": "?" },
+        { "name": "?" }
+      ]
+    }
+  },
+  {
+    "name": "_GridAnchor_TRT",
+    "schema": {
+      "inputs": [
+        { "name": "input" }
+      ]
+    }
   }
 ]

+ 21 - 17
source/uff.js

@@ -127,7 +127,7 @@ uff.Graph = class {
                     fields[field.key] = field.value;
                 }
                 if (fields.dtype && fields.shape && fields.values) {
-                    const tensor = new uff.Tensor(fields.dtype, fields.shape, fields.values);
+                    const tensor = new uff.Tensor(fields.dtype.dtype, fields.shape, fields.values);
                     args.set(node.id, new uff.Argument(node.id, tensor.type, tensor));
                     graph.nodes.splice(i, 1);
                 }
@@ -137,7 +137,7 @@ uff.Graph = class {
                 for (const field of node.fields) {
                     fields[field.key] = field.value;
                 }
-                const type = fields.dtype && fields.shape ? new uff.TensorType(fields.dtype, fields.shape) : null;
+                const type = fields.dtype && fields.shape ? new uff.TensorType(fields.dtype.dtype, fields.shape) : null;
                 args.set(node.id, new uff.Argument(node.id, type, null));
             }
         }
@@ -231,8 +231,8 @@ uff.Node = class {
             let inputIndex = 0;
             if (schema && schema.inputs) {
                 for (const inputSchema of schema.inputs) {
-                    if (inputIndex < node.inputs.length || inputSchema.option != 'optional') {
-                        const inputCount = (inputSchema.option == 'variadic') ? (node.input.length - inputIndex) : 1;
+                    if (inputIndex < node.inputs.length || inputSchema.optional !== true) {
+                        const inputCount = inputSchema.list ? (node.inputs.length - inputIndex) : 1;
                         const inputArguments = node.inputs.slice(inputIndex, inputIndex + inputCount).map((id) => {
                             return args.get(id);
                         });
@@ -286,18 +286,21 @@ uff.Attribute = class {
     constructor(metadata, name, value) {
         this._name = name;
         switch(value.type) {
-            case 's':  this._value = value.s; this._type = 'string'; break;
-            case 's_list':  this._value = value.s_list; this._type = 'string[]'; break;
-            case 'd':  this._value = value.d; this._type = 'float64'; break;
-            case 'd_list':  this._value = value.d_list.val; this._type = 'float64[]'; break;
-            case 'i':  this._value = value.i; this._type = 'int64'; break;
-            case 'i_list':  this._value = value.i_list.val; this._type = 'int64[]'; break;
-            case 'b':  this._value = value.b; this._type = 'boolean'; break;
+            case 's': this._value = value.s; this._type = 'string'; break;
+            case 's_list': this._value = value.s_list; this._type = 'string[]'; break;
+            case 'd': this._value = value.d; this._type = 'float64'; break;
+            case 'd_list': this._value = value.d_list.val; this._type = 'float64[]'; break;
+            case 'b': this._value = value.b; this._type = 'boolean'; break;
             case 'b_list': this._value = value.b_list; this._type = 'boolean[]'; break;
+            case 'i': this._value = value.i; this._type = 'int64'; break;
+            case 'i_list': this._value = value.i_list.val; this._type = 'int64[]'; break;
             case 'blob': this._value = value.blob; break;
-            case 'dtype': this._value = new uff.TensorType(value, null).dataType; break;
+            case 'ref': this._value = value.ref; this._type = 'ref'; break;
+            case 'dtype': this._value = new uff.TensorType(value.dtype, null).dataType; this._type = 'uff.DataType'; break;
+            case 'dtype_list': this._value = value.dtype_list.map((type) => new uff.TensorType(type, null).dataType); this._type = 'uff.DataType[]'; break;
+            case 'dim_orders': this._value = value.dim_orders; break;
             case 'dim_orders_list': this._value = value.dim_orders_list.val; break;
-            default: throw new uff.Error("Unknown attribute '" + name + "'format '" + JSON.stringify(value) + "'.");
+            default: throw new uff.Error("Unknown attribute value '" + JSON.stringify(value) + "'.");
         }
     }
 
@@ -312,6 +315,10 @@ uff.Attribute = class {
     get value() {
         return this._value;
     }
+
+    get visible() {
+        return true;
+    }
 };
 
 uff.Tensor = class {
@@ -447,10 +454,7 @@ uff.Tensor = class {
 uff.TensorType = class {
 
     constructor(dataType, shape) {
-        if (dataType.type !== 'dtype') {
-            throw new uff.Error("Unknown data type format '" + JSON.stringify(dataType.type) + "'.");
-        }
-        switch (dataType.dtype) {
+        switch (dataType) {
             case uff.proto.DataType.DT_INT8: this._dataType = 'int8'; break;
             case uff.proto.DataType.DT_INT16: this._dataType = 'int16'; break;
             case uff.proto.DataType.DT_INT32: this._dataType = 'int32'; break;