|
|
@@ -287,7 +287,9 @@ class BertNonFusedLayerNorm(nn.Module):
|
|
|
|
|
|
def forward(self, x):
|
|
|
u = x.mean(-1, keepdim=True)
|
|
|
- s = (x - u).pow(2).mean(-1, keepdim=True)
|
|
|
+ s = (x - u)
|
|
|
+ s = s * s
|
|
|
+ s = s.mean(-1, keepdim=True)
|
|
|
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
|
|
return self.weight * x + self.bias
|
|
|
|
|
|
@@ -323,7 +325,9 @@ class BertLayerNorm(Module):
|
|
|
x = self.fused_layer_norm(x)
|
|
|
else:
|
|
|
u = x.mean(-1, keepdim=True)
|
|
|
- s = (x - u).pow(2).mean(-1, keepdim=True)
|
|
|
+ s = (x - u)
|
|
|
+ s = s * s
|
|
|
+ s = s.mean(-1, keepdim=True)
|
|
|
x = (x - u) / torch.sqrt(s + self.eps)
|
|
|
x = self.weight * x + self.bias
|
|
|
return x
|