|
|
@@ -27,7 +27,6 @@ namespace traph
|
|
|
private:
|
|
|
std::shared_ptr<TensorBase<T>> _data;
|
|
|
std::shared_ptr<TensorBase<f32>> _grad;
|
|
|
- bool _requires_grad;
|
|
|
std::shared_ptr<OpBase> _grad_fn;
|
|
|
std::vector<VariableInterfacePtr> _inputs;
|
|
|
// std::vector<std::weak_ptr<VariableInterface>> _outputs;
|
|
|
@@ -44,22 +43,8 @@ namespace traph
|
|
|
|
|
|
~Variable();
|
|
|
|
|
|
- template<class T>
|
|
|
- friend std::shared_ptr<Variable<T>> sum(std::shared_ptr<Variable<T>> input);
|
|
|
-
|
|
|
- template<class T>
|
|
|
- friend std::shared_ptr<Variable<T>> add(std::shared_ptr<Variable<T>> left, std::shared_ptr<Variable<T>> right);
|
|
|
-
|
|
|
- template<class T>
|
|
|
- friend std::shared_ptr<Variable<T>> matmul(std::shared_ptr<Variable<T>> left, std::shared_ptr<Variable<T>> right);
|
|
|
-
|
|
|
- template<class T>
|
|
|
- friend std::shared_ptr<Variable<T>> select(std::shared_ptr<Variable<T>> input, const SliceVector& slice);
|
|
|
-
|
|
|
- template<class T>
|
|
|
- friend std::shared_ptr<Variable<T>> sin(std::shared_ptr<Variable<T>> input);
|
|
|
-
|
|
|
virtual void backward() override;
|
|
|
+ virtual void clear_graph() override;
|
|
|
virtual TensorInterfacePtr data() override;
|
|
|
virtual void data_(TensorInterfacePtr d) override;
|
|
|
virtual device_id device() override;
|
|
|
@@ -113,7 +98,6 @@ namespace traph
|
|
|
template<typename T>
|
|
|
Variable<T>::Variable()
|
|
|
:_data(new Tensor<T>), _grad(nullptr),
|
|
|
- _requires_grad(false),
|
|
|
_grad_fn(nullptr), _inputs()
|
|
|
{
|
|
|
|
|
|
@@ -122,7 +106,6 @@ namespace traph
|
|
|
template<typename T>
|
|
|
Variable<T>::Variable(std::shared_ptr<TensorBase<T>> data)
|
|
|
:_data(data), _grad(nullptr),
|
|
|
- _requires_grad(false),
|
|
|
_grad_fn(nullptr), _inputs()
|
|
|
{
|
|
|
}
|
|
|
@@ -130,7 +113,6 @@ namespace traph
|
|
|
template<typename T>
|
|
|
Variable<T>::Variable(const DimVector& dim)
|
|
|
:_data(new Tensor<T>(dim)), _grad(nullptr),
|
|
|
- _requires_grad(false),
|
|
|
_grad_fn(nullptr), _inputs()
|
|
|
{
|
|
|
}
|
|
|
@@ -138,7 +120,6 @@ namespace traph
|
|
|
template<typename T>
|
|
|
Variable<T>::Variable(std::initializer_list<idx_type> l)
|
|
|
:_data(new Tensor<T>()), _grad(nullptr),
|
|
|
- _requires_grad(false),
|
|
|
_grad_fn(nullptr), _inputs()
|
|
|
{
|
|
|
DimVector dim;
|
|
|
@@ -178,6 +159,19 @@ namespace traph
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ // TODO:retain_graph
|
|
|
+ clear_graph();
|
|
|
+ }
|
|
|
+
|
|
|
+ template<typename T>
|
|
|
+ void Variable<T>::clear_graph()
|
|
|
+ {
|
|
|
+ _grad_fn = nullptr;
|
|
|
+ for(auto &each:_inputs)
|
|
|
+ {
|
|
|
+ each->clear_graph();
|
|
|
+ }
|
|
|
+ _inputs.clear();
|
|
|
}
|
|
|
|
|
|
template<typename T>
|
|
|
@@ -281,13 +275,12 @@ namespace traph
|
|
|
template<typename T>
|
|
|
bool Variable<T>::requires_grad() const
|
|
|
{
|
|
|
- return _requires_grad;
|
|
|
+ return bool(_grad);
|
|
|
}
|
|
|
|
|
|
template<typename T>
|
|
|
void Variable<T>::requires_grad_(bool requires_grad)
|
|
|
{
|
|
|
- _requires_grad = requires_grad;
|
|
|
if (requires_grad)
|
|
|
{
|
|
|
_grad = _data->create_grad();
|