瀏覽代碼

[BERT/PyT] fix onnx export (#689)

Sharath T S 5 年之前
父節點
當前提交
482fe9ac8a
共有 1 個文件被更改,包括 6 次插入2 次删除
  1. 6 2
      PyTorch/LanguageModeling/BERT/modeling.py

+ 6 - 2
PyTorch/LanguageModeling/BERT/modeling.py

@@ -287,7 +287,9 @@ class BertNonFusedLayerNorm(nn.Module):
 
     def forward(self, x):
         u = x.mean(-1, keepdim=True)
-        s = (x - u).pow(2).mean(-1, keepdim=True)
+        s = (x - u)
+        s = s * s
+        s = s.mean(-1, keepdim=True)
         x = (x - u) / torch.sqrt(s + self.variance_epsilon)
         return self.weight * x + self.bias
 
@@ -323,7 +325,9 @@ class BertLayerNorm(Module):
             x = self.fused_layer_norm(x)
         else:
             u = x.mean(-1, keepdim=True)
-            s = (x - u).pow(2).mean(-1, keepdim=True)
+            s = (x - u)
+            s = s * s
+            s = s.mean(-1, keepdim=True)
             x = (x - u) / torch.sqrt(s + self.eps)
             x = self.weight * x + self.bias
         return x