Răsfoiți Sursa

Fix TensorFlow s.filter issue

Lutz Roeder 6 ani în urmă
părinte
comite
1ec92207e3
1 a modificat fișierele cu 15 adăugiri și 9 ștergeri
  1. 15 9
      src/tf.js

+ 15 - 9
src/tf.js

@@ -234,8 +234,7 @@ tf.ModelFactory = class {
                         return tf.TensorBundle.open(buffer, identifier, context, host).then((bundle) => {
                             return tf.ModelFactory._openModel(identifier, host, metadata, saved_model, format, producer, bundle);
                         });
-                    }).catch((error) => {
-                        host.exception(error, false);
+                    }).catch(() => {
                         return tf.ModelFactory._openModel(identifier, host, metadata, saved_model, format, producer, null);
                     })
                 }
@@ -944,11 +943,11 @@ tf.Attribute = class {
             this._value = new tf.TensorShape(value.shape);
         }
         else if (Object.prototype.hasOwnProperty.call(value, 's')) {
-            if (typeof value.s === 'string'){
+            if (typeof value.s === 'string') {
                 this._value = value.s;
             }
-            else if (value.s.filter(c => c <= 32 && c >= 128).length == 0) {
-                this._value = tf.Metadata.textDecoder.decode(value.s);
+            else if (ArrayBuffer.isView(value.s)) {
+                this._value = (value.s.length === 0) ? '' : (value.s.filter(c => c <= 32 && c >= 128).length === 0) ? tf.Metadata.textDecoder.decode(value.s) : Array.from(value.s);
             }
             else {
                 this._value = value.s;
@@ -956,13 +955,17 @@ tf.Attribute = class {
         }
         else if (Object.prototype.hasOwnProperty.call(value, 'list')) {
             let list = value.list;
-            this._value = [];
             if (list.s && list.s.length > 0) {
                 this._value = list.s.map((s) => {
-                    if (s.filter(c => c <= 32 && c >= 128).length == 0) {
-                        return tf.Metadata.textDecoder.decode(value.s);
+                    if (typeof s === 'string') {
+                        return s;
+                    }
+                    else if (ArrayBuffer.isView(s)) {
+                        return (s.length === 0) ? '' : (s.filter(c => c <= 32 && c >= 128).length === 0) ? tf.Metadata.textDecoder.decode(s) : Array.from(s);
+                    }
+                    else {
+                        return s;
                     }
-                    return s.map(v => v.toString()).join(', ');
                 });
             }
             else if (list.i && list.i.length > 0) {
@@ -979,6 +982,9 @@ tf.Attribute = class {
                 this._type = 'shape[]';
                 this._value = list.shape.map((shape) => new tf.TensorShape(shape));
             }
+            else {
+                this._value = [];
+            }
         }
         else if (Object.prototype.hasOwnProperty.call(value, 'func')) {
             const func = value.func;