1
0
JasonWang 6 жил өмнө
parent
commit
70f36760af

+ 2 - 0
traph/include/traph/core/variable.h

@@ -19,6 +19,7 @@ namespace traph
 
     public:
         virtual void backward() = 0;
+        virtual void clear_graph() = 0;
         virtual TensorInterfacePtr data() = 0;
         virtual void data_(TensorInterfacePtr d) = 0;
         virtual device_id device() = 0;
@@ -62,6 +63,7 @@ namespace traph
         using ByteVariableBase = VariableBase<u8>;
     public:
         virtual void backward() = 0;
+        virtual void clear_graph() = 0;
         virtual TensorInterfacePtr data() = 0;
         virtual void data_(TensorInterfacePtr d) = 0;
         virtual device_id device() = 0;

+ 13 - 1
traph/include/traph/nn/function.h

@@ -74,6 +74,9 @@ namespace traph
 
 		std::shared_ptr<VariableInterface> result(new Variable<T>(dim, false));
 
+		if(requires_grad)
+			result->requires_grad_(true);
+
 		return result;
 	}
 
@@ -87,6 +90,9 @@ namespace traph
 		std::shared_ptr<VariableInterface> result(new Variable<T>(dim));
 		std::dynamic_pointer_cast<TensorBase<T>>(result->data())->fill_(0);
 
+		if(requires_grad)
+			result->requires_grad_(true);
+
 		return result;
 	}
 
@@ -100,6 +106,9 @@ namespace traph
 		std::shared_ptr<VariableInterface> result(new Variable<T>(dim));
 		std::dynamic_pointer_cast<TensorBase<T>>(result->data())->fill_(1);
 
+		if(requires_grad)
+			result->requires_grad_(true);
+
 		return result;
 	}
 
@@ -119,6 +128,8 @@ namespace traph
 		result_data->apply_([&d, &gen](T n){
 			return d(gen);
 		});
+		if(requires_grad)
+			result->requires_grad_(true);
 
 		return result;
 	}
@@ -127,7 +138,8 @@ namespace traph
 	VariableInterfacePtr empty_like(VariableInterfacePtr input, bool requires_grad = false)
 	{
 		std::shared_ptr<VariableInterface> result(new Variable<T>(input->size(), false));
-
+		if(requires_grad)
+			result->requires_grad_(true);
 		return result;
 	}
 

+ 4 - 4
traph/include/traph/nn/layers/linear.h

@@ -18,12 +18,12 @@ namespace traph
         {
             _in_features = in_features;
             _out_features = out_features;
-            _weight = std::shared_ptr<VariableInterface>(new FloatParameter({out_features, in_features}));
+            _weight = randn<f32>({out_features, in_features}, true);
             if(bias)
-                _bias = std::shared_ptr<VariableInterface>(new FloatParameter({out_features}));
+                _bias = randn<f32>({out_features}, true);
             
-            register_parameter("weight", std::dynamic_pointer_cast<FloatParameter>(_weight));
-            register_parameter("bias", std::dynamic_pointer_cast<FloatParameter>(_bias));
+            register_parameter("weight", _weight);
+            register_parameter("bias", _bias);
         }
 
         std::shared_ptr<VariableInterface> forward(std::shared_ptr<VariableInterface> input)

+ 3 - 2
traph/include/traph/nn/module.h

@@ -38,14 +38,15 @@ namespace traph
             return result;
         }
 
-        std::vector<std::shared_ptr<VariableInterface>> parameters(bool recurse)
+        std::vector<std::shared_ptr<VariableInterface>> parameters(bool recurse=true)
         {
             std::vector<std::shared_ptr<VariableInterface>> result;
             if(recurse)
             {
                 // fixme: children params recurse
                 for (const auto &p : _parameters)
-                    result.push_back(p.second);
+					if(p.second)
+						result.push_back(p.second);
             }
             else
             {

+ 2 - 1
traph/include/traph/nn/optim.h

@@ -24,7 +24,8 @@ namespace traph
         {
             for(auto& each_param: _params)
             {
-                each_param->grad()->fill_(0);
+				if(each_param->grad())
+					each_param->grad()->fill_(0);
             }
         }
     };

+ 15 - 22
traph/include/traph/nn/variable.h

@@ -27,7 +27,6 @@ namespace traph
     private:
         std::shared_ptr<TensorBase<T>> _data;
         std::shared_ptr<TensorBase<f32>> _grad;
-        bool _requires_grad;
         std::shared_ptr<OpBase> _grad_fn;
         std::vector<VariableInterfacePtr> _inputs;
         // std::vector<std::weak_ptr<VariableInterface>> _outputs;
@@ -44,22 +43,8 @@ namespace traph
 
         ~Variable();
 
-		template<class T>
-		friend std::shared_ptr<Variable<T>> sum(std::shared_ptr<Variable<T>> input);
-
-		template<class T>
-		friend std::shared_ptr<Variable<T>> add(std::shared_ptr<Variable<T>> left, std::shared_ptr<Variable<T>> right);
-
-		template<class T>
-		friend std::shared_ptr<Variable<T>> matmul(std::shared_ptr<Variable<T>> left, std::shared_ptr<Variable<T>> right);
-
-		template<class T>
-		friend std::shared_ptr<Variable<T>> select(std::shared_ptr<Variable<T>> input, const SliceVector& slice);
-
-		template<class T>
-		friend std::shared_ptr<Variable<T>> sin(std::shared_ptr<Variable<T>> input);
-
         virtual void backward() override;
+		virtual void clear_graph() override;
         virtual TensorInterfacePtr data() override;
 		virtual void data_(TensorInterfacePtr d) override;
         virtual device_id device() override;
@@ -113,7 +98,6 @@ namespace traph
 	template<typename T>
 	Variable<T>::Variable()
 		:_data(new Tensor<T>), _grad(nullptr),
-		_requires_grad(false),
 		_grad_fn(nullptr), _inputs()
 	{
 
@@ -122,7 +106,6 @@ namespace traph
 	template<typename T>
 	Variable<T>::Variable(std::shared_ptr<TensorBase<T>> data)
 		:_data(data), _grad(nullptr),
-		_requires_grad(false),
 		_grad_fn(nullptr), _inputs()
 	{
 	}
@@ -130,7 +113,6 @@ namespace traph
 	template<typename T>
 	Variable<T>::Variable(const DimVector& dim)
 		:_data(new Tensor<T>(dim)), _grad(nullptr),
-		_requires_grad(false),
 		_grad_fn(nullptr), _inputs()
 	{
 	}
@@ -138,7 +120,6 @@ namespace traph
 	template<typename T>
 	Variable<T>::Variable(std::initializer_list<idx_type> l)
 		:_data(new Tensor<T>()), _grad(nullptr),
-		_requires_grad(false),
 		_grad_fn(nullptr), _inputs()
 	{
 		DimVector dim;
@@ -178,6 +159,19 @@ namespace traph
 			}
 		}
 
+		// TODO:retain_graph
+		clear_graph();
+	}
+
+	template<typename T>
+	void Variable<T>::clear_graph()
+	{
+		_grad_fn = nullptr;
+		for(auto &each:_inputs)
+		{
+			each->clear_graph();
+		}
+		_inputs.clear();
 	}
 
     template<typename T>
@@ -281,13 +275,12 @@ namespace traph
 	template<typename T>
 	bool Variable<T>::requires_grad() const
 	{
-		return _requires_grad;
+		return bool(_grad);
 	}
 
 	template<typename T>
 	void Variable<T>::requires_grad_(bool requires_grad)
 	{
-		_requires_grad = requires_grad;
 		if (requires_grad)
 		{
 			_grad = _data->create_grad();

+ 1 - 0
traph/source/nn/CMakeLists.txt

@@ -14,6 +14,7 @@ SET(NN_LIST
 	${HEADER_PATH}/function.h
 	${HEADER_PATH}/operation.h
 	${SOURCE_PATH}/operation.cpp
+	${HEADER_PATH}/optim.h
 )
 
 ADD_LIBRARY(${LIB_OUTNAME} ${NN_LIST})

+ 20 - 16
traph/source/test/main.cpp

@@ -4,7 +4,7 @@
 #include <traph/nn/layers/linear.h>
 #include <traph/nn/layers/loss.h>
 #include <traph/core/tensor.h>
-#include <traph/tensor/byte_tensor.h>
+#include <traph/nn/optim.h>
 
 #include <iostream>
 
@@ -59,28 +59,32 @@ int main()
 	d->backward();
 	std::cout << a->grad()->to_string();
 	*/
-/*
+
 	int batch_size = 16;
 	
-	auto x = traph::randn<traph::f32>({ batch_size,4 });
-	auto y = traph::randn<traph::f32>({ batch_size,2 });
+	auto x = traph::ones<traph::f32>({ batch_size,4 });
+	auto y = traph::ones<traph::f32>({ batch_size,2 });
 
 	traph::Linear linear_model(4, 2, false);
-	traph::MSELoss loss;
-
-	auto out = linear_model.forward(x);
-	auto result = loss.forward(out, y);
+	traph::MSELoss criterion;
+	traph::SGD optimizer(linear_model.parameters(), 0.001f);
+	std::cout << y->data()->to_string() << std::endl;
 
-	result->backward();
-	std::cout << x->data()->to_string() << std::endl;
-	std::cout << linear_model.parameters(true)[0]->grad()->to_string() << std::endl;
-*/
+	std::cout << "Start Training..." << std::endl;
 
+	for (int epoch = 0; epoch < 100; ++epoch)
+	{
+		float loss100 = 0.f;
 
-	auto x = traph::ones<traph::u8>({ 2, 3, 4 });
-	x->data()->transpose_(1, 2);
-	std::dynamic_pointer_cast<traph::TensorBase<traph::u8>>(x->data())->fill_(5);
-	std::cout << x->data()->to_string();
+		optimizer.zero_grad();
+		auto out = linear_model.forward(x);
+		auto loss = criterion.forward(out, y);
+		loss->backward();
+		optimizer.step();
+		// loss100 += loss->item();
+		std::cout << loss->data()->to_string()<<std::endl;
+	}
+	
 	
     return 0;
 }