|
@@ -1644,12 +1644,22 @@ python.Execution = class {
|
|
|
const numpy = this.register('numpy');
|
|
const numpy = this.register('numpy');
|
|
|
this.register('pickle');
|
|
this.register('pickle');
|
|
|
this.register('sklearn');
|
|
this.register('sklearn');
|
|
|
|
|
+ const torch = this.register('torch');
|
|
|
|
|
+ const torch_storage = this.register('torch.storage');
|
|
|
|
|
+ const torch_nn_parameter = this.register('torch.nn.parameter');
|
|
|
|
|
+ this.register('torch.ops');
|
|
|
|
|
+ this.register('torch.ops.torchvision');
|
|
|
|
|
+ this.register('torch.ops.torchaudio');
|
|
|
|
|
+ this.register('torch.ops._caffe2');
|
|
|
|
|
+ this.register('torchvision');
|
|
|
|
|
+ this.register('__torch__');
|
|
|
this.register('sys').modules = this._modules;
|
|
this.register('sys').modules = this._modules;
|
|
|
this.register('xgboost');
|
|
this.register('xgboost');
|
|
|
this.registerType('builtins.function', class {});
|
|
this.registerType('builtins.function', class {});
|
|
|
this.registerType('builtins.method', class {});
|
|
this.registerType('builtins.method', class {});
|
|
|
this.registerType('builtins.dict', dict);
|
|
this.registerType('builtins.dict', dict);
|
|
|
this.registerType('builtins.list', class {});
|
|
this.registerType('builtins.list', class {});
|
|
|
|
|
+ this.registerType('builtins.number', class {});
|
|
|
this.registerFunction('builtins.__import__', function(name, globals, locals, fromlist, level) {
|
|
this.registerFunction('builtins.__import__', function(name, globals, locals, fromlist, level) {
|
|
|
return execution.__import__(name, globals, locals, fromlist, level);
|
|
return execution.__import__(name, globals, locals, fromlist, level);
|
|
|
});
|
|
});
|
|
@@ -2148,7 +2158,7 @@ python.Execution = class {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
else if (this.data.length != size) {
|
|
else if (this.data.length != size) {
|
|
|
- // throw new pytorch.Error('Invalid array data size.');
|
|
|
|
|
|
|
+ // throw new python.Error('Invalid array data size.');
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -3465,6 +3475,1182 @@ python.Execution = class {
|
|
|
this.registerFunction('theano.tensor.type.values_eq_approx_remove_nan', function() {
|
|
this.registerFunction('theano.tensor.type.values_eq_approx_remove_nan', function() {
|
|
|
throw new python.Error('Function not implemented.');
|
|
throw new python.Error('Function not implemented.');
|
|
|
});
|
|
});
|
|
|
|
|
+ this.registerType('torch.ao.quantization.observer._PartialWrapper', class {});
|
|
|
|
|
+ this.registerType('torch.ao.quantization.qconfig.QConfig', class {});
|
|
|
|
|
+ this.registerType('torch.ao.quantization.stubs.DeQuantStub', class {});
|
|
|
|
|
+ this.registerType('torch.ao.quantization.stubs.QuantStub', class {});
|
|
|
|
|
+ this.registerType('torch.autograd.variable.Variable', class {});
|
|
|
|
|
+ this.registerType('torch.backends.cudnn.rnn.Unserializable', class {});
|
|
|
|
|
+ this.registerType('torch.distributions.bernoulli.Bernoulli', class {});
|
|
|
|
|
+ this.registerType('torch.distributions.constraints._LowerCholesky', class {});
|
|
|
|
|
+ this.registerType('torch.distributions.constraints._Real', class {});
|
|
|
|
|
+ this.registerType('torch.distributions.multivariate_normal.MultivariateNormal', class {});
|
|
|
|
|
+ this.registerType('torch.distributions.normal.Normal', class {});
|
|
|
|
|
+ this.registerType('torch.distributions.transforms.LowerCholeskyTransform', class {});
|
|
|
|
|
+ this.registerType('torch.distributions.uniform.Uniform', class {});
|
|
|
|
|
+ this.registerType('torch.nn.backends.thnn._get_thnn_function_backend', class {});
|
|
|
|
|
+ this.registerType('torch.nn.intrinsic.modules.fused.ConvBnReLU2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.intrinsic.modules.fused.ConvReLU2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.intrinsic.modules.fused.BNReLU2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.intrinsic.qat.modules.conv_fused.ConvBnReLU2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.intrinsic.qat.modules.conv_fused.ConvReLU2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.intrinsic.quantized.modules.linear_relu.LinearReLU', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.CELU', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.ELU', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.GELU', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.GLU', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.Hardtanh', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.Hardswish', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.Hardsigmoid', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.LeakyReLU', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.LogSigmoid', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.LogSoftmax', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.Mish', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.MultiheadAttention', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.ReLU', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.ReLU6', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.PReLU', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.RReLU', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.SELU', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.Sigmoid', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.SiLU', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.Softmax', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.Softmax2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.Softplus', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.Tanh', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.Tanhshrink', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.activation.Threshold', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.batchnorm.BatchNorm1d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.batchnorm.BatchNorm2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.batchnorm.BatchNorm3d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.batchnorm.LazyBatchNorm1d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.batchnorm.SyncBatchNorm', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.container.ModuleDict', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.container.ModuleList', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.container.ParameterDict', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.container.ParameterList', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.container.Sequential', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.conv.Conv1d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.conv.Conv2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.conv.Conv3d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.conv.ConvTranspose1d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.conv.ConvTranspose2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.conv.ConvTranspose3d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.distance.CosineSimilarity', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.dropout.AlphaDropout', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.dropout.Dropout', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.dropout.Dropout2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.dropout.Dropout3d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.fold.Fold', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.fold.Unfold', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.flatten.Flatten', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.flatten.Unflatten', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.instancenorm.InstanceNorm1d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.instancenorm.InstanceNorm2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.instancenorm.InstanceNorm3d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.linear._LinearWithBias', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.linear.Bilinear', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.linear.Identity', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.linear.LazyLinear', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.linear.Linear', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.linear.NonDynamicallyQuantizableLinear', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.loss.BCELoss', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.loss.BCEWithLogitsLoss', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.loss.CrossEntropyLoss', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.loss.CTCLoss', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.loss.KLDivLoss', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.loss.L1Loss', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.loss.MarginRankingLoss', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.loss.MSELoss', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.loss.NLLLoss', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.loss.NLLLoss2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.loss.SmoothL1Loss', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.module._IncompatibleKeys', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.module.Module', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.module.PatchForward', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.normalization.CrossMapLRN2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.normalization.GroupNorm', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.normalization.LayerNorm', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.normalization.LocalResponseNorm', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.padding.ReflectionPad1d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.padding.ReflectionPad2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.padding.ReplicationPad1d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.padding.ReplicationPad2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.padding.ReplicationPad3d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.padding.ZeroPad2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.padding.ConstantPad1d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.padding.ConstantPad2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.padding.ConstantPad3d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pixelshuffle.PixelShuffle', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pixelshuffle.PixelUnshuffle', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pooling.AdaptiveAvgPool1d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pooling.AdaptiveAvgPool2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pooling.AdaptiveAvgPool3d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pooling.AdaptiveMaxPool1d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pooling.AdaptiveMaxPool2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pooling.AdaptiveMaxPool3d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pooling.AvgPool1d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pooling.AvgPool2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pooling.AvgPool3d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pooling.FractionalMaxPool2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pooling.LPPool2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pooling.MaxPool1d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pooling.MaxPool2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pooling.MaxPool3d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pooling.MaxUnpool1d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pooling.MaxUnpool2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.pooling.MaxUnpool3d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.rnn.GRU', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.rnn.GRUCell', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.rnn.LSTM', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.rnn.LSTMCell', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.rnn.RNN', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.rnn.RNNCell', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.sparse.Embedding', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.sparse.EmbeddingBag', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.transformer.Transformer', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.transformer.TransformerDecoder', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.transformer.TransformerDecoderLayer', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.transformer.TransformerEncoder', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.transformer.TransformerEncoderLayer', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.upsampling.Upsample', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.upsampling.UpsamplingBilinear2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.modules.upsampling.UpsamplingNearest2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.parallel.data_parallel.DataParallel', class {});
|
|
|
|
|
+ this.registerType('torch.nn.parallel.distributed._DDPUnevenInputsConfig', class {});
|
|
|
|
|
+ this.registerType('torch.nn.parallel.distributed.DistributedDataParallel', class {});
|
|
|
|
|
+ this.registerType('torch.nn.qat.modules.conv.Conv2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.qat.modules.linear.Linear', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.modules.activation.ReLU', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.modules.activation.LeakyReLU', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.dynamic.modules.linear.Linear', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.dynamic.modules.rnn.GRU', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.dynamic.modules.rnn.LSTM', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.dynamic.modules.rnn.LSTMCell', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.dynamic.modules.rnn.PackedParameter', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.modules.activation.ReLU6', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.modules.batchnorm.BatchNorm2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.modules.conv.Conv1d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.modules.conv.Conv2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.modules.conv.ConvTranspose2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.modules.DeQuantize', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.modules.dropout.Dropout', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.modules.functional_modules.FloatFunctional', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.modules.functional_modules.QFunctional', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.modules.linear.Linear', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.modules.linear.LinearPackedParams', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.modules.normalization.InstanceNorm2d', class {});
|
|
|
|
|
+ this.registerType('torch.nn.quantized.modules.Quantize', class {});
|
|
|
|
|
+ this.registerType('torch.nn.utils.prune.L1Unstructured', class {});
|
|
|
|
|
+ this.registerType('torch.nn.utils.spectral_norm.SpectralNorm', class {});
|
|
|
|
|
+ this.registerType('torch.nn.utils.spectral_norm.SpectralNormStateDictHook', class {});
|
|
|
|
|
+ this.registerType('torch.nn.utils.spectral_norm.SpectralNormLoadStateDictPreHook', class {});
|
|
|
|
|
+ this.registerType('torch.nn.utils.weight_norm.WeightNorm', class {});
|
|
|
|
|
+ this.registerType('torch.torch_version.TorchVersion', class extends String {});
|
|
|
|
|
+ this.registerType('torch.optim.adam.Adam', class {});
|
|
|
|
|
+ this.register('torch.optim').Adam = this._registry.get('torch.optim.adam').Adam;
|
|
|
|
|
+ this.registerType('torch.optim.adamw.AdamW', class {});
|
|
|
|
|
+ this.registerType('torch.optim.adagrad.Adagrad', class {});
|
|
|
|
|
+ this.registerType('torch.optim.adadelta.Adadelta', class {});
|
|
|
|
|
+ this.registerType('torch.optim.lr_scheduler.CosineAnnealingLR', class {});
|
|
|
|
|
+ this.registerType('torch.optim.lr_scheduler.CyclicLR', class {});
|
|
|
|
|
+ this.registerType('torch.optim.lr_scheduler.ExponentialLR', class {});
|
|
|
|
|
+ this.registerType('torch.optim.lr_scheduler.LambdaLR', class {});
|
|
|
|
|
+ this.registerType('torch.optim.lr_scheduler.MultiStepLR', class {});
|
|
|
|
|
+ this.registerType('torch.optim.lr_scheduler.OneCycleLR', class {});
|
|
|
|
|
+ this.registerType('torch.optim.lr_scheduler.ReduceLROnPlateau', class {});
|
|
|
|
|
+ this.registerType('torch.optim.lr_scheduler.StepLR', class {});
|
|
|
|
|
+ this.registerType('torch.optim.optimizer._RequiredParameter', class {});
|
|
|
|
|
+ this.registerType('torch.optim.rmsprop.RMSprop', class {});
|
|
|
|
|
+ this.registerType('torch.optim.sgd.SGD', class {});
|
|
|
|
|
+ this.registerType('torch.quantization.fake_quantize.FakeQuantize', class {});
|
|
|
|
|
+ this.registerType('torch.quantization.observer._PartialWrapper', class {});
|
|
|
|
|
+ this.registerType('torch.quantization.observer.MinMaxObserver', class {});
|
|
|
|
|
+ this.registerType('torch.quantization.observer.MovingAverageMinMaxObserver', class {});
|
|
|
|
|
+ this.registerType('torch.quantization.observer.MovingAveragePerChannelMinMaxObserver', class {});
|
|
|
|
|
+ this.registerType('torch.quantization.qconfig.QConfig', class {});
|
|
|
|
|
+ this.registerType('torch.quantization.stubs.DeQuantStub', class {});
|
|
|
|
|
+ this.registerType('torch.quantization.stubs.QuantStub', class {});
|
|
|
|
|
+ this.registerType('torch.utils.data.dataloader._MultiProcessingDataLoaderIter', class {});
|
|
|
|
|
+ this.registerType('torch.utils.data.dataloader.DataLoader', class {});
|
|
|
|
|
+ this.registerType('torch.utils.data.dataset.Subset', class {});
|
|
|
|
|
+ this.registerType('torch.utils.data.dataset.ConcatDataset', class {});
|
|
|
|
|
+ this.registerType('torch.utils.data.dataset.TensorDataset', class {});
|
|
|
|
|
+ this.registerType('torch.utils.data.sampler.BatchSampler', class {});
|
|
|
|
|
+ this.registerType('torch.utils.data.sampler.RandomSampler', class {});
|
|
|
|
|
+ this.registerType('torch.utils.data.sampler.SequentialSampler', class {});
|
|
|
|
|
+ this.registerType('torchvision.datasets.folder.ImageFolder', class {});
|
|
|
|
|
+ this.registerType('torchvision.datasets.mnist.MNIST', class {});
|
|
|
|
|
+ this.registerType('torchvision.datasets.vision.StandardTransform', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.alexnet.AlexNet', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.densenet.DenseNet', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.densenet._DenseBlock', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.densenet._DenseLayer', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.densenet._Transition', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection._utils.BalancedPositiveNegativeSampler', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection._utils.BoxCoder', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection._utils.Matcher', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection._utils.SSDMatcher', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.anchor_utils.AnchorGenerator', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.anchor_utils.DefaultBoxGenerator', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.backbone_utils.BackboneWithFPN', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.faster_rcnn.FasterRCNN', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.faster_rcnn.FastRCNNPredictor', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.faster_rcnn.TwoMLPHead', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.keypoint_rcnn.KeypointRCNN', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.keypoint_rcnn.KeypointRCNNHeads', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.keypoint_rcnn.KeypointRCNNPredictor', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.mask_rcnn.MaskRCNN', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.mask_rcnn.MaskRCNNHeads', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.mask_rcnn.MaskRCNNPredictor', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.retinanet.RetinaNetClassificationHead', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.retinanet.RetinaNetHead', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.retinanet.RetinaNetRegressionHead', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.roi_heads.RoIHeads', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.rpn.AnchorGenerator', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.rpn.RegionProposalNetwork', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.rpn.RPNHead', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.ssd.SSD', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.ssdlite.SSDLiteClassificationHead', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.ssdlite.SSDLiteFeatureExtractorMobileNet', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.ssdlite.SSDLiteHead', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.ssdlite.SSDLiteRegressionHead', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.detection.transform.GeneralizedRCNNTransform', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.efficientnet.EfficientNet', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.efficientnet.MBConv', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.googlenet.BasicConv2d', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.googlenet.GoogLeNet', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.googlenet.Inception', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.googlenet.InceptionAux', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.inception.BasicConv2d', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.inception.Inception3', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.inception.InceptionAux', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.inception.InceptionA', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.inception.InceptionB', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.inception.InceptionC', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.inception.InceptionD', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.inception.InceptionE', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.mnasnet._InvertedResidual', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.mnasnet.MNASNet', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.mobilenet.ConvBNReLU', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.mobilenet.MobileNetV2', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.mobilenet.InvertedResidual', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.mobilenetv2.ConvBNActivation', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.mobilenetv2.InvertedResidual', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.mobilenetv2.MobileNetV2', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.mobilenetv3.InvertedResidual', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.mobilenetv3.MobileNetV3', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.mobilenetv3.SqueezeExcitation', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.resnet.Bottleneck', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.resnet.BasicBlock', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.quantization.mobilenet.QuantizableInvertedResidual', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.quantization.mobilenet.QuantizableMobileNetV2', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.quantization.mobilenetv2.QuantizableInvertedResidual', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.quantization.mobilenetv2.QuantizableMobileNetV2', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.quantization.resnet.QuantizableBasicBlock', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.quantization.resnet.QuantizableBottleneck', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.quantization.resnet.QuantizableResNet', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.segmentation.deeplabv3.ASPP', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.segmentation.deeplabv3.ASPPConv', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.segmentation.deeplabv3.ASPPPooling', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.segmentation.deeplabv3.DeepLabHead', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.segmentation.deeplabv3.DeepLabV3', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.segmentation.fcn.FCN', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.segmentation.fcn.FCNHead', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.shufflenetv2.ShuffleNetV2', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.shufflenetv2.InvertedResidual', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.squeezenet.Fire', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.squeezenet.SqueezeNet', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.resnet.ResNet', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.vgg.VGG', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.video.resnet.BasicBlock', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.video.resnet.BasicStem', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.video.resnet.Conv2Plus1D', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.video.resnet.Conv3DNoTemporal', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.video.resnet.Conv3DSimple', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.video.resnet.R2Plus1dStem', class {});
|
|
|
|
|
+ this.registerType('torchvision.models.video.resnet.VideoResNet', class {});
|
|
|
|
|
+ this.registerType('torchvision.models._utils.IntermediateLayerGetter', class {});
|
|
|
|
|
+ this.registerType('torchvision.ops.deform_conv.DeformConv2d', class {});
|
|
|
|
|
+ this.registerType('torchvision.ops.feature_pyramid_network.FeaturePyramidNetwork', class {});
|
|
|
|
|
+ this.registerType('torchvision.ops.feature_pyramid_network.LastLevelMaxPool', class {});
|
|
|
|
|
+ this.registerType('torchvision.ops.feature_pyramid_network.LastLevelP6P7', class {});
|
|
|
|
|
+ this.registerType('torchvision.ops.misc.ConvNormActivation', class {});
|
|
|
|
|
+ this.registerType('torchvision.ops.misc.ConvTranspose2d', class {});
|
|
|
|
|
+ this.registerType('torchvision.ops.misc.FrozenBatchNorm2d', class {});
|
|
|
|
|
+ this.registerType('torchvision.ops.misc.SqueezeExcitation', class {});
|
|
|
|
|
+ this.registerType('torchvision.ops.poolers.LevelMapper', class {});
|
|
|
|
|
+ this.registerType('torchvision.ops.poolers.MultiScaleRoIAlign', class {});
|
|
|
|
|
+ this.registerType('torchvision.ops.stochastic_depth.StochasticDepth', class {});
|
|
|
|
|
+ this.registerType('torchvision.transforms.functional.InterpolationMode', class {});
|
|
|
|
|
+ this.registerType('torchvision.transforms.transforms.Compose', class {});
|
|
|
|
|
+ this.registerType('torchvision.transforms.transforms.CenterCrop', class {});
|
|
|
|
|
+ this.registerType('torchvision.transforms.transforms.Grayscale', class {});
|
|
|
|
|
+ this.registerType('torchvision.transforms.transforms.Normalize', class {});
|
|
|
|
|
+ this.registerType('torchvision.transforms.transforms.RandomAffine', class {});
|
|
|
|
|
+ this.registerType('torchvision.transforms.transforms.RandomCrop', class {});
|
|
|
|
|
+ this.registerType('torchvision.transforms.transforms.RandomHorizontalFlip', class {});
|
|
|
|
|
+ this.registerType('torchvision.transforms.transforms.Resize', class {});
|
|
|
|
|
+ this.registerType('torchvision.transforms.transforms.Scale', class {});
|
|
|
|
|
+ this.registerType('torchvision.transforms.transforms.ToPILImage', class {});
|
|
|
|
|
+ this.registerType('torchvision.transforms.transforms.ToTensor', class {});
|
|
|
|
|
+ this.registerFunction('torchvision.models.resnet.resnet34', function() {});
|
|
|
|
|
+ this.registerFunction('builtins.annotate', function(type, value) {
|
|
|
|
|
+ if (type === self._builtins.int) {
|
|
|
|
|
+ return Number.isInteger(value) ? value : NaN;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (type === self._builtins.float) {
|
|
|
|
|
+ return typeof value === 'number' ? value : NaN;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (type === self._builtins.number) {
|
|
|
|
|
+ // if (pytorch.Utility.isTensor(value)) {
|
|
|
|
|
+ // value.resize_([]);
|
|
|
|
|
+ // }
|
|
|
|
|
+ }
|
|
|
|
|
+ return value;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('builtins.unchecked_cast', function(type, value) {
|
|
|
|
|
+ return value;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('builtins.uninitialized', function(/* type */) {
|
|
|
|
|
+ return undefined;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('ops.prim.data', function(tensor) {
|
|
|
|
|
+ return tensor;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('ops.prim.device', function(tensor) {
|
|
|
|
|
+ return tensor.device;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('ops.prim.dtype', function(tensor) {
|
|
|
|
|
+ return tensor.dtype.scalar_type();
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('ops.prim.is_quantized', function(tensor) {
|
|
|
|
|
+ return tensor && tensor.__quantized__ === true;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('ops.prim.unchecked_unwrap_optional', function(value) {
|
|
|
|
|
+ return value;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('ops.prim.NumToTensor', function(value) {
|
|
|
|
|
+ const tensor = self.invoke('torch.Tensor', []);
|
|
|
|
|
+ tensor.value = value; // TODO
|
|
|
|
|
+ return tensor;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('ops.prim.min', function(value) {
|
|
|
|
|
+ if (Array.isArray(value)) {
|
|
|
|
|
+ return Math.min.apply(null, value);
|
|
|
|
|
+ }
|
|
|
|
|
+ return Math.min.apply(null, arguments);
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('ops.prim.max', function(value) {
|
|
|
|
|
+ if (Array.isArray(value)) {
|
|
|
|
|
+ return Math.max.apply(null, value);
|
|
|
|
|
+ }
|
|
|
|
|
+ return Math.max.apply(null, arguments);
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('ops.prim.shape', function(tensor) {
|
|
|
|
|
+ return tensor && tensor.size ? tensor.size() : undefined;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('ops.quantized.conv_prepack', function(weight, bias, stride, padding, dilation, groups) {
|
|
|
|
|
+ const params = self.invoke('__torch__.torch.classes.quantized.Conv2dPackedParamsBase', []);
|
|
|
|
|
+ params.weight = weight;
|
|
|
|
|
+ params.bias = bias;
|
|
|
|
|
+ params.stride = stride;
|
|
|
|
|
+ params.padding =padding;
|
|
|
|
|
+ params.dilation = dilation;
|
|
|
|
|
+ params.groups = groups;
|
|
|
|
|
+ return params;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('ops.quantized.conv1d_prepack', function(weight, bias, stride, padding, dilation, groups) {
|
|
|
|
|
+ const params = self.invoke('__torch__.torch.classes.quantized.Conv2dPackedParamsBase', []);
|
|
|
|
|
+ params.weight = weight;
|
|
|
|
|
+ params.bias = bias;
|
|
|
|
|
+ params.stride = stride;
|
|
|
|
|
+ params.padding =padding;
|
|
|
|
|
+ params.dilation = dilation;
|
|
|
|
|
+ params.groups = groups;
|
|
|
|
|
+ return params;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('ops.quantized.conv2d_prepack', function(weight, bias, stride, padding, dilation, groups) {
|
|
|
|
|
+ const params = self.invoke('__torch__.torch.classes.quantized.Conv2dPackedParamsBase', []);
|
|
|
|
|
+ params.weight = weight;
|
|
|
|
|
+ params.bias = bias;
|
|
|
|
|
+ params.stride = stride;
|
|
|
|
|
+ params.padding =padding;
|
|
|
|
|
+ params.dilation = dilation;
|
|
|
|
|
+ params.groups = groups;
|
|
|
|
|
+ return params;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('ops.quantized.conv3d_prepack', function(weight, bias, stride, padding, dilation, groups) {
|
|
|
|
|
+ const params = self.invoke('__torch__.torch.classes.quantized.Conv3dPackedParamsBase', []);
|
|
|
|
|
+ params.weight = weight;
|
|
|
|
|
+ params.bias = bias;
|
|
|
|
|
+ params.stride = stride;
|
|
|
|
|
+ params.padding =padding;
|
|
|
|
|
+ params.dilation = dilation;
|
|
|
|
|
+ params.groups = groups;
|
|
|
|
|
+ return params;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('ops.quantized.conv_transpose2d_prepack', function(weight, bias, stride, padding, output_padding, dilation, groups) {
|
|
|
|
|
+ const params = self.invoke('__torch__.torch.classes.quantized.Conv2dPackedParamsBase', []);
|
|
|
|
|
+ params.weight = weight;
|
|
|
|
|
+ params.bias = bias;
|
|
|
|
|
+ params.stride = stride;
|
|
|
|
|
+ params.padding =padding;
|
|
|
|
|
+ params.output_padding = output_padding;
|
|
|
|
|
+ params.dilation = dilation;
|
|
|
|
|
+ params.groups = groups;
|
|
|
|
|
+ return params;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('ops.quantized.linear_prepack', function(weight, bias) {
|
|
|
|
|
+ const params = self.invoke('__torch__.torch.classes.quantized.LinearPackedParamsBase', []);
|
|
|
|
|
+ params.weight = weight;
|
|
|
|
|
+ params.bias = bias;
|
|
|
|
|
+ return params;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('ops.prim.RaiseException', function(message) {
|
|
|
|
|
+ throw new python.Error(message);
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('builtins.range', function(start, stop, step) {
|
|
|
|
|
+ if (stop === undefined && step === undefined) {
|
|
|
|
|
+ if (Number.isInteger(start)) {
|
|
|
|
|
+ return Array(start).keys();
|
|
|
|
|
+ }
|
|
|
|
|
+ if (isNaN(start)) {
|
|
|
|
|
+ return [];
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error('Unsupported function range(' + JSON.stringify(start) + ', ' + JSON.stringify(stop) + ', ' + JSON.stringify(step) + ')');
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch._utils._rebuild_tensor', function (storage, storage_offset, size, stride) {
|
|
|
|
|
+ const name = storage.__class__.__module__ + '.' + storage.__class__.__name__.replace('Storage', 'Tensor');
|
|
|
|
|
+ const tensor = self.invoke(name, []);
|
|
|
|
|
+ tensor.__setstate__([ storage, storage_offset, size, stride ]);
|
|
|
|
|
+ return tensor;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch._utils._rebuild_tensor_v2', function (storage, storage_offset, size, stride, requires_grad, backward_hooks) {
|
|
|
|
|
+ const name = storage.__class__.__module__ + '.' + storage.__class__.__name__.replace('Storage', 'Tensor');
|
|
|
|
|
+ const tensor = self.invoke(name, []);
|
|
|
|
|
+ tensor.__setstate__([ storage, storage_offset, size, stride ]);
|
|
|
|
|
+ tensor.requires_grad = requires_grad;
|
|
|
|
|
+ tensor.backward_hooks = backward_hooks;
|
|
|
|
|
+ return tensor;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch._utils._rebuild_parameter', function(data, requires_grad, backward_hooks) {
|
|
|
|
|
+ const obj = self.invoke('torch.nn.parameter.Parameter', [ data, requires_grad ]);
|
|
|
|
|
+ obj.backward_hooks = backward_hooks;
|
|
|
|
|
+ return obj;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch._utils._rebuild_qtensor', function(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks) {
|
|
|
|
|
+ const name = storage.__class__.__module__ + '.' + storage.__class__.__name__.replace('Storage', 'Tensor');
|
|
|
|
|
+ const tensor = self.invoke(name, []);
|
|
|
|
|
+ tensor.__setstate__([ storage, storage_offset, size, stride ]);
|
|
|
|
|
+ tensor.quantizer_params = quantizer_params;
|
|
|
|
|
+ tensor.requires_grad = requires_grad;
|
|
|
|
|
+ tensor.backward_hooks = backward_hooks;
|
|
|
|
|
+ return tensor;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch._set_item', function(dict, key, value) {
|
|
|
|
|
+ dict[key] = value;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.__and__', function(left, right) {
|
|
|
|
|
+ return left && right;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.__contains__', function(dict, key) {
|
|
|
|
|
+ return dict[key] !== undefined;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.__derive_index', function(index, start, step) {
|
|
|
|
|
+ return start + index * step;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.__is__', function(left, right) {
|
|
|
|
|
+ if (left === null && right === null) {
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
|
|
+ if ((left !== null && right === null) || (left === null && right !== null)) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error("Unsupported 'torch.__is__' expression type.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.__isnot__', function(left, right) {
|
|
|
|
|
+ if (left === null && right === null) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ if ((left !== null && right === null) || (left === null && right !== null)) {
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error("Unsupported 'torch.__isnot__' expression type.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.__not__', function(value) {
|
|
|
|
|
+ if (typeof value === 'boolean') {
|
|
|
|
|
+ return !value;
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error("Unsupported 'torch.__not__' expression type.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.__range_length', function(lo, hi, step) {
|
|
|
|
|
+ if (step === 0) {
|
|
|
|
|
+ throw new python.Error('range() arg 3 must not be zero');
|
|
|
|
|
+ }
|
|
|
|
|
+ if (step > 0 && lo < hi) {
|
|
|
|
|
+ return 1 + (hi - 1 - lo) / step;
|
|
|
|
|
+ }
|
|
|
|
|
+ else if (step < 0 && lo > hi) {
|
|
|
|
|
+ return 1 + (lo - 1 - hi) / (0 - step);
|
|
|
|
|
+ }
|
|
|
|
|
+ return 0;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch._unwrap_optional', function(value) {
|
|
|
|
|
+ return value; // TODO
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.add', function(left, right) {
|
|
|
|
|
+ if (typeof left === 'number' && typeof right === 'number') {
|
|
|
|
|
+ return left * right;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (Array.isArray(left) && Array.isArray(right)) {
|
|
|
|
|
+ return left.concat(right);
|
|
|
|
|
+ }
|
|
|
|
|
+ if (typeof left === 'string' && typeof right === 'string') {
|
|
|
|
|
+ return left + right;
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error('Unsupported torch.add expression type.');
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.append', function(list, value) {
|
|
|
|
|
+ list.push(value);
|
|
|
|
|
+ return value;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.extend', function(list, value) {
|
|
|
|
|
+ list.push(...value);
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.insert', function(list, index, value) {
|
|
|
|
|
+ list.splice(index, 0, value);
|
|
|
|
|
+ return value;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.clear', function(value) {
|
|
|
|
|
+ if (Object(value) === value) {
|
|
|
|
|
+ for (const key of Object.keys(value)) {
|
|
|
|
|
+ delete value[key];
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.replace', function(value) {
|
|
|
|
|
+ return value;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.dict', function(args) {
|
|
|
|
|
+ const obj = {};
|
|
|
|
|
+ if (args) {
|
|
|
|
|
+ if (Array.isArray(args)) {
|
|
|
|
|
+ for (const pair of args) {
|
|
|
|
|
+ const key = pair[0];
|
|
|
|
|
+ const value = pair[1];
|
|
|
|
|
+ obj[key] = value;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ else {
|
|
|
|
|
+ throw new python.Error("'torch.dict' arguments not supported.");
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ return obj;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.dim', function(tensor) {
|
|
|
|
|
+ if (tensor && tensor.size) {
|
|
|
|
|
+ const size = tensor.size();
|
|
|
|
|
+ if (size) {
|
|
|
|
|
+ return size.length;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ return NaN; // TODO
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.numel', function(tensor) {
|
|
|
|
|
+ if (tensor && tensor.size) {
|
|
|
|
|
+ const size = tensor.size();
|
|
|
|
|
+ if (size) {
|
|
|
|
|
+ return size.reduce((a, b) => a * b, 1);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ return NaN;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.eq', function(left, right) {
|
|
|
|
|
+ if (typeof left === 'string' && typeof right === 'string') {
|
|
|
|
|
+ return left === right;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (typeof left === 'number' && typeof right === 'number') {
|
|
|
|
|
+ if (isNaN(left) && isNaN(right)) {
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
|
|
+ return left === right;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (left === undefined || right === undefined) {
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (Array.isArray(left) && Array.isArray(right)) {
|
|
|
|
|
+ return left.length === right.length && left.every((item, index) => item === right[index]);
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error("Unsupported 'torch.eq' expression type.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.floor', function(value) {
|
|
|
|
|
+ return Math.floor(value);
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.ceil', function(value) {
|
|
|
|
|
+ return Math.ceil(value);
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.floordiv', function(left, right) {
|
|
|
|
|
+ return Math.floor(left / right);
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.format', function() {
|
|
|
|
|
+ const args = Array.from(arguments);
|
|
|
|
|
+ const list = args.shift().split(/({}D?)/);
|
|
|
|
|
+ return list.map((text) => {
|
|
|
|
|
+ if (text === '{}' || text === '{}D') {
|
|
|
|
|
+ const arg = args.shift();
|
|
|
|
|
+ return Array.isArray(arg) ? '[' + arg.map((item) => item.toString()).join(', ') + ']' : arg ? arg.toString() : '?';
|
|
|
|
|
+ }
|
|
|
|
|
+ return text;
|
|
|
|
|
+ }).join('');
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.gt', function(left, right) {
|
|
|
|
|
+ if (typeof left === 'number' && typeof right === 'number') {
|
|
|
|
|
+ if (!isNaN(left) && !isNaN(right)) {
|
|
|
|
|
+ return left > right;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ if (isNaN(left) && !isNaN(right)) {
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error("Unsupported 'torch.gt' expression type.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.ge', function(left, right) {
|
|
|
|
|
+ if (typeof left === 'number' && typeof right === 'number') {
|
|
|
|
|
+ if (!isNaN(left) && !isNaN(right)) {
|
|
|
|
|
+ return left > right;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ if (isNaN(left) && !isNaN(right)) {
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error("Unsupported 'torch.ge' expression type.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.is_floating_point', function(tensor) {
|
|
|
|
|
+ const type = tensor.dtype.scalar_type();
|
|
|
|
|
+ return (type === 5 || type === 6 || type === 7);
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.is_grad_enabled', function() {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.set_grad_enabled', function(/* value */) {
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.jit._pickle.build_boollist', function(data) {
|
|
|
|
|
+ return data;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.jit._pickle.build_doublelist', function(data) {
|
|
|
|
|
+ return data;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.jit._pickle.build_intlist', function(data) {
|
|
|
|
|
+ return data;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.jit._pickle.build_tensorlist', function(data) {
|
|
|
|
|
+ return data;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.jit._pickle.build_tensor_from_id', function(data) {
|
|
|
|
|
+ return self.builtins.CONSTANTS['c' + data.toString()];
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.jit._pickle.restore_type_tag', function(value /*, type_str */) {
|
|
|
|
|
+ return value;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.keys', function(dict) {
|
|
|
|
|
+ return Object.keys(dict);
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.len', function(value) {
|
|
|
|
|
+ if (Array.isArray(value)) {
|
|
|
|
|
+ return value.length;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (value && value.shape && value.__len__) {
|
|
|
|
|
+ return value.__len__();
|
|
|
|
|
+ }
|
|
|
|
|
+ return NaN;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.le', function(left, right) {
|
|
|
|
|
+ if (typeof left === 'number' && typeof right === 'number') {
|
|
|
|
|
+ if (isNaN(left) || isNaN(right)) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ return left <= right;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (left === undefined || right === undefined) {
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error("Unsupported 'torch.le' expression type.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.list', function(args) {
|
|
|
|
|
+ return args;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.list_with_default', function(size /*, defaults */) {
|
|
|
|
|
+ return size;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.lt', function(left, right) {
|
|
|
|
|
+ if (typeof left === 'number' && typeof right === 'number') {
|
|
|
|
|
+ return left < right;
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error("Unsupported 'torch.lt' expression type.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.mul', function(left, right) {
|
|
|
|
|
+ if (typeof left === 'number' && typeof right === 'number') {
|
|
|
|
|
+ return left * right;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (isNaN(left) || isNaN(right)) {
|
|
|
|
|
+ return NaN;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (Array.isArray(left) && left.every((value) => typeof value === 'number') && typeof right === 'number') {
|
|
|
|
|
+ return left.map((value) => value * right);
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error("Unsupported 'torch.mul' expression type.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.div', function(left, right) {
|
|
|
|
|
+ if (typeof left === 'number' && typeof right === 'number') {
|
|
|
|
|
+ return left / right;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (isNaN(left) || isNaN(right)) {
|
|
|
|
|
+ return NaN;
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error("Unsupported 'torch.div' expression type.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.round', function(value) {
|
|
|
|
|
+ if (typeof value === 'number') {
|
|
|
|
|
+ return Math.round(value);
|
|
|
|
|
+ }
|
|
|
|
|
+ if (isNaN(value)) {
|
|
|
|
|
+ return value;
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error("Unsupported 'torch.round' expression type.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.remainder', function(left, right) {
|
|
|
|
|
+ if (typeof left === 'number' && typeof right === 'number') {
|
|
|
|
|
+ return left % right;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (isNaN(left) || isNaN(right)) {
|
|
|
|
|
+ return NaN;
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error("Unsupported 'torch.remainder' expression type.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.ne', function(left, right) {
|
|
|
|
|
+ if (typeof left === 'boolean' && typeof right === 'boolean') {
|
|
|
|
|
+ return left !== right;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (typeof left === 'number' && typeof right === 'number') {
|
|
|
|
|
+ if (isNaN(left) || isNaN(right)) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ return left !== right;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (Array.isArray(left) && Array.isArray(right) && left.length === right.length) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (typeof left === 'string' && typeof right === 'string') {
|
|
|
|
|
+ return left !== right;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (left === undefined || right === undefined) {
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error("Unsupported 'torch.ne' expression type.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.neg', function(value) {
|
|
|
|
|
+ if (typeof value === 'number') {
|
|
|
|
|
+ return -value;
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error("Unsupported 'torch.neg' expression type.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.q_scale', function(/* tensor */) {
|
|
|
|
|
+ return -1; // TODO
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.t', function(tensor) {
|
|
|
|
|
+ return tensor;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.size', function(tensor, dim) {
|
|
|
|
|
+ if (tensor && tensor.size) {
|
|
|
|
|
+ const size = tensor.size();
|
|
|
|
|
+ if (Array.isArray(size)) {
|
|
|
|
|
+ if (dim === undefined) {
|
|
|
|
|
+ return size;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (Number.isInteger(dim)) {
|
|
|
|
|
+ if (dim >= 0 && dim < size.length) {
|
|
|
|
|
+ return size[dim];
|
|
|
|
|
+ }
|
|
|
|
|
+ if (dim < 0 && -dim < size.length) {
|
|
|
|
|
+ return size[size.length + dim];
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error('Dimension out of range (expected to be in range of ' + JSON.stringify(size) + ', but got ' + JSON.stringify(dim) + ').');
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ if (Number.isInteger(dim)) {
|
|
|
|
|
+ return NaN;
|
|
|
|
|
+ }
|
|
|
|
|
+ return [];
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.slice', function(l, start, end, step) {
|
|
|
|
|
+ if (!Array.isArray(l)) {
|
|
|
|
|
+ throw new python.Error('Slicing expected array');
|
|
|
|
|
+ }
|
|
|
|
|
+ step = step || 1;
|
|
|
|
|
+ if (step !== 1) {
|
|
|
|
|
+ throw new python.Error('Slicing only supports step=1');
|
|
|
|
|
+ }
|
|
|
|
|
+ start = Math.max(0, start >= 0 ? start : l.length + start);
|
|
|
|
|
+ end = Math.min(l.length, end || Number.MAX_SAFE_INTEGER);
|
|
|
|
|
+ return l.slice(start, end);
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.sub', function(left, right) {
|
|
|
|
|
+ if (typeof left === 'number' && typeof right === 'number') {
|
|
|
|
|
+ return left - right;
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new python.Error("Unsupported 'torch.sub' expression type.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.nn.functional.gelu', function(/* input */) {
|
|
|
|
|
+ throw new python.Error("Function not implemented.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.nn.functional.leaky_relu', function(/* input */) {
|
|
|
|
|
+ throw new python.Error("Function not implemented.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.nn.functional.relu', function(/* input */) {
|
|
|
|
|
+ throw new python.Error("Function not implemented.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.nn.functional.tanh', function(/* input */) {
|
|
|
|
|
+ throw new python.Error("Function not implemented.");
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.values', function(dict) {
|
|
|
|
|
+ return Object.keys(dict).map((key) => dict[key]);
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.warn', function() {
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerFunction('torch.fx.graph_module.reduce_graph_module', function(body /*, import_block */) {
|
|
|
|
|
+ // https://github.com/pytorch/pytorch/blob/master/torch/fx/graph_module.py
|
|
|
|
|
+ return body;
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.device', class {
|
|
|
|
|
+ constructor(type, index) {
|
|
|
|
|
+ this.type = type;
|
|
|
|
|
+ if (index) {
|
|
|
|
|
+ this.index = index;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.dtype', class {
|
|
|
|
|
+ constructor(data) {
|
|
|
|
|
+ this._type = data.type;
|
|
|
|
|
+ this._data = data;
|
|
|
|
|
+ }
|
|
|
|
|
+ scalar_type() {
|
|
|
|
|
+ return this._type;
|
|
|
|
|
+ }
|
|
|
|
|
+ itemsize() {
|
|
|
|
|
+ return this._data.itemsize;
|
|
|
|
|
+ }
|
|
|
|
|
+ __reduce__() {
|
|
|
|
|
+ return this._data.name;
|
|
|
|
|
+ }
|
|
|
|
|
+ __str__() {
|
|
|
|
|
+ return 'torch.' + this._data.name;
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.qscheme', class {});
|
|
|
|
|
+ this.registerType('torch.utils.hooks.RemovableHandle', class {
|
|
|
|
|
+ __setstate__(state) {
|
|
|
|
|
+ this.hooks_dict_ref = state[0] || new Map();
|
|
|
|
|
+ this.id = state[1];
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.storage._StorageBase', class {
|
|
|
|
|
+ constructor(size, dtype) {
|
|
|
|
|
+ this._size = size;
|
|
|
|
|
+ this._dtype = dtype;
|
|
|
|
|
+ this._device = null;
|
|
|
|
|
+ }
|
|
|
|
|
+ get device() {
|
|
|
|
|
+ return this._device;
|
|
|
|
|
+ }
|
|
|
|
|
+ get dtype() {
|
|
|
|
|
+ return this._dtype;
|
|
|
|
|
+ }
|
|
|
|
|
+ element_size() {
|
|
|
|
|
+ return this._dtype.element_size;
|
|
|
|
|
+ }
|
|
|
|
|
+ size() {
|
|
|
|
|
+ return this._size;
|
|
|
|
|
+ }
|
|
|
|
|
+ get data() {
|
|
|
|
|
+ return this._cdata;
|
|
|
|
|
+ }
|
|
|
|
|
+ _set_cdata(data) {
|
|
|
|
|
+ const length = this.size() * this.dtype.itemsize();
|
|
|
|
|
+ if (length !== data.length) {
|
|
|
|
|
+ throw new python.Error('Storage data size mismatch.');
|
|
|
|
|
+ }
|
|
|
|
|
+ this._cdata = data;
|
|
|
|
|
+ }
|
|
|
|
|
+ _set_from_file(unpickler) {
|
|
|
|
|
+ const buffer = unpickler.read(8);
|
|
|
|
|
+ const size = buffer.reverse().reduce((a, b) => (a*256)+b, 0);
|
|
|
|
|
+ if (size !== this.size()) {
|
|
|
|
|
+ throw new python.Error('Storage size mismatch.');
|
|
|
|
|
+ }
|
|
|
|
|
+ const itemsize = this.dtype.itemsize();
|
|
|
|
|
+ const data = unpickler.stream(itemsize * size);
|
|
|
|
|
+ this._set_cdata(data);
|
|
|
|
|
+ }
|
|
|
|
|
+ static _new_with_file(unpickler) {
|
|
|
|
|
+ const buffer = unpickler.read(8);
|
|
|
|
|
+ const size = buffer.reverse().reduce((a, b) => (a*256)+b, 0);
|
|
|
|
|
+ const storage = new this(size);
|
|
|
|
|
+ const itemsize = storage.dtype.itemsize();
|
|
|
|
|
+ const data = unpickler.stream(itemsize * size);
|
|
|
|
|
+ storage._set_cdata(data);
|
|
|
|
|
+ return storage;
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.storage._UntypedStorage', class extends torch_storage._StorageBase {
|
|
|
|
|
+ constructor() {
|
|
|
|
|
+ super();
|
|
|
|
|
+ throw new python.Error('_UntypedStorage not implemented.');
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.storage._TypedStorage', class {
|
|
|
|
|
+ constructor() {
|
|
|
|
|
+ throw new python.Error('_TypedStorage not implemented.');
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.storage._LegacyStorage', class extends torch_storage._TypedStorage {
|
|
|
|
|
+ constructor() {
|
|
|
|
|
+ super();
|
|
|
|
|
+ throw new python.Error('_LegacyStorage not implemented.');
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.BoolStorage', class extends torch_storage._StorageBase {
|
|
|
|
|
+ constructor(size) {
|
|
|
|
|
+ super(size, torch.bool);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.ByteStorage', class extends torch_storage._StorageBase {
|
|
|
|
|
+ constructor(size) {
|
|
|
|
|
+ super(size, torch.uint8);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.CharStorage', class extends torch_storage._StorageBase {
|
|
|
|
|
+ constructor(size) {
|
|
|
|
|
+ super(size, torch.int8);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.ShortStorage', class extends torch_storage._StorageBase {
|
|
|
|
|
+ constructor(size) {
|
|
|
|
|
+ super(size, torch.int16);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.IntStorage', class extends torch_storage._StorageBase {
|
|
|
|
|
+ constructor(size) {
|
|
|
|
|
+ super(size, torch.int32);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.LongStorage', class extends torch_storage._StorageBase {
|
|
|
|
|
+ constructor(size) {
|
|
|
|
|
+ super(size, torch.int64);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.HalfStorage', class extends torch_storage._StorageBase {
|
|
|
|
|
+ constructor(size) {
|
|
|
|
|
+ super(size, torch.float16);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.FloatStorage', class extends torch_storage._StorageBase {
|
|
|
|
|
+ constructor(size) {
|
|
|
|
|
+ super(size, torch.float32);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.DoubleStorage', class extends torch_storage._StorageBase {
|
|
|
|
|
+ constructor(size) {
|
|
|
|
|
+ super(size, torch.float64);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.ComplexHalfStorage', class extends torch_storage._StorageBase {
|
|
|
|
|
+ constructor(size) {
|
|
|
|
|
+ super(size, torch.complex32);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.ComplexFloatStorage', class extends torch_storage._StorageBase {
|
|
|
|
|
+ constructor(size) {
|
|
|
|
|
+ super(size, torch.complex64);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.ComplexDoubleStorage', class extends torch_storage._StorageBase {
|
|
|
|
|
+ constructor(size) {
|
|
|
|
|
+ super(size, torch.complex128);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.QInt8Storage', class extends torch_storage._StorageBase {
|
|
|
|
|
+ constructor(size) {
|
|
|
|
|
+ super(size, torch.qint8);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.QUInt8Storage', class extends torch_storage._StorageBase {
|
|
|
|
|
+ constructor(size) {
|
|
|
|
|
+ super(size, torch.quint8);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.QInt32Storage', class extends torch_storage._StorageBase {
|
|
|
|
|
+ constructor(size) {
|
|
|
|
|
+ super(size, torch.qint32);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.BFloat16Storage', class extends torch_storage._StorageBase {
|
|
|
|
|
+ constructor(size) {
|
|
|
|
|
+ super(size, torch.bfloat16);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.Size', class extends Array {
|
|
|
|
|
+ constructor(size) {
|
|
|
|
|
+ super(size.length);
|
|
|
|
|
+ for (let i = 0; i < size.length; i++) {
|
|
|
|
|
+ this[i] = size[i];
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ __len__() {
|
|
|
|
|
+ return this.length;
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.Tensor', class {
|
|
|
|
|
+ constructor() {
|
|
|
|
|
+ }
|
|
|
|
|
+ get device() {
|
|
|
|
|
+ return this.storage().device;
|
|
|
|
|
+ }
|
|
|
|
|
+ get dtype() {
|
|
|
|
|
+ return this.storage().dtype;
|
|
|
|
|
+ }
|
|
|
|
|
+ get shape() {
|
|
|
|
|
+ return this._shape;
|
|
|
|
|
+ }
|
|
|
|
|
+ size() {
|
|
|
|
|
+ return this._shape;
|
|
|
|
|
+ }
|
|
|
|
|
+ storage() {
|
|
|
|
|
+ if (!this._storage) {
|
|
|
|
|
+ const name = this.__class__.__name__ == 'Tensor' ? 'FloatStorage' : this.__storage__.__name__.replace('Tensor', 'Storage');
|
|
|
|
|
+ this._storage = self.invoke(this.__class__.__module__ + '.' + name, []);
|
|
|
|
|
+ }
|
|
|
|
|
+ return this._storage;
|
|
|
|
|
+ }
|
|
|
|
|
+ storage_offset() {
|
|
|
|
|
+ return this._storage_offset;
|
|
|
|
|
+ }
|
|
|
|
|
+ stride() {
|
|
|
|
|
+ return this._stride;
|
|
|
|
|
+ }
|
|
|
|
|
+ resize_(shape) {
|
|
|
|
|
+ this._shape = shape;
|
|
|
|
|
+ }
|
|
|
|
|
+ __len__() {
|
|
|
|
|
+ return this._shape[0];
|
|
|
|
|
+ }
|
|
|
|
|
+ __setstate__(state) {
|
|
|
|
|
+ this._storage = state[0];
|
|
|
|
|
+ this._storage_offset = state[1];
|
|
|
|
|
+ this._shape = state[2];
|
|
|
|
|
+ this._stride = state[3];
|
|
|
|
|
+ }
|
|
|
|
|
+ __bool__() {
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
|
|
+ __int__() {
|
|
|
|
|
+ const storage = this.storage();
|
|
|
|
|
+ if (storage && storage.dtype.__reduce__() === 'int64' && storage.data.length === 8) {
|
|
|
|
|
+ const buffer = storage.data;
|
|
|
|
|
+ const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
|
|
|
|
|
+ return view.getInt64(0, true);
|
|
|
|
|
+ }
|
|
|
|
|
+ return NaN;
|
|
|
|
|
+ }
|
|
|
|
|
+ __float__() {
|
|
|
|
|
+ const storage = this.storage();
|
|
|
|
|
+ if (storage && storage.dtype.__reduce__() === 'float32') {
|
|
|
|
|
+ if (storage.size() !== undefined && storage.data.length === 4) {
|
|
|
|
|
+ const buffer = storage.data;
|
|
|
|
|
+ const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
|
|
|
|
|
+ return view.getFloat32(0, true);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ return NaN;
|
|
|
|
|
+ }
|
|
|
|
|
+ __str__() {
|
|
|
|
|
+ return 'tensor(...)';
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.nn.parameter.Parameter', class extends torch.Tensor {
|
|
|
|
|
+ constructor(data, requires_grad) {
|
|
|
|
|
+ super();
|
|
|
|
|
+ if (!data) {
|
|
|
|
|
+ data = self.invoke('torch.Tensor', [[]]);
|
|
|
|
|
+ }
|
|
|
|
|
+ this.data = data;
|
|
|
|
|
+ this.requires_grad = requires_grad !== undefined ? requires_grad : true;
|
|
|
|
|
+ }
|
|
|
|
|
+ __setstate__(state) {
|
|
|
|
|
+ switch (state.length) {
|
|
|
|
|
+ case 3:
|
|
|
|
|
+ this.data = null;
|
|
|
|
|
+ break;
|
|
|
|
|
+ case 4:
|
|
|
|
|
+ this.data = state[0];
|
|
|
|
|
+ break;
|
|
|
|
|
+ case 5:
|
|
|
|
|
+ this.data = state[0];
|
|
|
|
|
+ break;
|
|
|
|
|
+ default:
|
|
|
|
|
+ throw new python.Error("Unsupported parameter state length '" + state.length + "'.");
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.nn.parameter.UninitializedParameter', class extends torch_nn_parameter.Parameter {
|
|
|
|
|
+ constructor(requires_grad /*, device, dtype */) {
|
|
|
|
|
+ super(undefined, requires_grad);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this.registerType('torch.BoolTensor', class extends torch.Tensor {});
|
|
|
|
|
+ this.registerType('torch.ByteTensor', class extends torch.Tensor {});
|
|
|
|
|
+ this.registerType('torch.CharTensor', class extends torch.Tensor {});
|
|
|
|
|
+ this.registerType('torch.ShortTensor', class extends torch.Tensor {});
|
|
|
|
|
+ this.registerType('torch.IntTensor', class extends torch.Tensor {});
|
|
|
|
|
+ this.registerType('torch.LongTensor', class extends torch.Tensor {});
|
|
|
|
|
+ this.registerType('torch.HalfTensor', class extends torch.Tensor {});
|
|
|
|
|
+ this.registerType('torch.FloatTensor', class extends torch.Tensor {});
|
|
|
|
|
+ this.registerType('torch.DoubleTensor', class extends torch.Tensor {});
|
|
|
|
|
+ this.registerType('torch.ComplexFloatTensor', class extends torch.Tensor {});
|
|
|
|
|
+ this.registerType('torch.ComplexDoubleTensor', class extends torch.Tensor {});
|
|
|
|
|
+ this.registerType('torch.QInt8Tensor', class extends torch.Tensor {});
|
|
|
|
|
+ this.registerType('torch.QUInt8Tensor', class extends torch.Tensor {});
|
|
|
|
|
+ this.registerType('torch.QInt32Tensor', class extends torch.Tensor {});
|
|
|
|
|
+ this.registerType('torch.BFloat16Tensor', class extends torch.Tensor {});
|
|
|
|
|
+ this.registerType('torch.cuda.FloatTensor', class extends torch.Tensor {});
|
|
|
|
|
+ this.registerType('torch.cuda.DoubleTensor', class extends torch.Tensor {});
|
|
|
|
|
+ torch.uint8 = torch.ByteStorage.dtype = new torch.dtype({ type: 0, name: 'uint8', itemsize: 1 });
|
|
|
|
|
+ torch.int8 = torch.CharStorage.dtype = new torch.dtype({ type: 1, name: 'int8', itemsize: 1 });
|
|
|
|
|
+ torch.int16 = torch.ShortStorage.dtype = new torch.dtype({ type: 2, name: 'int16', itemsize: 2 });
|
|
|
|
|
+ torch.int32 = torch.IntStorage.dtype = new torch.dtype({ type: 3, name: 'int32', itemsize: 4 });
|
|
|
|
|
+ torch.int64 = torch.LongStorage.dtype = new torch.dtype({ type: 4, name: 'int64', itemsize: 8 });
|
|
|
|
|
+ torch.float16 = torch.HalfStorage.dtype = new torch.dtype({ type: 5, name: 'float16', itemsize: 2 });
|
|
|
|
|
+ torch.float32 = torch.FloatStorage.dtype = new torch.dtype({ type: 6, name: 'float32', itemsize: 4 });
|
|
|
|
|
+ torch.float64 = torch.DoubleStorage.dtype = new torch.dtype({ type: 7, name: 'float64', itemsize: 8 });
|
|
|
|
|
+ torch.complex32 = torch.ComplexHalfStorage.dtype = new torch.dtype({ type: 8, name: 'complex32', itemsize: 4 });
|
|
|
|
|
+ torch.complex64 = torch.ComplexFloatStorage.dtype = new torch.dtype({ type: 9, name: 'complex64', itemsize: 8 });
|
|
|
|
|
+ torch.complex128 = torch.ComplexDoubleStorage.dtype = new torch.dtype({ type: 10, name: 'complex128', itemsize: 16 });
|
|
|
|
|
+ torch.bool = torch.BoolStorage.dtype = new torch.dtype({ type: 11, name: 'boolean', itemsize: 1 });
|
|
|
|
|
+ torch.qint8 = torch.QInt8Storage.dtype = new torch.dtype({ type: 12, name: 'qint8', itemsize: 1 });
|
|
|
|
|
+ torch.quint8 = torch.QUInt8Storage.dtype = new torch.dtype({ type: 13, name: 'quint8', itemsize: 1 });
|
|
|
|
|
+ torch.qint32 = torch.QInt32Storage.dtype = new torch.dtype({ type: 14, name: 'qint32', itemsize: 4 });
|
|
|
|
|
+ torch.bfloat16 = torch.BFloat16Storage.dtype = new torch.dtype({ type: 15, name: 'bfloat16', itemsize: 2 });
|
|
|
|
|
+ torch.quint4x2 = new torch.dtype({ type: 16, name: 'quint4x2' });
|
|
|
|
|
+ torch.per_tensor_affine = new torch.qscheme();
|
|
|
|
|
+ torch.per_channel_affine = new torch.qscheme();
|
|
|
|
|
+ torch.per_tensor_symmetric = new torch.qscheme();
|
|
|
|
|
+ torch.per_channel_symmetric = new torch.qscheme();
|
|
|
|
|
+ torch.per_channel_affine_float_qparams = new torch.qscheme();
|
|
|
|
|
+ torch.inf = this.register('math').inf;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get builtins() {
|
|
get builtins() {
|