ソースを参照

add sin operator

JasonWang 7 年 前
コミット
83bd8ac242

+ 22 - 0
traph/include/traph/core/operation.h

@@ -79,6 +79,28 @@ namespace traph
 			return { output_grad, output_grad };
 		}
 	};
+
+	class SinOp : public OpBase
+	{
+	public:
+		virtual TensorInterfacePtr forward(std::vector<TensorInterfacePtr> inputs) override
+		{
+			assert(inputs.size() == 1);
+
+			TensorInterfacePtr input = inputs[0];
+			TensorInterfacePtr result = input->clone();
+			result->sin_();
+
+			return result;
+		}
+
+		virtual std::vector<TensorBasePtr<f32>> backward(TensorBasePtr<f32> output_grad) override
+		{
+			TensorBasePtr<f32> result = std::dynamic_pointer_cast<TensorBase<f32>>(output_grad->clone());
+			result->cos_();
+			return { result };
+		}
+	};
 }
 
 #endif

+ 2 - 0
traph/include/traph/core/variable.h

@@ -28,6 +28,7 @@ namespace traph
         virtual idx_type offset() const = 0;
 		virtual layout_type order() const = 0;
         virtual platform_type platform() = 0;
+        virtual bool requires_grad() const = 0;
         virtual void requires_grad_(bool requires_grad) = 0;
         virtual void reshape_(const DimVector& dims) = 0;
         virtual void resize_(const DimVector& dims) = 0;
@@ -67,6 +68,7 @@ namespace traph
         virtual idx_type offset() const = 0;
 		virtual layout_type order() const = 0;
         virtual platform_type platform() = 0;
+        virtual bool requires_grad() const = 0;
         virtual void requires_grad_(bool requires_grad) = 0;
         virtual void reshape_(const DimVector& dims) = 0;
         virtual void resize_(const DimVector& dims) = 0;

+ 26 - 0
traph/include/traph/nn/arithmetic.h

@@ -66,6 +66,32 @@ namespace traph
 		return result;
 	}
 
+	template<class T>
+	VariablePtr<T> sin(VariablePtr<T> input)
+	{
+		VariablePtr<T> result(new Variable<T>);
+		std::shared_ptr<SinOp> op(new SinOp);
+
+		std::vector<VariableInterfacePtr> result_inputs{ input };
+		result->_data = std::dynamic_pointer_cast<TensorBase<T>>(op->forward({ input->_data }));
+		result->_leaf = false;
+
+		if (input->requires_grad())
+		{
+			result->_grad = result->_data->create_grad();
+			result->_grad->fill_(0);
+			result->_requires_grad = true;
+			result->_grad_fn = op;
+			result->_inputs = result_inputs;
+		}
+		else
+		{
+			result->_requires_grad = false;
+		}
+
+		return result;
+	}
+
 }
 
 #endif

+ 10 - 0
traph/include/traph/nn/variable.h

@@ -52,6 +52,9 @@ namespace traph
 		template<class T>
 		friend std::shared_ptr<Variable<T>> add(std::shared_ptr<Variable<T>> left, std::shared_ptr<Variable<T>> right);
 
+		template<class T>
+		friend std::shared_ptr<Variable<T>> sin(std::shared_ptr<Variable<T>> input);
+
         virtual void backward() override;
         virtual TensorInterfacePtr data() override;
         virtual device_id device() override;
@@ -64,6 +67,7 @@ namespace traph
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;
         virtual platform_type platform() override;
+		virtual bool requires_grad() const override;
         virtual void requires_grad_(bool requires_grad) override;
         virtual void reshape_(const DimVector& dims) override;
         virtual void resize_(const DimVector& dims) override;
@@ -247,6 +251,12 @@ namespace traph
 		return _data->platform();
 	}
 
+	template<typename T>
+	bool Variable<T>::requires_grad() const
+	{
+		return _requires_grad;
+	}
+
 	template<typename T>
 	void Variable<T>::requires_grad_(bool requires_grad)
 	{

+ 2 - 1
traph/source/nn/executor.cpp

@@ -53,7 +53,8 @@ namespace traph
             std::vector<VariableInterfacePtr>& cur_inputs(cur->inputs());
             for(int i = 0; i<cur_inputs.size(); ++i)
             {
-                variable_queue.push_back(cur_inputs[i].get());
+                if(cur_inputs[i]->requires_grad())
+                    variable_queue.push_back(cur_inputs[i].get());
             }
         }
 

+ 7 - 5
traph/source/test/main.cpp

@@ -36,13 +36,15 @@ int main()
 	*/
 	// auto a = traph::Variable<traph::f32>({ 2, 3 });
 
-	auto a = traph::ones<traph::f32>({ 2,3 });
+	auto a = traph::ones<traph::f32>({ 2,3,2 });
 	a->requires_grad_(true);
-	auto b = traph::ones<traph::f32>({ 2,3 });
-	auto c = traph::add<traph::f32>(a, b);
-	auto d = traph::sum<traph::f32>(c);
+	auto b = traph::sin<traph::f32>(a);
+	auto c = traph::ones<traph::f32>({ 2,3,2 });
+	c->requires_grad_(true);
+	auto d = traph::add<traph::f32>(b, c);
+	auto e = traph::sum<traph::f32>(d);
 
-	d->backward();
+	e->backward();
 
 	std::cout << a->grad()->to_string();