Bladeren bron

add autograd some basic functions

JasonWang 7 jaren geleden
bovenliggende
commit
4bb604d4db

+ 1 - 1
traph/include/traph/core/tensor.h

@@ -72,7 +72,7 @@ namespace traph
     public:
         virtual void apply_(std::function<T(T)> f) = 0;
         virtual void cos_() = 0;
-        virtual TensorBasePtr create_grad() = 0;
+        virtual std::shared_ptr<TensorBase<f32>> create_grad() = 0;
         virtual device_id device() = 0;
         virtual void fill_(T value) = 0;
         virtual T item() const = 0;

+ 6 - 0
traph/include/traph/core/variable.h

@@ -18,8 +18,11 @@ namespace traph
     public:
         virtual void backward() = 0;
         virtual device_id device() = 0;
+        virtual TensorBasePtr<f32> grad() = 0;
+        virtual std::vector<VariableInterfacePtr>& inputs() = 0;
         virtual idx_type offset() const = 0;
 		virtual layout_type order() const = 0;
+        virtual std::vector<std::weak_ptr<VariableInterface>>& outputs() = 0;
         virtual platform_type platform() = 0;
         virtual void requires_grad_(bool requires_grad) = 0;
         virtual void reshape_(const DimVector& dims) = 0;
@@ -51,9 +54,12 @@ namespace traph
         virtual void backward() = 0;
         virtual device_id device() = 0;
         virtual void fill_(T value) = 0;
+        virtual TensorBasePtr<f32> grad() = 0;
+        virtual std::vector<VariableInterfacePtr>& inputs() = 0;
         virtual T item() const = 0;
         virtual idx_type offset() const = 0;
 		virtual layout_type order() const = 0;
+        virtual std::vector<std::weak_ptr<VariableInterface>>& outputs() = 0;
         virtual platform_type platform() = 0;
         virtual void requires_grad_(bool requires_grad) = 0;
         virtual void reshape_(const DimVector& dims) = 0;

+ 2 - 0
traph/include/traph/nn/arithmetic.h

@@ -28,6 +28,8 @@ namespace traph
             result->_leaf = false;
             result->_grad_fn = op;
             result->_inputs = result_inputs;
+            
+            input->_outputs.push_back(result);
         }
         else
         {

+ 61 - 0
traph/include/traph/nn/executor.h

@@ -2,6 +2,13 @@
 #define TRAPH_NN_EXECUTOR_H_
 
 #include <vector>
+#include <list>
+#include <set>
+#include <algorithm>
+#include <iterator>
+#include <cassert>
+
+#include <traph/core/variable.h>
 
 namespace traph
 {
@@ -9,7 +16,61 @@ namespace traph
     {
     private:
     public:
+        static void backward(VariableInterface* root)
+        {
+            
+        }
+
+        static std::vector<VariableInterface*> topology_sort(VariableInterface* root)
+        {
+            std::set<VariableInterface*> all_nodes = collect_backward_tensors(root);
+            std::vector<VariableInterface*> visited_nodes;
+            int all_size = all_nodes.size();
+
+            for(int i = 0; i<all_size; ++i)
+            {
+                for(std::set<VariableInterface*>::iterator it = all_nodes.begin(); it != all_nodes.end(); ++it)
+                {
+                    std::vector<VariableInterfacePtr> cur_inputs = (*it)->inputs();
+                    std::vector<VariableInterfacePtr> cur_inputs_no_visited;
+                    std::set_difference(cur_inputs.begin(), cur_inputs.end(), visited_nodes.begin(), visited_nodes.end(), 
+                            std::inserter(cur_inputs_no_visited, cur_inputs_no_visited.begin()));
+                    if(cur_inputs_no_visited.empty())
+                    {
+                        visited_nodes.push_back(*it);
+                        all_nodes.erase(it);
+                        break;
+                    }
+                }
+            }
+
+            return visited_nodes;
+        }
+
+        static std::set<VariableInterface*> collect_backward_tensors(VariableInterface* root)
+        {
+            // bfs
+            std::set<VariableInterface*> result_set;
+            std::list<VariableInterface*> variable_queue;
+            variable_queue.push_back(root);
+            while(!variable_queue.empty())
+            {
+                // get first
+                VariableInterface* cur = variable_queue.front();
+                variable_queue.pop_front();
+                // save
+                result_set.insert(cur);
+
+                // get inputs and insert into queue then continue
+                std::vector<VariableInterfacePtr>& cur_inputs = cur->inputs();
+                for(int i = 0; i<cur_inputs.size(); ++i)
+                {
+                    variable_queue.push_back(cur_inputs[i].get());
+                }
+            }
 
+            return result_set;
+        }
     };
 }
 

+ 3 - 2
traph/include/traph/nn/operation.h

@@ -39,6 +39,7 @@ namespace traph
     {
     public:
         OpContext context;
+        virtual std::vector<TensorBasePtr<f32>> backward(TensorBasePtr<f32> output_grad) = 0;
     };
 
     template<class T>
@@ -46,7 +47,7 @@ namespace traph
     {
     public:
         virtual TensorBasePtr<T> forward(std::vector<TensorBasePtr<T>> inputs) = 0;
-        virtual std::vector<TensorBasePtr<T>> backward(TensorBasePtr<T> output_grad) = 0;
+        virtual std::vector<TensorBasePtr<f32>> backward(TensorBasePtr<f32> output_grad) = 0;
     };
 
     template<class T>
@@ -63,7 +64,7 @@ namespace traph
 			return result;
         }
 
-        virtual std::vector<TensorBasePtr<T>> backward(TensorBasePtr<T> output_grad) override
+        virtual std::vector<TensorBasePtr<f32>> backward(TensorBasePtr<f32> output_grad) override
         {
             return {output_grad};
         }

+ 41 - 7
traph/include/traph/nn/variable.h

@@ -5,12 +5,15 @@
 #include <functional>
 #include <initializer_list>
 #include <vector>
+#include <list>
+#include <cassert>
 
 #include <traph/core/index.h>
 #include <traph/core/tensor.h>
 #include <traph/core/variable.h>
 #include <traph/tensor/tensor.h>
 #include <traph/nn/operation.h>
+#include <traph/nn/executor.h>
 
 namespace traph
 {
@@ -31,16 +34,17 @@ namespace traph
         using ByteVariable = Variable<u8>;
     private:
         std::shared_ptr<TensorBase<T>> _data;
-        std::shared_ptr<TensorBase<T>> _grad;
+        std::shared_ptr<TensorBase<f32>> _grad;
         bool _requires_grad;
         bool _leaf;
         std::shared_ptr<OpInterface<T>> _grad_fn;
         std::vector<VariableInterfacePtr> _inputs;
+        std::vector<std::weak_ptr<VariableInterfacePtr>> _outputs;
     public:
         Variable()
             :_data(new Tensor<T>), _grad(nullptr),
             _requires_grad(false), _leaf(false),
-            _grad_fn(nullptr)
+            _grad_fn(nullptr), _inputs(), _outputs()
         {
 
         }
@@ -48,21 +52,21 @@ namespace traph
         Variable(std::shared_ptr<TensorBase<T>> data)
             :_data(data), _grad(nullptr),
             _requires_grad(false), _leaf(false),
-            _grad_fn(nullptr)
+            _grad_fn(nullptr), _inputs(), _outputs()
         {
         }
 
         Variable(const DimVector& dim)
             :_data(new Tensor<T>(dim)), _grad(nullptr),
             _requires_grad(false), _leaf(false),
-            _grad_fn(nullptr)
+            _grad_fn(nullptr), _inputs(), _outputs()
         {
         }
 
         Variable(const DimVector& dim, bool is_leaf)
             :_data(new Tensor<T>(dim)), _grad(nullptr),
             _requires_grad(false), _leaf(is_leaf),
-            _grad_fn(nullptr)
+            _grad_fn(nullptr), _inputs(), _outputs()
         {
             if(is_leaf)
             {
@@ -75,7 +79,7 @@ namespace traph
         Variable(std::initializer_list<idx_type> l)
             :_data(new Tensor<T>()), _grad(nullptr),
             _requires_grad(false), _leaf(false),
-            _grad_fn(nullptr)
+            _grad_fn(nullptr), _inputs(), _outputs()
         {
             DimVector dim;
             for (auto i : l)
@@ -105,6 +109,20 @@ namespace traph
 
         virtual void backward() override
         {
+            _grad->fill_(1);
+            
+            std::vector<VariableInterface*> sorted_node = Executor::topology_sort(this);
+            for(int i = sorted_node.size() - 1; i >=0; --i)
+            {
+                VariableInterface* cur_node = sorted_node[i];
+                std::vector<TensorBasePtr<T>> back_grad = cur_node->_grad_fn->backward(cur_node->grad());
+
+                assert(back_grad.size() == _inputs.size());
+                for(int i = 0; i < cur_node->inputs().size(); ++i)
+                {
+                    cur_node->inputs()[i]->grad().add_(back_grad[i]);
+                }
+            }
 
         }
         virtual device_id device() override
@@ -115,6 +133,14 @@ namespace traph
         {
             return _data->fill_(value);
         }
+        virtual TensorBasePtr<f32> grad() override
+        {
+            return _grad;
+        }
+        virtual std::vector<VariableInterfacePtr>& inputs() override
+        {
+            return _inputs;
+        }
         virtual T item() const override
         {
             return _data->item();
@@ -127,6 +153,10 @@ namespace traph
         {
             return _data->order();
         }
+        virtual std::vector<std::weak_ptr<VariableInterface>>& outputs() override
+        {
+            return _outputs;
+        }
         virtual platform_type platform() override
         {
             return _data->platform();
@@ -135,10 +165,14 @@ namespace traph
         {
             _requires_grad = requires_grad;
             if(requires_grad)
+            {
                 _grad = _data->create_grad();
+                _grad->fill_(0);
+            }
             else
+            {
                 _grad = std::shared_ptr<TensorBase<T>>(nullptr);
-            
+            }
         }
         virtual void reshape_(const DimVector& dims) override
         {

+ 2 - 2
traph/include/traph/tensor/tensor.h

@@ -277,9 +277,9 @@ namespace traph
         {
 			apply_([](T a)->T {return std::cos(a); });
         }
-        virtual TensorBasePtr create_grad() override
+        virtual std::shared_ptr<TensorBase<f32>> create_grad() override
         {
-            return std::shared_ptr<TensorBase<T>>(new Tensor<T>(_dimensions));
+            return std::shared_ptr<TensorBase<f32>>(new Tensor<f32>(_dimensions));
         }
         virtual device_id device() override { return 0; }
         virtual void fill_(T value) override