Ver Fonte

add ndimension

JasonWang há 6 anos atrás
pai
commit
c905ac79e3

+ 13 - 0
traph/include/traph/core/tensor.h

@@ -39,6 +39,7 @@ namespace traph
         virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const = 0;
         virtual std::shared_ptr<TensorInterface> mean() const = 0;
         virtual void mul_(std::shared_ptr<TensorInterface> other) = 0;
+        virtual idx_type ndimension() const = 0;
         virtual void neg_() = 0;
         virtual idx_type offset() const = 0;
 		virtual layout_type order() const = 0;
@@ -96,6 +97,7 @@ namespace traph
         virtual TensorInterfacePtr mean() const = 0;
         virtual void mul_(T value) = 0;
         virtual void mul_(std::shared_ptr<TensorInterface> other) = 0;
+        virtual idx_type ndimension() const = 0;
         virtual void neg_() = 0;
         virtual idx_type offset() const = 0;
 		virtual layout_type order() const = 0;
@@ -154,6 +156,17 @@ namespace traph
 
         return true;
     }
+
+    inline std::shared_ptr<TensorInterface> sort_strides(std::shared_ptr<TensorInterface> t)
+    {
+        DimVector indices(t->ndimension());
+        for (idx_type i = 0; i < t->ndimension(); i++)
+            indices[i] = i;
+
+        std::sort(&indices[0], &indices[indices.size() - 1], std::greater<idx_type>());
+        std::shared_ptr<TensorInterface> ret = t->permute(indices);
+        return ret;
+    }
 }
 
 #endif

+ 1 - 0
traph/include/traph/tensor/byte_tensor.h

@@ -71,6 +71,7 @@ namespace traph
 		virtual TensorInterfacePtr mean() const override;
         virtual void mul_(u8 value) override;
         virtual void mul_(std::shared_ptr<TensorInterface> other) override;
+        virtual idx_type ndimension() const override;
         virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;

+ 1 - 0
traph/include/traph/tensor/char_tensor.h

@@ -69,6 +69,7 @@ namespace traph
 		virtual TensorInterfacePtr mean() const override;
         virtual void mul_(i8 value) override;
         virtual void mul_(std::shared_ptr<TensorInterface> other) override;
+        virtual idx_type ndimension() const override;
         virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;

+ 1 - 0
traph/include/traph/tensor/double_tensor.h

@@ -69,6 +69,7 @@ namespace traph
 		virtual TensorInterfacePtr mean() const override;
         virtual void mul_(f64 value) override;
         virtual void mul_(std::shared_ptr<TensorInterface> other) override;
+        virtual idx_type ndimension() const override;
         virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;

+ 1 - 0
traph/include/traph/tensor/float_tensor.h

@@ -70,6 +70,7 @@ namespace traph
 		virtual TensorInterfacePtr mean() const override;
         virtual void mul_(f32 value) override;
         virtual void mul_(std::shared_ptr<TensorInterface> other) override;
+        virtual idx_type ndimension() const override;
         virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;

+ 1 - 0
traph/include/traph/tensor/int_tensor.h

@@ -69,6 +69,7 @@ namespace traph
 		virtual TensorInterfacePtr mean() const override;
         virtual void mul_(i32 value) override;
         virtual void mul_(std::shared_ptr<TensorInterface> other) override;
+        virtual idx_type ndimension() const override;
         virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;

+ 1 - 0
traph/include/traph/tensor/long_tensor.h

@@ -69,6 +69,7 @@ namespace traph
 		virtual TensorInterfacePtr mean() const override;
         virtual void mul_(i64 value) override;
         virtual void mul_(std::shared_ptr<TensorInterface> other) override;
+        virtual idx_type ndimension() const override;
         virtual void neg_() override;
 		virtual idx_type offset() const override;
 		virtual layout_type order() const override;

+ 1 - 0
traph/include/traph/tensor/short_tensor.h

@@ -69,6 +69,7 @@ namespace traph
 		virtual TensorInterfacePtr mean() const override;
         virtual void mul_(i16 value) override;
         virtual void mul_(std::shared_ptr<TensorInterface> other) override;
+        virtual idx_type ndimension() const override;
         virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;

+ 1 - 0
traph/include/traph/tensor/tensor.h

@@ -73,6 +73,7 @@ namespace traph
 		virtual TensorInterfacePtr mean() const override;
         virtual void mul_(T value) override;
         virtual void mul_(std::shared_ptr<TensorInterface> other) override;
+        virtual idx_type ndimension() const override;
         virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;

+ 5 - 0
traph/source/tensor/byte_tensor.cpp

@@ -376,6 +376,11 @@ namespace traph
 		mul_impl(-1, -1, lhs->offset(), rhs->offset());
     }
 
+    idx_type Tensor<u8>::ndimension() const
+    {
+        return _dimensions.size();
+    }
+
     void Tensor<u8>::neg_()
     {
         apply_([](u8 a)->u8 {return -a; });

+ 5 - 0
traph/source/tensor/char_tensor.cpp

@@ -376,6 +376,11 @@ namespace traph
 		mul_impl(-1, -1, lhs->offset(), rhs->offset());
     }
 
+    idx_type Tensor<i8>::ndimension() const
+    {
+        return _dimensions.size();
+    }
+
     void Tensor<i8>::neg_()
     {
         apply_([](i8 a)->i8 {return -a; });

+ 5 - 0
traph/source/tensor/double_tensor.cpp

@@ -377,6 +377,11 @@ namespace traph
 		mul_impl(-1, -1, lhs->offset(), rhs->offset());
     }
 
+    idx_type Tensor<f64>::ndimension() const
+    {
+        return _dimensions.size();
+    }
+
     void Tensor<f64>::neg_()
     {
         apply_([](f64 a)->f64 {return -a; });

+ 4 - 0
traph/source/tensor/float_tensor.cpp

@@ -378,6 +378,10 @@ namespace traph
 		mul_impl(-1, -1, lhs->offset(), rhs->offset());
     }
 
+    idx_type Tensor<f32>::ndimension() const
+    {
+        return _dimensions.size();
+    }
 
     void Tensor<f32>::neg_()
     {

+ 5 - 0
traph/source/tensor/int_tensor.cpp

@@ -377,6 +377,11 @@ namespace traph
 		mul_impl(-1, -1, lhs->offset(), rhs->offset());
     }
 
+    idx_type Tensor<i32>::ndimension() const
+    {
+        return _dimensions.size();
+    }
+
     void Tensor<i32>::neg_()
     {
         apply_([](i32 a)->i32 {return -a; });

+ 5 - 0
traph/source/tensor/long_tensor.cpp

@@ -377,6 +377,11 @@ namespace traph
 		mul_impl(-1, -1, lhs->offset(), rhs->offset());
     }
 
+    idx_type Tensor<i64>::ndimension() const
+    {
+        return _dimensions.size();
+    }
+
     void Tensor<i64>::neg_()
     {
         apply_([](i64 a)->i64 {return -a; });

+ 5 - 0
traph/source/tensor/short_tensor.cpp

@@ -377,6 +377,11 @@ namespace traph
 		mul_impl(-1, -1, lhs->offset(), rhs->offset());
     }
 
+    idx_type Tensor<i16>::ndimension() const
+    {
+        return _dimensions.size();
+    }
+
     void Tensor<i16>::neg_()
     {
         apply_([](i16 a)->i16 {return -a; });

+ 6 - 0
traph/source/tensor/tensor.cpp

@@ -146,6 +146,12 @@ namespace traph
         throw std::runtime_error("No implement");
     }
 
+    template<typename T>
+    idx_type Tensor<T>::ndimension() const
+    {
+        throw std::runtime_error("No implement");
+    }
+
     template<typename T>
     void Tensor<T>::neg_()
     {