|
@@ -24,7 +24,6 @@ namespace traph
|
|
|
dim.push_back(i);
|
|
dim.push_back(i);
|
|
|
|
|
|
|
|
std::shared_ptr<VariableInterface> result(new Variable<T>(dim, false));
|
|
std::shared_ptr<VariableInterface> result(new Variable<T>(dim, false));
|
|
|
- result->leaf_(true);
|
|
|
|
|
|
|
|
|
|
return result;
|
|
return result;
|
|
|
}
|
|
}
|
|
@@ -36,8 +35,7 @@ namespace traph
|
|
|
for (auto i : l)
|
|
for (auto i : l)
|
|
|
dim.push_back(i);
|
|
dim.push_back(i);
|
|
|
|
|
|
|
|
- std::shared_ptr<VariableInterface> result(new Variable<T>(dim, false));
|
|
|
|
|
- result->leaf_(true);
|
|
|
|
|
|
|
+ std::shared_ptr<VariableInterface> result(new Variable<T>(dim));
|
|
|
std::dynamic_pointer_cast<TensorBase<T>>(result->data())->fill_(0);
|
|
std::dynamic_pointer_cast<TensorBase<T>>(result->data())->fill_(0);
|
|
|
|
|
|
|
|
return result;
|
|
return result;
|
|
@@ -50,8 +48,7 @@ namespace traph
|
|
|
for (auto i : l)
|
|
for (auto i : l)
|
|
|
dim.push_back(i);
|
|
dim.push_back(i);
|
|
|
|
|
|
|
|
- std::shared_ptr<VariableInterface> result(new Variable<T>(dim, false));
|
|
|
|
|
- result->leaf_(true);
|
|
|
|
|
|
|
+ std::shared_ptr<VariableInterface> result(new Variable<T>(dim));
|
|
|
std::dynamic_pointer_cast<TensorBase<T>>(result->data())->fill_(1);
|
|
std::dynamic_pointer_cast<TensorBase<T>>(result->data())->fill_(1);
|
|
|
|
|
|
|
|
return result;
|
|
return result;
|
|
@@ -61,7 +58,6 @@ namespace traph
|
|
|
VariableInterfacePtr empty_like(VariableInterfacePtr input, bool requires_grad = false)
|
|
VariableInterfacePtr empty_like(VariableInterfacePtr input, bool requires_grad = false)
|
|
|
{
|
|
{
|
|
|
std::shared_ptr<VariableInterface> result(new Variable<T>(input->size(), false));
|
|
std::shared_ptr<VariableInterface> result(new Variable<T>(input->size(), false));
|
|
|
- result->leaf_(true);
|
|
|
|
|
|
|
|
|
|
return result;
|
|
return result;
|
|
|
}
|
|
}
|
|
@@ -74,22 +70,20 @@ namespace traph
|
|
|
|
|
|
|
|
VariableInterfacePtr result = input->new_empty(result_dim, true);
|
|
VariableInterfacePtr result = input->new_empty(result_dim, true);
|
|
|
std::shared_ptr<SumOp> op(new SumOp);
|
|
std::shared_ptr<SumOp> op(new SumOp);
|
|
|
|
|
+
|
|
|
|
|
+ result->data_(op->forward({ input->data() }));
|
|
|
if(input->requires_grad())
|
|
if(input->requires_grad())
|
|
|
{
|
|
{
|
|
|
std::vector<VariableInterfacePtr> result_inputs { input };
|
|
std::vector<VariableInterfacePtr> result_inputs { input };
|
|
|
- result->data_(op->forward({ input->data() }));
|
|
|
|
|
result->grad_(result->data()->create_grad());
|
|
result->grad_(result->data()->create_grad());
|
|
|
result->grad()->fill_(0);
|
|
result->grad()->fill_(0);
|
|
|
result->requires_grad_(true);
|
|
result->requires_grad_(true);
|
|
|
- result->leaf_(false);
|
|
|
|
|
result->grad_fn_(op);
|
|
result->grad_fn_(op);
|
|
|
result->inputs_(result_inputs);
|
|
result->inputs_(result_inputs);
|
|
|
}
|
|
}
|
|
|
else
|
|
else
|
|
|
{
|
|
{
|
|
|
- result->data_(op->forward({ input->data() }));
|
|
|
|
|
result->requires_grad_(false);
|
|
result->requires_grad_(false);
|
|
|
- result->leaf_(false);
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
return result;
|
|
return result;
|
|
@@ -108,7 +102,6 @@ namespace traph
|
|
|
result->grad_(result->data()->create_grad());
|
|
result->grad_(result->data()->create_grad());
|
|
|
result->grad()->fill_(0);
|
|
result->grad()->fill_(0);
|
|
|
result->requires_grad_(true);
|
|
result->requires_grad_(true);
|
|
|
- result->leaf_(false);
|
|
|
|
|
result->grad_fn_(op);
|
|
result->grad_fn_(op);
|
|
|
result->inputs_(result_inputs);
|
|
result->inputs_(result_inputs);
|
|
|
}
|
|
}
|
|
@@ -116,7 +109,6 @@ namespace traph
|
|
|
{
|
|
{
|
|
|
result->data_(op->forward({ left->data(), right->data() }));
|
|
result->data_(op->forward({ left->data(), right->data() }));
|
|
|
result->requires_grad_(false);
|
|
result->requires_grad_(false);
|
|
|
- result->leaf_(false);
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
return result;
|
|
return result;
|
|
@@ -135,7 +127,6 @@ namespace traph
|
|
|
result->grad_(result->data()->create_grad());
|
|
result->grad_(result->data()->create_grad());
|
|
|
result->grad()->fill_(0);
|
|
result->grad()->fill_(0);
|
|
|
result->requires_grad_(true);
|
|
result->requires_grad_(true);
|
|
|
- result->leaf_(false);
|
|
|
|
|
result->grad_fn_(op);
|
|
result->grad_fn_(op);
|
|
|
result->inputs_(result_inputs);
|
|
result->inputs_(result_inputs);
|
|
|
}
|
|
}
|
|
@@ -143,7 +134,6 @@ namespace traph
|
|
|
{
|
|
{
|
|
|
result->data_(op->forward({ left->data(), right->data() }));
|
|
result->data_(op->forward({ left->data(), right->data() }));
|
|
|
result->requires_grad_(false);
|
|
result->requires_grad_(false);
|
|
|
- result->leaf_(false);
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
return result;
|
|
return result;
|
|
@@ -160,7 +150,6 @@ namespace traph
|
|
|
|
|
|
|
|
std::vector<VariableInterfacePtr> result_inputs{ input };
|
|
std::vector<VariableInterfacePtr> result_inputs{ input };
|
|
|
result->data_(op->forward({ input->data() }));
|
|
result->data_(op->forward({ input->data() }));
|
|
|
- result->leaf_(false);
|
|
|
|
|
|
|
|
|
|
if (input->requires_grad())
|
|
if (input->requires_grad())
|
|
|
{
|
|
{
|
|
@@ -188,7 +177,6 @@ namespace traph
|
|
|
|
|
|
|
|
std::vector<VariableInterfacePtr> result_inputs{ input };
|
|
std::vector<VariableInterfacePtr> result_inputs{ input };
|
|
|
result->data_(op->forward({ input->data() }));
|
|
result->data_(op->forward({ input->data() }));
|
|
|
- result->leaf_(false);
|
|
|
|
|
|
|
|
|
|
if (input->requires_grad())
|
|
if (input->requires_grad())
|
|
|
{
|
|
{
|
|
@@ -213,7 +201,6 @@ namespace traph
|
|
|
VariableInterfacePtr result = left->new_empty(result_dim, true);
|
|
VariableInterfacePtr result = left->new_empty(result_dim, true);
|
|
|
std::shared_ptr<SubOp> op(new SubOp);
|
|
std::shared_ptr<SubOp> op(new SubOp);
|
|
|
result->data_(op->forward({ left->data(), right->data() }));
|
|
result->data_(op->forward({ left->data(), right->data() }));
|
|
|
- result->leaf_(false);
|
|
|
|
|
if (left->requires_grad() || right->requires_grad())
|
|
if (left->requires_grad() || right->requires_grad())
|
|
|
{
|
|
{
|
|
|
std::vector<VariableInterfacePtr> result_inputs{ left, right };
|
|
std::vector<VariableInterfacePtr> result_inputs{ left, right };
|
|
@@ -241,7 +228,6 @@ namespace traph
|
|
|
|
|
|
|
|
std::vector<VariableInterfacePtr> result_inputs{ input };
|
|
std::vector<VariableInterfacePtr> result_inputs{ input };
|
|
|
result->data_(op->forward({ input->data() }));
|
|
result->data_(op->forward({ input->data() }));
|
|
|
- result->leaf_(false);
|
|
|
|
|
|
|
|
|
|
if (input->requires_grad())
|
|
if (input->requires_grad())
|
|
|
{
|
|
{
|