Lutz Roeder 4 лет назад
Родитель
Сommit
c53d226eac
1 измененных файлов с 112 добавлено и 1 удалено
  1. 112 1
      source/tf-proto.js

+ 112 - 1
source/tf-proto.js

@@ -850,6 +850,7 @@ $root.tensorflow.FunctionDefLibrary = class FunctionDefLibrary {
     constructor() {
         this["function"] = [];
         this.gradient = [];
+        this.registered_gradients = [];
     }
 
     static decode(reader, length) {
@@ -864,6 +865,9 @@ $root.tensorflow.FunctionDefLibrary = class FunctionDefLibrary {
                 case 2:
                     message.gradient.push($root.tensorflow.GradientDef.decode(reader, reader.uint32()));
                     break;
+                case 3:
+                    message.registered_gradients.push($root.tensorflow.RegisteredGradient.decode(reader, reader.uint32()));
+                    break;
                 default:
                     reader.skipType(tag & 7);
                     break;
@@ -884,6 +888,9 @@ $root.tensorflow.FunctionDefLibrary = class FunctionDefLibrary {
                 case "gradient":
                     message.gradient.push($root.tensorflow.GradientDef.decodeText(reader));
                     break;
+                case "registered_gradients":
+                    message.registered_gradients.push($root.tensorflow.RegisteredGradient.decodeText(reader));
+                    break;
                 default:
                     reader.field(tag, message);
                     break;
@@ -1067,6 +1074,55 @@ $root.tensorflow.GradientDef = class GradientDef {
 $root.tensorflow.GradientDef.prototype.function_name = "";
 $root.tensorflow.GradientDef.prototype.gradient_func = "";
 
+$root.tensorflow.RegisteredGradient = class RegisteredGradient {
+
+    constructor() {
+    }
+
+    static decode(reader, length) {
+        const message = new $root.tensorflow.RegisteredGradient();
+        const end = length !== undefined ? reader.position + length : reader.length;
+        while (reader.position < end) {
+            const tag = reader.uint32();
+            switch (tag >>> 3) {
+                case 1:
+                    message.gradient_func = reader.string();
+                    break;
+                case 2:
+                    message.registered_op_type = reader.string();
+                    break;
+                default:
+                    reader.skipType(tag & 7);
+                    break;
+            }
+        }
+        return message;
+    }
+
+    static decodeText(reader) {
+        const message = new $root.tensorflow.RegisteredGradient();
+        reader.start();
+        while (!reader.end()) {
+            const tag = reader.tag();
+            switch (tag) {
+                case "gradient_func":
+                    message.gradient_func = reader.string();
+                    break;
+                case "registered_op_type":
+                    message.registered_op_type = reader.string();
+                    break;
+                default:
+                    reader.field(tag, message);
+                    break;
+            }
+        }
+        return message;
+    }
+};
+
+$root.tensorflow.RegisteredGradient.prototype.gradient_func = "";
+$root.tensorflow.RegisteredGradient.prototype.registered_op_type = "";
+
 $root.tensorflow.AttrValue = class AttrValue {
 
     constructor() {
@@ -2510,7 +2566,7 @@ $root.tensorflow.SavedObject = class SavedObject {
     }
 
     get kind() {
-        $root.tensorflow.SavedObject.kindSet = $root.tensorflow.SavedObject.kindSet || new Set([ "user_object", "asset", "function", "variable", "bare_concrete_function", "constant", "resource"]);
+        $root.tensorflow.SavedObject.kindSet = $root.tensorflow.SavedObject.kindSet || new Set([ "user_object", "asset", "function", "variable", "bare_concrete_function", "constant", "resource", "captured_tensor"]);
         return Object.keys(this).find((key) => $root.tensorflow.SavedObject.kindSet.has(key) && this[key] != null);
     }
 
@@ -2547,6 +2603,9 @@ $root.tensorflow.SavedObject = class SavedObject {
                 case 10:
                     message.resource = $root.tensorflow.SavedResource.decode(reader, reader.uint32());
                     break;
+                case 12:
+                    message.captured_tensor = $root.tensorflow.CapturedTensor.decode(reader, reader.uint32());
+                    break;
                 case 11:
                     reader.entry(message.saveable_objects, () => reader.string(), () => $root.tensorflow.SaveableObject.decode(reader, reader.uint32()));
                     break;
@@ -2591,6 +2650,9 @@ $root.tensorflow.SavedObject = class SavedObject {
                 case "resource":
                     message.resource = $root.tensorflow.SavedResource.decodeText(reader);
                     break;
+                case "captured_tensor":
+                    message.captured_tensor = $root.tensorflow.CapturedTensor.decodeText(reader);
+                    break;
                 case "saveable_objects":
                     reader.entry(message.saveable_objects, () => reader.string(), () => $root.tensorflow.SaveableObject.decodeText(reader));
                     break;
@@ -2750,6 +2812,55 @@ $root.tensorflow.SavedFunction = class SavedFunction {
 
 $root.tensorflow.SavedFunction.prototype.function_spec = null;
 
+$root.tensorflow.CapturedTensor = class CapturedTensor {
+
+    constructor() {
+    }
+
+    static decode(reader, length) {
+        const message = new $root.tensorflow.CapturedTensor();
+        const end = length !== undefined ? reader.position + length : reader.length;
+        while (reader.position < end) {
+            const tag = reader.uint32();
+            switch (tag >>> 3) {
+                case 1:
+                    message.name = reader.string();
+                    break;
+                case 2:
+                    message.concrete_function = reader.string();
+                    break;
+                default:
+                    reader.skipType(tag & 7);
+                    break;
+            }
+        }
+        return message;
+    }
+
+    static decodeText(reader) {
+        const message = new $root.tensorflow.CapturedTensor();
+        reader.start();
+        while (!reader.end()) {
+            const tag = reader.tag();
+            switch (tag) {
+                case "name":
+                    message.name = reader.string();
+                    break;
+                case "concrete_function":
+                    message.concrete_function = reader.string();
+                    break;
+                default:
+                    reader.field(tag, message);
+                    break;
+            }
+        }
+        return message;
+    }
+};
+
+$root.tensorflow.CapturedTensor.prototype.name = "";
+$root.tensorflow.CapturedTensor.prototype.concrete_function = "";
+
 $root.tensorflow.SavedConcreteFunction = class SavedConcreteFunction {
 
     constructor() {