ソースを参照

add mul_ and PowOp

JasonWang 6 年 前
コミット
689084d669

+ 1 - 0
traph/include/traph/core/tensor.h

@@ -89,6 +89,7 @@ namespace traph
         virtual std::shared_ptr<TensorInterface> inverse() const = 0;
         virtual T item() const = 0;
         virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const = 0;
+        virtual void mul_(T value) = 0;
         virtual void neg_() = 0;
         virtual idx_type offset() const = 0;
 		virtual layout_type order() const = 0;

+ 54 - 124
traph/include/traph/nn/function.h

@@ -15,6 +15,54 @@
 
 namespace traph
 {
+
+#define UNARY_OP(name, op_name)                                           \
+	VariableInterfacePtr name(VariableInterfacePtr input)                 \
+	{                                                                     \
+		DimVector result_dim;                                             \
+        VariableInterfacePtr result = input->new_empty(result_dim, true); \
+		std::shared_ptr<op_name> op(new op_name);                         \
+		std::vector<VariableInterfacePtr> result_inputs{ input };         \
+		result->data_(op->forward({ input->data() }));                    \
+		if (input->requires_grad())                                       \
+		{                                                                 \
+			result->grad_(result->data()->create_grad());                 \
+			result->grad()->fill_(0);                                     \
+			result->requires_grad_(true);                                 \
+			result->grad_fn_(op);                                         \
+			result->inputs_(result_inputs);                               \
+		}                                                                 \
+		else                                                              \
+		{                                                                 \
+			result->requires_grad_(false);                                \
+		}                                                                 \
+		return result;                                                    \
+	}
+
+#define BINARY_OP(name, op_name)                                                           \
+	VariableInterfacePtr name(VariableInterfacePtr left, VariableInterfacePtr right)       \
+	{                                                                                      \
+		DimVector result_dim;                                                              \
+        VariableInterfacePtr result = left->new_empty(result_dim, true);                   \
+		std::shared_ptr<op_name> op(new op_name);                                          \
+		result->data_(op->forward({ left->data(), right->data() }));                       \
+		if (left->requires_grad() || right->requires_grad())                               \
+		{                                                                                  \
+			std::vector<VariableInterfacePtr> result_inputs{ left, right };                \
+			result->grad_(result->data()->create_grad());                                  \
+			result->grad()->fill_(0);                                                      \
+			result->requires_grad_(true);                                                  \
+			result->grad_fn_(op);                                                          \
+			result->inputs_(result_inputs);                                                \
+		}                                                                                  \
+		else                                                                               \
+		{                                                                                  \
+			result->requires_grad_(false);                                                 \
+		}                                                                                  \
+		return result;                                                                     \
+	}
+
+
 	// creation function
 	template<typename T>
 	VariableInterfacePtr empty(std::initializer_list<idx_type> l, bool requires_grad = false)
@@ -63,83 +111,12 @@ namespace traph
 	}
 
 	// arithmetic function
-	VariableInterfacePtr sum(VariableInterfacePtr input)
-    {
-		DimVector result_dim(1);
-		result_dim[0] = 1;
-
-        VariableInterfacePtr result = input->new_empty(result_dim, true);
-        std::shared_ptr<SumOp> op(new SumOp);
-
-		result->data_(op->forward({ input->data() }));
-        if(input->requires_grad())
-        {
-			std::vector<VariableInterfacePtr> result_inputs { input };
-			result->grad_(result->data()->create_grad());
-			result->grad()->fill_(0);
-            result->requires_grad_(true);
-            result->grad_fn_(op);
-            result->inputs_(result_inputs);
-        }
-        else
-        {
-            result->requires_grad_(false);
-        }
-
-        return result;
-    }
-
-	VariableInterfacePtr add(VariableInterfacePtr left, VariableInterfacePtr right)
-	{
-		DimVector result_dim;
-
-        VariableInterfacePtr result = left->new_empty(result_dim, true);
-		std::shared_ptr<AddOp> op(new AddOp);
-		if (left->requires_grad() || right->requires_grad())
-		{
-			std::vector<VariableInterfacePtr> result_inputs{ left, right };
-			result->data_(op->forward({ left->data(), right->data() }));
-			result->grad_(result->data()->create_grad());
-			result->grad()->fill_(0);
-			result->requires_grad_(true);
-			result->grad_fn_(op);
-			result->inputs_(result_inputs);
-		}
-		else
-		{
-			result->data_(op->forward({ left->data(), right->data() }));
-			result->requires_grad_(false);
-		}
-
-		return result;
-	}
-
-	VariableInterfacePtr matmul(VariableInterfacePtr left, VariableInterfacePtr right)
-	{
-		DimVector result_dim;
-
-        VariableInterfacePtr result = left->new_empty(result_dim, true);
-		std::shared_ptr<MatmulOp> op(new MatmulOp);
-		if (left->requires_grad() || right->requires_grad())
-		{
-			std::vector<VariableInterfacePtr> result_inputs{ left, right };
-			result->data_(op->forward({ left->data(), right->data() }));
-			result->grad_(result->data()->create_grad());
-			result->grad()->fill_(0);
-			result->requires_grad_(true);
-			result->grad_fn_(op);
-			result->inputs_(result_inputs);
-		}
-		else
-		{
-			result->data_(op->forward({ left->data(), right->data() }));
-			result->requires_grad_(false);
-		}
-
-		return result;
-	}
+	UNARY_OP(sum, SumOp)
 
+	BINARY_OP(add, AddOp)
 	
+	BINARY_OP(matmul, MatmulOp)
+
 	VariableInterfacePtr select(VariableInterfacePtr input, const SliceVector& slice)
 	{
 		DimVector result_dim;
@@ -167,56 +144,9 @@ namespace traph
 		return result;
 	}
 
+	UNARY_OP(sin, SinOp)
 
-	VariableInterfacePtr sin(VariableInterfacePtr input)
-	{
-		DimVector result_dim;
-
-        VariableInterfacePtr result = input->new_empty(result_dim, true);
-		std::shared_ptr<SinOp> op(new SinOp);
-
-		std::vector<VariableInterfacePtr> result_inputs{ input };
-		result->data_(op->forward({ input->data() }));
-
-		if (input->requires_grad())
-		{
-			result->grad_(result->data()->create_grad());
-			result->grad()->fill_(0);
-			result->requires_grad_(true);
-			result->grad_fn_(op);
-			result->inputs_(result_inputs);
-		}
-		else
-		{
-			result->requires_grad_(false);
-		}
-
-		return result;
-	}
-
-	VariableInterfacePtr sub(VariableInterfacePtr left, VariableInterfacePtr right)
-	{
-		DimVector result_dim;
-
-        VariableInterfacePtr result = left->new_empty(result_dim, true);
-		std::shared_ptr<SubOp> op(new SubOp);
-		result->data_(op->forward({ left->data(), right->data() }));
-		if (left->requires_grad() || right->requires_grad())
-		{
-			std::vector<VariableInterfacePtr> result_inputs{ left, right };
-			result->grad_(result->data()->create_grad());
-			result->grad()->fill_(0);
-			result->requires_grad_(true);
-			result->grad_fn_(op);
-			result->inputs_(result_inputs);
-		}
-		else
-		{
-			result->requires_grad_(false);
-		}
-
-		return result;
-	}
+	BINARY_OP(sub, SubOp)
 
 	VariableInterfacePtr transpose(VariableInterfacePtr input, idx_type dim0, idx_type dim1)
 	{

+ 29 - 0
traph/include/traph/nn/operation.h

@@ -107,6 +107,35 @@ namespace traph
 		}
 	};
 
+	class PowOp: public OpBase
+	{
+	private:
+		float _exp;
+	public:
+		void set_exp(float exp)
+		{
+			_exp = exp;
+		}
+
+		virtual TensorInterfacePtr forward(std::vector<TensorInterfacePtr> inputs) override
+		{
+			assert(inputs.size() == 1);
+
+			TensorInterfacePtr input = inputs[0];
+			auto output = input->clone();
+			output->pow_(_exp);
+			
+			return output;
+		}
+
+		virtual std::vector<TensorBasePtr<f32>> backward(TensorBasePtr<f32> output_grad) override
+		{
+			auto output = std::dynamic_pointer_cast<TensorBase<f32>>(output_grad->clone());
+			output->mul_(_exp);
+			return { output };
+		}
+	};
+
 	class SelectOp : public OpBase
 	{
 	public:

+ 2 - 1
traph/include/traph/tensor/byte_tensor.h

@@ -69,7 +69,8 @@ namespace traph
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual u8 item() const override;
 		virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const override;
-		virtual void neg_() override;
+		virtual void mul_(u8 value) override;
+        virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;
 		virtual platform_type platform() override;

+ 2 - 1
traph/include/traph/tensor/char_tensor.h

@@ -67,7 +67,8 @@ namespace traph
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual i8 item() const override;
 		virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const override;
-		virtual void neg_() override;
+		virtual void mul_(i8 value) override;
+        virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;
 		virtual platform_type platform() override;

+ 2 - 1
traph/include/traph/tensor/double_tensor.h

@@ -67,7 +67,8 @@ namespace traph
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual f64 item() const override;
 		virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const override;
-		virtual void neg_() override;
+		virtual void mul_(f64 value) override;
+        virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;
 		virtual platform_type platform() override;

+ 2 - 1
traph/include/traph/tensor/float_tensor.h

@@ -68,7 +68,8 @@ namespace traph
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual f32 item() const override;
 		virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const override;
-		virtual void neg_() override;
+		virtual void mul_(f32 value) override;
+        virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;
 		virtual platform_type platform() override;

+ 2 - 1
traph/include/traph/tensor/int_tensor.h

@@ -67,7 +67,8 @@ namespace traph
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual i32 item() const override;
 		virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const override;
-		virtual void neg_() override;
+		virtual void mul_(i32 value) override;
+        virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;
 		virtual platform_type platform() override;

+ 2 - 1
traph/include/traph/tensor/long_tensor.h

@@ -67,7 +67,8 @@ namespace traph
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual i64 item() const override;
 		virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const override;
-		virtual void neg_() override;
+		virtual void mul_(i64 value) override;
+        virtual void neg_() override;
 		virtual idx_type offset() const override;
 		virtual layout_type order() const override;
 		virtual platform_type platform() override;

+ 2 - 1
traph/include/traph/tensor/short_tensor.h

@@ -67,7 +67,8 @@ namespace traph
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual i16 item() const override;
 		virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const override;
-		virtual void neg_() override;
+		virtual void mul_(i16 value) override;
+        virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;
 		virtual platform_type platform() override;

+ 2 - 1
traph/include/traph/tensor/tensor.h

@@ -69,7 +69,8 @@ namespace traph
         virtual std::shared_ptr<TensorInterface> inverse() const override;
         virtual T item() const override;
         virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const override;
-		virtual void neg_() override;
+		virtual void mul_(T value) override;
+        virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;
         virtual platform_type platform() override;

+ 5 - 0
traph/source/tensor/byte_tensor.cpp

@@ -262,6 +262,11 @@ namespace traph
 		return matmul_impl(*this, *right_matrix);
 	}
 
+    void Tensor<u8>::mul_(u8 value)
+    {
+        apply_([value](u8 a)->u8 {return a*value; });
+    }
+
     void Tensor<u8>::neg_()
     {
         apply_([](u8 a)->u8 {return -a; });

+ 5 - 0
traph/source/tensor/char_tensor.cpp

@@ -262,6 +262,11 @@ namespace traph
 		return matmul_impl(*this, *right_matrix);
 	}
 
+    void Tensor<i8>::mul_(i8 value)
+    {
+        apply_([value](i8 a)->i8 {return a*value; });
+    }
+
     void Tensor<i8>::neg_()
     {
         apply_([](i8 a)->i8 {return -a; });

+ 5 - 0
traph/source/tensor/double_tensor.cpp

@@ -262,6 +262,11 @@ namespace traph
 		return matmul_impl(*this, *right_matrix);
 	}
 
+    void Tensor<f64>::mul_(f64 value)
+    {
+        apply_([value](f64 a)->f64 {return a*value; });
+    }
+
     void Tensor<f64>::neg_()
     {
         apply_([](f64 a)->f64 {return -a; });

+ 6 - 0
traph/source/tensor/float_tensor.cpp

@@ -263,6 +263,12 @@ namespace traph
 		return matmul_impl(*this, *right_matrix);
 	}
 
+    void Tensor<f32>::mul_(f32 value)
+    {
+        apply_([value](f32 a)->f32 {return a*value; });
+    }
+
+
     void Tensor<f32>::neg_()
     {
         apply_([](f32 a)->f32 {return -a; });

+ 6 - 0
traph/source/tensor/int_tensor.cpp

@@ -262,6 +262,12 @@ namespace traph
 		return matmul_impl(*this, *right_matrix);
 	}
 
+    void Tensor<i32>::mul_(i32 value)
+    {
+        apply_([value](i32 a)->i32 {return a*value; });
+    }
+
+
     void Tensor<i32>::neg_()
     {
         apply_([](i32 a)->i32 {return -a; });

+ 5 - 0
traph/source/tensor/long_tensor.cpp

@@ -262,6 +262,11 @@ namespace traph
 		return matmul_impl(*this, *right_matrix);
 	}
 
+    void Tensor<i64>::mul_(i64 value)
+    {
+        apply_([value](i64 a)->i64 {return a*value; });
+    }
+
     void Tensor<i64>::neg_()
     {
         apply_([](i64 a)->i64 {return -a; });

+ 5 - 0
traph/source/tensor/short_tensor.cpp

@@ -262,6 +262,11 @@ namespace traph
 		return matmul_impl(*this, *right_matrix);
 	}
 
+    void Tensor<i16>::mul_(i16 value)
+    {
+        apply_([value](i16 a)->i16 {return a*value; });
+    }
+
     void Tensor<i16>::neg_()
     {
         apply_([](i16 a)->i16 {return -a; });

+ 6 - 0
traph/source/tensor/tensor.cpp

@@ -121,6 +121,12 @@ namespace traph
 		throw std::runtime_error("No implement");
     }
 
+    template<typename T>
+    void Tensor<T>::mul_(T value)
+    {
+        throw std::runtime_error("No implement");
+    }
+
     template<typename T>
     void Tensor<T>::neg_()
     {