Ver código fonte

add transpose

JasonWang 6 anos atrás
pai
commit
6a34d5b232

+ 14 - 0
traph/include/traph/core/index.h

@@ -163,8 +163,19 @@ namespace traph
 			return flat_size;
 		}
 
+        bool in_range(idx_type dim) const
+        {
+            if(dim < 0)
+                dim = dim_num + dim;
+            
+            return dim >= 0 && dim < dim_num;
+        }
+
         idx_type& operator[](idx_type dim)
         {
+            if(dim < 0)
+                dim = dim_num + dim;
+            
             if(dim<0 || dim >= dim_num)
                 throw std::runtime_error("index out of dim vector size");
             
@@ -176,6 +187,9 @@ namespace traph
 
         idx_type operator[](idx_type dim) const
         {
+            if(dim < 0)
+                dim = dim_num + dim;
+            
             if(dim<0 || dim >= dim_num)
                 throw std::runtime_error("index out of dim vector size");
             

+ 2 - 0
traph/include/traph/core/variable.h

@@ -25,6 +25,7 @@ namespace traph
         virtual std::shared_ptr<OpBase> grad_fn() = 0;
         virtual std::vector<VariableInterfacePtr>& inputs() = 0;
         virtual bool is_leaf() const = 0;
+		virtual void leaf_(bool state) = 0;
         virtual idx_type offset() const = 0;
 		virtual layout_type order() const = 0;
         virtual platform_type platform() = 0;
@@ -65,6 +66,7 @@ namespace traph
         virtual std::vector<VariableInterfacePtr>& inputs() = 0;
         virtual bool is_leaf() const = 0;
         virtual T item() const = 0;
+		virtual void leaf_(bool state) = 0;
         virtual idx_type offset() const = 0;
 		virtual layout_type order() const = 0;
         virtual platform_type platform() = 0;

+ 31 - 0
traph/include/traph/nn/arithmetic.h

@@ -15,6 +15,37 @@
 
 namespace traph
 {
+	// variable constructor
+	template<class T>
+	VariablePtr<T> zeros(std::initializer_list<idx_type> l, bool requires_grad = false)
+	{
+		DimVector dim;
+		for (auto i : l)
+			dim.push_back(i);
+
+		std::shared_ptr<Variable<T>> result(new Variable<T>(dim, false));
+		result->leaf_(true);
+		result->fill_(0);
+
+		return result;
+	}
+
+	template<class T>
+	VariablePtr<T> ones(std::initializer_list<idx_type> l, bool requires_grad = false)
+	{
+		DimVector dim;
+		for (auto i : l)
+			dim.push_back(i);
+
+		std::shared_ptr<Variable<T>> result(new Variable<T>(dim, false));
+		result->leaf_(true);
+		result->fill_(1);
+
+		return result;
+	}
+
+
+	// arithmetic
     template<class T>
 	VariablePtr<T> sum(VariablePtr<T> input)
     {

+ 7 - 27
traph/include/traph/nn/variable.h

@@ -70,6 +70,7 @@ namespace traph
         virtual std::vector<VariableInterfacePtr>& inputs() override;
         virtual bool is_leaf() const override;
         virtual T item() const override;
+		virtual void leaf_(bool state) override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;
         virtual platform_type platform() override;
@@ -239,6 +240,12 @@ namespace traph
 		return _data->item();
 	}
 
+	template<typename T>
+	void Variable<T>::leaf_(bool state)
+	{
+		_leaf = state;
+	}
+
 	template<typename T>
 	idx_type Variable<T>::offset() const
 	{
@@ -307,33 +314,6 @@ namespace traph
 	{
 		return _data->stride();
 	}
-
-    // variable constructor
-    template<class T>
-    VariablePtr<T> zeros(std::initializer_list<idx_type> l, bool requires_grad = false)
-    {
-        DimVector dim;
-		for (auto i : l)
-			dim.push_back(i);
-
-        std::shared_ptr<Variable<T>> result(new Variable<T>(dim, false));
-        result->fill_(0);
-
-        return result;
-    }
-
-    template<class T>
-    VariablePtr<T> ones(std::initializer_list<idx_type> l, bool requires_grad = false)
-    {
-		DimVector dim;
-		for (auto i : l)
-			dim.push_back(i);
-
-        std::shared_ptr<Variable<T>> result(new Variable<T>(dim, false));
-        result->fill_(1);
-
-        return result;
-    }
 }
 
 

+ 10 - 2
traph/source/tensor/byte_tensor.cpp

@@ -406,11 +406,19 @@ namespace traph
 
     void Tensor<u8>::transpose_(idx_type dim0, idx_type dim1)
     {
-
+        if(dim0 != dim1 &&
+            _dimensions.in_range(dim0) &&
+            _dimensions.in_range(dim1))
+        {
+            std::swap(_dimensions[dim0], _dimensions[dim1]);
+            std::swap(_strides[dim0], _strides[dim1]);
+        }
     }
 
     std::shared_ptr<TensorInterface> Tensor<u8>::transpose(idx_type dim0, idx_type dim1)
     {
-
+        std::shared_ptr<TensorInterface> result= this->clone();
+        result->transpose_(dim0, dim1);
+        return result;
     }
 }

+ 10 - 2
traph/source/tensor/char_tensor.cpp

@@ -406,11 +406,19 @@ namespace traph
 
     void Tensor<i8>::transpose_(idx_type dim0, idx_type dim1)
     {
-
+        if(dim0 != dim1 &&
+            _dimensions.in_range(dim0) &&
+            _dimensions.in_range(dim1))
+        {
+            std::swap(_dimensions[dim0], _dimensions[dim1]);
+            std::swap(_strides[dim0], _strides[dim1]);
+        }
     }
 
     std::shared_ptr<TensorInterface> Tensor<i8>::transpose(idx_type dim0, idx_type dim1)
     {
-
+        std::shared_ptr<TensorInterface> result= this->clone();
+        result->transpose_(dim0, dim1);
+        return result;
     }
 }

+ 10 - 2
traph/source/tensor/double_tensor.cpp

@@ -406,11 +406,19 @@ namespace traph
 
     void Tensor<f64>::transpose_(idx_type dim0, idx_type dim1)
     {
-
+        if(dim0 != dim1 &&
+            _dimensions.in_range(dim0) &&
+            _dimensions.in_range(dim1))
+        {
+            std::swap(_dimensions[dim0], _dimensions[dim1]);
+            std::swap(_strides[dim0], _strides[dim1]);
+        }
     }
 
     std::shared_ptr<TensorInterface> Tensor<f64>::transpose(idx_type dim0, idx_type dim1)
     {
-
+        std::shared_ptr<TensorInterface> result= this->clone();
+        result->transpose_(dim0, dim1);
+        return result;
     }
 }

+ 10 - 2
traph/source/tensor/float_tensor.cpp

@@ -406,11 +406,19 @@ namespace traph
 
     void Tensor<f32>::transpose_(idx_type dim0, idx_type dim1)
     {
-
+        if(dim0 != dim1 &&
+            _dimensions.in_range(dim0) &&
+            _dimensions.in_range(dim1))
+        {
+            std::swap(_dimensions[dim0], _dimensions[dim1]);
+            std::swap(_strides[dim0], _strides[dim1]);
+        }
     }
 
     std::shared_ptr<TensorInterface> Tensor<f32>::transpose(idx_type dim0, idx_type dim1)
     {
-
+        std::shared_ptr<TensorInterface> result= this->clone();
+        result->transpose_(dim0, dim1);
+        return result;
     }
 }

+ 10 - 2
traph/source/tensor/int_tensor.cpp

@@ -406,11 +406,19 @@ namespace traph
 
     void Tensor<i32>::transpose_(idx_type dim0, idx_type dim1)
     {
-
+        if(dim0 != dim1 &&
+            _dimensions.in_range(dim0) &&
+            _dimensions.in_range(dim1))
+        {
+            std::swap(_dimensions[dim0], _dimensions[dim1]);
+            std::swap(_strides[dim0], _strides[dim1]);
+        }
     }
 
     std::shared_ptr<TensorInterface> Tensor<i32>::transpose(idx_type dim0, idx_type dim1)
     {
-
+        std::shared_ptr<TensorInterface> result= this->clone();
+        result->transpose_(dim0, dim1);
+        return result;
     }
 }

+ 10 - 2
traph/source/tensor/long_tensor.cpp

@@ -406,11 +406,19 @@ namespace traph
 
     void Tensor<i64>::transpose_(idx_type dim0, idx_type dim1)
     {
-
+        if(dim0 != dim1 &&
+            _dimensions.in_range(dim0) &&
+            _dimensions.in_range(dim1))
+        {
+            std::swap(_dimensions[dim0], _dimensions[dim1]);
+            std::swap(_strides[dim0], _strides[dim1]);
+        }
     }
 
     std::shared_ptr<TensorInterface> Tensor<i64>::transpose(idx_type dim0, idx_type dim1)
     {
-
+        std::shared_ptr<TensorInterface> result= this->clone();
+        result->transpose_(dim0, dim1);
+        return result;
     }
 }

+ 10 - 2
traph/source/tensor/short_tensor.cpp

@@ -406,11 +406,19 @@ namespace traph
 
     void Tensor<i16>::transpose_(idx_type dim0, idx_type dim1)
     {
-
+        if(dim0 != dim1 &&
+            _dimensions.in_range(dim0) &&
+            _dimensions.in_range(dim1))
+        {
+            std::swap(_dimensions[dim0], _dimensions[dim1]);
+            std::swap(_strides[dim0], _strides[dim1]);
+        }
     }
 
     std::shared_ptr<TensorInterface> Tensor<i16>::transpose(idx_type dim0, idx_type dim1)
     {
-
+        std::shared_ptr<TensorInterface> result= this->clone();
+        result->transpose_(dim0, dim1);
+        return result;
     }
 }