|
|
@@ -89,7 +89,7 @@ pytorch.Graph = class {
|
|
|
if (pytorch.Utility.isTensor(obj)) {
|
|
|
let parameter = obj;
|
|
|
parameter.__parent__ = module;
|
|
|
- if (!parameter.initializer) {
|
|
|
+ if (!parameter.initializer && parameter.storage) {
|
|
|
parameter.initializer = new pytorch.Tensor(parameter.name, parameter, true);
|
|
|
}
|
|
|
if (parameter.__outputs__ && parameter.__outputs__.length == 1) {
|
|
|
@@ -1463,17 +1463,17 @@ pytorch.Execution = class {
|
|
|
this._registerFunction('ops.prim.NumToTensor', function(value) {
|
|
|
return { __module__: 'torch', __name__: 'Tensor', value: value }; // TODO
|
|
|
});
|
|
|
- this._registerFunction('ops.prim.shape', function(/* value */) {
|
|
|
- return undefined; // TODO
|
|
|
+ this._registerFunction('ops.prim.shape', function(value) {
|
|
|
+ return value.size;
|
|
|
});
|
|
|
this._registerFunction('ops.quantized.conv_prepack', function(/* weight, bias, stride, padding, dilation, groups */) {
|
|
|
- return { __module__: 'torch', __name__: '__conv_prepack__' }; // TODO
|
|
|
+ return { __module__: 'torch', __name__: 'Tensor', __origin__: 'ops.quantized.conv_prepack' }; // TODO
|
|
|
});
|
|
|
this._registerFunction('ops.quantized.conv2d_prepack', function(/* weight, bias, stride, padding, dilation, groups */) {
|
|
|
- return { __module__: 'torch', __name__: '__conv2d_prepack__' }; // TODO
|
|
|
+ return { __module__: 'torch', __name__: 'Tensor', __origin__: 'ops.quantized.conv2d_prepack' }; // TODO
|
|
|
});
|
|
|
this._registerFunction('ops.quantized.linear_prepack', function(/* weight, bias */) {
|
|
|
- return { __module__: 'torch', __name__: '__linear_prepack__' }; // TODO
|
|
|
+ return { __module__: 'torch', __name__: 'Tensor', __origin__: 'ops.quantized.linear_prepack' }; // TODO
|
|
|
});
|
|
|
|
|
|
this._registerFunction('ops.prim.RaiseException', function(message) {
|
|
|
@@ -1546,7 +1546,10 @@ pytorch.Execution = class {
|
|
|
this._registerFunction('torch.jit._pickle.build_tensorlist', function(data) {
|
|
|
return data;
|
|
|
});
|
|
|
- this._registerFunction('torch.len', function(/* value */) {
|
|
|
+ this._registerFunction('torch.len', function(value) {
|
|
|
+ if (value) {
|
|
|
+ return value.length;
|
|
|
+ }
|
|
|
return undefined;
|
|
|
});
|
|
|
this._registerFunction('torch.list_with_default', function(size /*, defaults */) {
|
|
|
@@ -1556,7 +1559,7 @@ pytorch.Execution = class {
|
|
|
if (typeof left === 'number' && typeof right === 'number') {
|
|
|
return left < right;
|
|
|
}
|
|
|
- throw new pytorch.Error('Unknown expression type.');
|
|
|
+ throw new pytorch.Error('Unknown torch.lt expression type.');
|
|
|
});
|
|
|
this._registerFunction('torch.mul', function(left, right) {
|
|
|
if (typeof left === 'number' && typeof right === 'number') {
|
|
|
@@ -1565,13 +1568,16 @@ pytorch.Execution = class {
|
|
|
if (pytorch.Utility.isTensor(left) && pytorch.Utility.isTensor(right)) {
|
|
|
return { __module__: 'torch', __name__: 'Tensor', __origin__: 'torch.mul' };
|
|
|
}
|
|
|
- throw new pytorch.Error('Unknown expression type.');
|
|
|
+ throw new pytorch.Error('Unknown torch.mul expression type.');
|
|
|
});
|
|
|
this._registerFunction('torch.ne', function(left, right) {
|
|
|
if (typeof left === 'number' && typeof right === 'number') {
|
|
|
return left !== right;
|
|
|
}
|
|
|
- throw new pytorch.Error('Unknown expression type.');
|
|
|
+ if (left === undefined && typeof right === 'number') {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ throw new pytorch.Error('Unknown torch.ne expression type.');
|
|
|
});
|
|
|
this._registerFunction('torch.q_scale', function(/* tensor */) {
|
|
|
return -1; // TODO
|
|
|
@@ -2829,6 +2835,7 @@ pytorch.Container.Zip = class {
|
|
|
case 'torch.cat':
|
|
|
case 'torch.conv2d':
|
|
|
case 'torch.flatten':
|
|
|
+ case 'torch.quantize_per_tensor':
|
|
|
case 'torch.relu_':
|
|
|
case 'torch.dropout': {
|
|
|
parameter.size = [ undefined, undefined, undefined, undefined ];
|