Просмотр исходного кода

iteration clear graph instead of recursive

JasonWang 6 лет назад
Родитель
Сommit
1654e891ae

+ 0 - 2
traph/include/traph/core/variable.h

@@ -19,7 +19,6 @@ namespace traph
 
     public:
         virtual void backward() = 0;
-        virtual void clear_graph() = 0;
         virtual TensorInterfacePtr data() = 0;
         virtual void data_(TensorInterfacePtr d) = 0;
         virtual device_id device() = 0;
@@ -63,7 +62,6 @@ namespace traph
         using ByteVariableBase = VariableBase<u8>;
     public:
         virtual void backward() = 0;
-        virtual void clear_graph() = 0;
         virtual TensorInterfacePtr data() = 0;
         virtual void data_(TensorInterfacePtr d) = 0;
         virtual device_id device() = 0;

+ 0 - 35
traph/include/traph/nn/graph.h

@@ -1,35 +0,0 @@
-#ifndef TRAPH_NN_GRAPH_H_
-#define TRAPH_NN_GRAPH_H_
-
-#include <utility>
-#include <cmath>
-#include <string>
-#include <vector>
-
-#include <traph/core/type.h>
-
-
-namespace traph
-{
-    class FlowGraphNode
-    {
-    public:
-
-    };
-
-    class FlowGraphEdge
-    {
-    public:
-
-    };
-
-    class FlowGraph
-    {
-    private:
-        std::vector<FlowGraphNode> _nodes;
-        std::vector<FlowGraphEdge> _edges;
-    public:
-    };
-}
-
-#endif

+ 20 - 19
traph/include/traph/nn/operation.h

@@ -40,25 +40,6 @@ namespace traph
         virtual std::vector<TensorBasePtr<f32>> backward(TensorBasePtr<f32> output_grad) = 0;
     };
 
-    class SumOp: public OpBase
-    {
-    public:
-        virtual TensorInterfacePtr forward(std::vector<TensorInterfacePtr> inputs) override
-        {
-            assert(inputs.size() == 1);
-            
-			TensorInterfacePtr input = inputs[0];
-			TensorInterfacePtr result = input->sum();
-
-			return result;
-        }
-
-        virtual std::vector<TensorBasePtr<f32>> backward(TensorBasePtr<f32> output_grad) override
-        {
-            return {output_grad};
-        }
-    };
-
 	class AddOp : public OpBase
 	{
 	public:
@@ -191,6 +172,7 @@ namespace traph
 
 		virtual std::vector<TensorBasePtr<f32>> backward(TensorBasePtr<f32> output_grad) override
 		{
+			// fixme: bug
 			TensorBasePtr<f32> result = std::dynamic_pointer_cast<TensorBase<f32>>(output_grad->clone());
 			result->cos_();
 			return { result };
@@ -221,6 +203,25 @@ namespace traph
 		}
 	};
 
+	class SumOp: public OpBase
+    {
+    public:
+        virtual TensorInterfacePtr forward(std::vector<TensorInterfacePtr> inputs) override
+        {
+            assert(inputs.size() == 1);
+            
+			TensorInterfacePtr input = inputs[0];
+			TensorInterfacePtr result = input->sum();
+
+			return result;
+        }
+
+        virtual std::vector<TensorBasePtr<f32>> backward(TensorBasePtr<f32> output_grad) override
+        {
+            return {output_grad};
+        }
+    };
+
 	class TransposeOp : public OpBase
 	{
 	private:

+ 3 - 11
traph/include/traph/nn/variable.h

@@ -44,7 +44,6 @@ namespace traph
         ~Variable();
 
         virtual void backward() override;
-		virtual void clear_graph() override;
         virtual TensorInterfacePtr data() override;
 		virtual void data_(TensorInterfacePtr d) override;
         virtual device_id device() override;
@@ -160,18 +159,11 @@ namespace traph
 		}
 
 		// TODO:retain_graph
-		clear_graph();
-	}
-
-	template<typename T>
-	void Variable<T>::clear_graph()
-	{
-		for(auto &each:_inputs)
+		for (int i = static_cast<int>(sorted_node.size()) - 1; i >= 0; --i)
 		{
-			each->clear_graph();
+			_grad_fn = nullptr;
+			_inputs.clear();
 		}
-		_grad_fn = nullptr;
-		_inputs.clear();
 	}
 
     template<typename T>

+ 1 - 1
traph/source/test/main.cpp

@@ -67,7 +67,7 @@ int main()
 
 	traph::Linear linear_model(4, 2, false);
 	traph::MSELoss criterion;
-	traph::SGD optimizer(linear_model.parameters(), 0.0001f);
+	traph::SGD optimizer(linear_model.parameters(), 0.001f);
 	std::cout << y->data()->to_string() << std::endl;
 
 	std::cout << "Start Training..." << std::endl;