|
|
@@ -16,162 +16,219 @@
|
|
|
namespace traph
|
|
|
{
|
|
|
// creation function
|
|
|
- template<class T>
|
|
|
- VariablePtr<T> zeros(std::initializer_list<idx_type> l, bool requires_grad = false)
|
|
|
+ template<typename T>
|
|
|
+ VariableInterfacePtr empty(std::initializer_list<idx_type> l, bool requires_grad = false)
|
|
|
{
|
|
|
DimVector dim;
|
|
|
for (auto i : l)
|
|
|
dim.push_back(i);
|
|
|
|
|
|
- std::shared_ptr<Variable<T>> result(new Variable<T>(dim, false));
|
|
|
+ std::shared_ptr<VariableInterface> result(new Variable<T>(dim, false));
|
|
|
result->leaf_(true);
|
|
|
- result->fill_(0);
|
|
|
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
- template<class T>
|
|
|
- VariablePtr<T> ones(std::initializer_list<idx_type> l, bool requires_grad = false)
|
|
|
+ template<typename T>
|
|
|
+ VariableInterfacePtr zeros(std::initializer_list<idx_type> l, bool requires_grad = false)
|
|
|
{
|
|
|
DimVector dim;
|
|
|
for (auto i : l)
|
|
|
dim.push_back(i);
|
|
|
|
|
|
- std::shared_ptr<Variable<T>> result(new Variable<T>(dim, false));
|
|
|
+ std::shared_ptr<VariableInterface> result(new Variable<T>(dim, false));
|
|
|
+ result->leaf_(true);
|
|
|
+ std::dynamic_pointer_cast<TensorBase<T>>(result->data())->fill_(0);
|
|
|
+
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
+ template<typename T>
|
|
|
+ VariableInterfacePtr ones(std::initializer_list<idx_type> l, bool requires_grad = false)
|
|
|
+ {
|
|
|
+ DimVector dim;
|
|
|
+ for (auto i : l)
|
|
|
+ dim.push_back(i);
|
|
|
+
|
|
|
+ std::shared_ptr<VariableInterface> result(new Variable<T>(dim, false));
|
|
|
+ result->leaf_(true);
|
|
|
+ std::dynamic_pointer_cast<TensorBase<T>>(result->data())->fill_(1);
|
|
|
+
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
+ template<typename T>
|
|
|
+ VariableInterfacePtr empty_like(VariableInterfacePtr input, bool requires_grad = false)
|
|
|
+ {
|
|
|
+ std::shared_ptr<VariableInterface> result(new Variable<T>(input->size(), false));
|
|
|
result->leaf_(true);
|
|
|
- result->fill_(1);
|
|
|
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
// arithmetic function
|
|
|
- template<class T>
|
|
|
- VariablePtr<T> sum(VariablePtr<T> input)
|
|
|
+ VariableInterfacePtr sum(VariableInterfacePtr input)
|
|
|
{
|
|
|
- VariablePtr<T> result(new Variable<T>);
|
|
|
+ DimVector result_dim(1);
|
|
|
+ result_dim[0] = 1;
|
|
|
+
|
|
|
+ VariableInterfacePtr result = input->new_empty(result_dim, true);
|
|
|
std::shared_ptr<SumOp> op(new SumOp);
|
|
|
- if(input->_requires_grad)
|
|
|
+ if(input->requires_grad())
|
|
|
{
|
|
|
- std::vector<VariableInterfacePtr> result_inputs { std::dynamic_pointer_cast<VariableInterface>(input) };
|
|
|
- result->_data = std::dynamic_pointer_cast<TensorBase<T>>(op->forward({ input->_data }));
|
|
|
- result->_grad = result->_data->create_grad();
|
|
|
- result->_requires_grad = true;
|
|
|
- result->_leaf = false;
|
|
|
- result->_grad_fn = op;
|
|
|
- result->_inputs = result_inputs;
|
|
|
+ std::vector<VariableInterfacePtr> result_inputs { input };
|
|
|
+ result->data_(op->forward({ input->data() }));
|
|
|
+ result->grad_(result->data()->create_grad());
|
|
|
+ result->grad()->fill_(0);
|
|
|
+ result->requires_grad_(true);
|
|
|
+ result->leaf_(false);
|
|
|
+ result->grad_fn_(op);
|
|
|
+ result->inputs_(result_inputs);
|
|
|
}
|
|
|
else
|
|
|
{
|
|
|
- result->_data = std::dynamic_pointer_cast<TensorBase<T>>(op->forward({ input->_data }));
|
|
|
- result->_requires_grad = false;
|
|
|
- result->_leaf = false;
|
|
|
+ result->data_(op->forward({ input->data() }));
|
|
|
+ result->requires_grad_(false);
|
|
|
+ result->leaf_(false);
|
|
|
}
|
|
|
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
- template<class T>
|
|
|
- VariablePtr<T> add(VariablePtr<T> left, VariablePtr<T> right)
|
|
|
+ VariableInterfacePtr add(VariableInterfacePtr left, VariableInterfacePtr right)
|
|
|
{
|
|
|
- VariablePtr<T> result(new Variable<T>);
|
|
|
+ DimVector result_dim;
|
|
|
+
|
|
|
+ VariableInterfacePtr result = left->new_empty(result_dim, true);
|
|
|
std::shared_ptr<AddOp> op(new AddOp);
|
|
|
- if (left->_requires_grad || right->_requires_grad)
|
|
|
+ if (left->requires_grad() || right->requires_grad())
|
|
|
{
|
|
|
std::vector<VariableInterfacePtr> result_inputs{ left, right };
|
|
|
- result->_data = std::dynamic_pointer_cast<TensorBase<T>>(op->forward({ left->_data, right->_data }));
|
|
|
- result->_grad = result->_data->create_grad();
|
|
|
- result->_grad->fill_(0);
|
|
|
- result->_requires_grad = true;
|
|
|
- result->_leaf = false;
|
|
|
- result->_grad_fn = op;
|
|
|
- result->_inputs = result_inputs;
|
|
|
+ result->data_(op->forward({ left->data(), right->data() }));
|
|
|
+ result->grad_(result->data()->create_grad());
|
|
|
+ result->grad()->fill_(0);
|
|
|
+ result->requires_grad_(true);
|
|
|
+ result->leaf_(false);
|
|
|
+ result->grad_fn_(op);
|
|
|
+ result->inputs_(result_inputs);
|
|
|
}
|
|
|
else
|
|
|
{
|
|
|
- result->_data = std::dynamic_pointer_cast<TensorBase<T>>(op->forward({ left->_data, right->_data }));
|
|
|
- result->_requires_grad = false;
|
|
|
- result->_leaf = false;
|
|
|
+ result->data_(op->forward({ left->data(), right->data() }));
|
|
|
+ result->requires_grad_(false);
|
|
|
+ result->leaf_(false);
|
|
|
}
|
|
|
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
- template<class T>
|
|
|
- VariablePtr<T> matmul(VariablePtr<T> left, VariablePtr<T> right)
|
|
|
+ VariableInterfacePtr matmul(VariableInterfacePtr left, VariableInterfacePtr right)
|
|
|
{
|
|
|
- VariablePtr<T> result(new Variable<T>);
|
|
|
+ DimVector result_dim;
|
|
|
+
|
|
|
+ VariableInterfacePtr result = left->new_empty(result_dim, true);
|
|
|
std::shared_ptr<MatmulOp> op(new MatmulOp);
|
|
|
- if (left->_requires_grad || right->_requires_grad)
|
|
|
+ if (left->requires_grad() || right->requires_grad())
|
|
|
{
|
|
|
std::vector<VariableInterfacePtr> result_inputs{ left, right };
|
|
|
- result->_data = std::dynamic_pointer_cast<TensorBase<T>>(op->forward({ left->_data, right->_data }));
|
|
|
- result->_grad = result->_data->create_grad();
|
|
|
- result->_grad->fill_(0);
|
|
|
- result->_requires_grad = true;
|
|
|
- result->_leaf = false;
|
|
|
- result->_grad_fn = op;
|
|
|
- result->_inputs = result_inputs;
|
|
|
+ result->data_(op->forward({ left->data(), right->data() }));
|
|
|
+ result->grad_(result->data()->create_grad());
|
|
|
+ result->grad()->fill_(0);
|
|
|
+ result->requires_grad_(true);
|
|
|
+ result->leaf_(false);
|
|
|
+ result->grad_fn_(op);
|
|
|
+ result->inputs_(result_inputs);
|
|
|
}
|
|
|
else
|
|
|
{
|
|
|
- result->_data = std::dynamic_pointer_cast<TensorBase<T>>(op->forward({ left->_data, right->_data }));
|
|
|
- result->_requires_grad = false;
|
|
|
- result->_leaf = false;
|
|
|
+ result->data_(op->forward({ left->data(), right->data() }));
|
|
|
+ result->requires_grad_(false);
|
|
|
+ result->leaf_(false);
|
|
|
}
|
|
|
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
|
|
|
- template<class T>
|
|
|
- VariablePtr<T> select(VariablePtr<T> input, const SliceVector& slice)
|
|
|
+ VariableInterfacePtr select(VariableInterfacePtr input, const SliceVector& slice)
|
|
|
{
|
|
|
- VariablePtr<T> result(new Variable<T>);
|
|
|
+ DimVector result_dim;
|
|
|
+
|
|
|
+ VariableInterfacePtr result = input->new_empty(result_dim, true);
|
|
|
std::shared_ptr<SelectOp> op(new SelectOp);
|
|
|
op->set_slice(slice);
|
|
|
|
|
|
std::vector<VariableInterfacePtr> result_inputs{ input };
|
|
|
- result->_data = std::dynamic_pointer_cast<TensorBase<T>>(op->forward({ input->_data }));
|
|
|
- result->_leaf = false;
|
|
|
+ result->data_(op->forward({ input->data() }));
|
|
|
+ result->leaf_(false);
|
|
|
|
|
|
if (input->requires_grad())
|
|
|
{
|
|
|
- result->_grad = result->_data->create_grad();
|
|
|
- result->_grad->fill_(0);
|
|
|
- result->_requires_grad = true;
|
|
|
- result->_grad_fn = op;
|
|
|
- result->_inputs = result_inputs;
|
|
|
+ result->grad_(result->data()->create_grad());
|
|
|
+ result->grad()->fill_(0);
|
|
|
+ result->requires_grad_(true);
|
|
|
+ result->grad_fn_(op);
|
|
|
+ result->inputs_(result_inputs);
|
|
|
}
|
|
|
else
|
|
|
{
|
|
|
- result->_requires_grad = false;
|
|
|
+ result->requires_grad_(false);
|
|
|
}
|
|
|
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
|
|
|
- template<class T>
|
|
|
- VariablePtr<T> sin(VariablePtr<T> input)
|
|
|
+ VariableInterfacePtr sin(VariableInterfacePtr input)
|
|
|
{
|
|
|
- VariablePtr<T> result(new Variable<T>);
|
|
|
+ DimVector result_dim;
|
|
|
+
|
|
|
+ VariableInterfacePtr result = input->new_empty(result_dim, true);
|
|
|
std::shared_ptr<SinOp> op(new SinOp);
|
|
|
|
|
|
std::vector<VariableInterfacePtr> result_inputs{ input };
|
|
|
- result->_data = std::dynamic_pointer_cast<TensorBase<T>>(op->forward({ input->_data }));
|
|
|
- result->_leaf = false;
|
|
|
+ result->data_(op->forward({ input->data() }));
|
|
|
+ result->leaf_(false);
|
|
|
+
|
|
|
+ if (input->requires_grad())
|
|
|
+ {
|
|
|
+ result->grad_(result->data()->create_grad());
|
|
|
+ result->grad()->fill_(0);
|
|
|
+ result->requires_grad_(true);
|
|
|
+ result->grad_fn_(op);
|
|
|
+ result->inputs_(result_inputs);
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ result->requires_grad_(false);
|
|
|
+ }
|
|
|
+
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
+ VariableInterfacePtr transpose(VariableInterfacePtr input, idx_type dim0, idx_type dim1)
|
|
|
+ {
|
|
|
+ DimVector result_dim;
|
|
|
+
|
|
|
+ VariableInterfacePtr result = input->new_empty(result_dim, true);
|
|
|
+ std::shared_ptr<TransposeOp> op(new TransposeOp);
|
|
|
+ op->set_dim(dim0, dim1);
|
|
|
+
|
|
|
+ std::vector<VariableInterfacePtr> result_inputs{ input };
|
|
|
+ result->data_(op->forward({ input->data() }));
|
|
|
+ result->leaf_(false);
|
|
|
|
|
|
if (input->requires_grad())
|
|
|
{
|
|
|
- result->_grad = result->_data->create_grad();
|
|
|
- result->_grad->fill_(0);
|
|
|
- result->_requires_grad = true;
|
|
|
- result->_grad_fn = op;
|
|
|
- result->_inputs = result_inputs;
|
|
|
+ result->grad_(result->data()->create_grad());
|
|
|
+ result->grad()->fill_(0);
|
|
|
+ result->requires_grad_(true);
|
|
|
+ result->grad_fn_(op);
|
|
|
+ result->inputs_(result_inputs);
|
|
|
}
|
|
|
else
|
|
|
{
|
|
|
- result->_requires_grad = false;
|
|
|
+ result->requires_grad_(false);
|
|
|
}
|
|
|
|
|
|
return result;
|