|
|
@@ -39,6 +39,7 @@ namespace traph
|
|
|
virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const = 0;
|
|
|
virtual std::shared_ptr<TensorInterface> mean() const = 0;
|
|
|
virtual void mul_(std::shared_ptr<TensorInterface> other) = 0;
|
|
|
+ virtual idx_type ndimension() const = 0;
|
|
|
virtual void neg_() = 0;
|
|
|
virtual idx_type offset() const = 0;
|
|
|
virtual layout_type order() const = 0;
|
|
|
@@ -96,6 +97,7 @@ namespace traph
|
|
|
virtual TensorInterfacePtr mean() const = 0;
|
|
|
virtual void mul_(T value) = 0;
|
|
|
virtual void mul_(std::shared_ptr<TensorInterface> other) = 0;
|
|
|
+ virtual idx_type ndimension() const = 0;
|
|
|
virtual void neg_() = 0;
|
|
|
virtual idx_type offset() const = 0;
|
|
|
virtual layout_type order() const = 0;
|
|
|
@@ -154,6 +156,17 @@ namespace traph
|
|
|
|
|
|
return true;
|
|
|
}
|
|
|
+
|
|
|
+ inline std::shared_ptr<TensorInterface> sort_strides(std::shared_ptr<TensorInterface> t)
|
|
|
+ {
|
|
|
+ DimVector indices(t->ndimension());
|
|
|
+ for (idx_type i = 0; i < t->ndimension(); i++)
|
|
|
+ indices[i] = i;
|
|
|
+
|
|
|
+ std::sort(&indices[0], &indices[indices.size() - 1], std::greater<idx_type>());
|
|
|
+ std::shared_ptr<TensorInterface> ret = t->permute(indices);
|
|
|
+ return ret;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
#endif
|