|
|
@@ -5,12 +5,15 @@
|
|
|
#include <functional>
|
|
|
#include <initializer_list>
|
|
|
#include <vector>
|
|
|
+#include <list>
|
|
|
+#include <cassert>
|
|
|
|
|
|
#include <traph/core/index.h>
|
|
|
#include <traph/core/tensor.h>
|
|
|
#include <traph/core/variable.h>
|
|
|
#include <traph/tensor/tensor.h>
|
|
|
#include <traph/nn/operation.h>
|
|
|
+#include <traph/nn/executor.h>
|
|
|
|
|
|
namespace traph
|
|
|
{
|
|
|
@@ -31,16 +34,17 @@ namespace traph
|
|
|
using ByteVariable = Variable<u8>;
|
|
|
private:
|
|
|
std::shared_ptr<TensorBase<T>> _data;
|
|
|
- std::shared_ptr<TensorBase<T>> _grad;
|
|
|
+ std::shared_ptr<TensorBase<f32>> _grad;
|
|
|
bool _requires_grad;
|
|
|
bool _leaf;
|
|
|
std::shared_ptr<OpInterface<T>> _grad_fn;
|
|
|
std::vector<VariableInterfacePtr> _inputs;
|
|
|
+ std::vector<std::weak_ptr<VariableInterfacePtr>> _outputs;
|
|
|
public:
|
|
|
Variable()
|
|
|
:_data(new Tensor<T>), _grad(nullptr),
|
|
|
_requires_grad(false), _leaf(false),
|
|
|
- _grad_fn(nullptr)
|
|
|
+ _grad_fn(nullptr), _inputs(), _outputs()
|
|
|
{
|
|
|
|
|
|
}
|
|
|
@@ -48,21 +52,21 @@ namespace traph
|
|
|
Variable(std::shared_ptr<TensorBase<T>> data)
|
|
|
:_data(data), _grad(nullptr),
|
|
|
_requires_grad(false), _leaf(false),
|
|
|
- _grad_fn(nullptr)
|
|
|
+ _grad_fn(nullptr), _inputs(), _outputs()
|
|
|
{
|
|
|
}
|
|
|
|
|
|
Variable(const DimVector& dim)
|
|
|
:_data(new Tensor<T>(dim)), _grad(nullptr),
|
|
|
_requires_grad(false), _leaf(false),
|
|
|
- _grad_fn(nullptr)
|
|
|
+ _grad_fn(nullptr), _inputs(), _outputs()
|
|
|
{
|
|
|
}
|
|
|
|
|
|
Variable(const DimVector& dim, bool is_leaf)
|
|
|
:_data(new Tensor<T>(dim)), _grad(nullptr),
|
|
|
_requires_grad(false), _leaf(is_leaf),
|
|
|
- _grad_fn(nullptr)
|
|
|
+ _grad_fn(nullptr), _inputs(), _outputs()
|
|
|
{
|
|
|
if(is_leaf)
|
|
|
{
|
|
|
@@ -75,7 +79,7 @@ namespace traph
|
|
|
Variable(std::initializer_list<idx_type> l)
|
|
|
:_data(new Tensor<T>()), _grad(nullptr),
|
|
|
_requires_grad(false), _leaf(false),
|
|
|
- _grad_fn(nullptr)
|
|
|
+ _grad_fn(nullptr), _inputs(), _outputs()
|
|
|
{
|
|
|
DimVector dim;
|
|
|
for (auto i : l)
|
|
|
@@ -105,6 +109,20 @@ namespace traph
|
|
|
|
|
|
virtual void backward() override
|
|
|
{
|
|
|
+ _grad->fill_(1);
|
|
|
+
|
|
|
+ std::vector<VariableInterface*> sorted_node = Executor::topology_sort(this);
|
|
|
+ for(int i = sorted_node.size() - 1; i >=0; --i)
|
|
|
+ {
|
|
|
+ VariableInterface* cur_node = sorted_node[i];
|
|
|
+ std::vector<TensorBasePtr<T>> back_grad = cur_node->_grad_fn->backward(cur_node->grad());
|
|
|
+
|
|
|
+ assert(back_grad.size() == _inputs.size());
|
|
|
+ for(int i = 0; i < cur_node->inputs().size(); ++i)
|
|
|
+ {
|
|
|
+ cur_node->inputs()[i]->grad().add_(back_grad[i]);
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
}
|
|
|
virtual device_id device() override
|
|
|
@@ -115,6 +133,14 @@ namespace traph
|
|
|
{
|
|
|
return _data->fill_(value);
|
|
|
}
|
|
|
+ virtual TensorBasePtr<f32> grad() override
|
|
|
+ {
|
|
|
+ return _grad;
|
|
|
+ }
|
|
|
+ virtual std::vector<VariableInterfacePtr>& inputs() override
|
|
|
+ {
|
|
|
+ return _inputs;
|
|
|
+ }
|
|
|
virtual T item() const override
|
|
|
{
|
|
|
return _data->item();
|
|
|
@@ -127,6 +153,10 @@ namespace traph
|
|
|
{
|
|
|
return _data->order();
|
|
|
}
|
|
|
+ virtual std::vector<std::weak_ptr<VariableInterface>>& outputs() override
|
|
|
+ {
|
|
|
+ return _outputs;
|
|
|
+ }
|
|
|
virtual platform_type platform() override
|
|
|
{
|
|
|
return _data->platform();
|
|
|
@@ -135,10 +165,14 @@ namespace traph
|
|
|
{
|
|
|
_requires_grad = requires_grad;
|
|
|
if(requires_grad)
|
|
|
+ {
|
|
|
_grad = _data->create_grad();
|
|
|
+ _grad->fill_(0);
|
|
|
+ }
|
|
|
else
|
|
|
+ {
|
|
|
_grad = std::shared_ptr<TensorBase<T>>(nullptr);
|
|
|
-
|
|
|
+ }
|
|
|
}
|
|
|
virtual void reshape_(const DimVector& dims) override
|
|
|
{
|