|
|
@@ -15,6 +15,54 @@
|
|
|
|
|
|
namespace traph
|
|
|
{
|
|
|
+
|
|
|
+#define UNARY_OP(name, op_name) \
|
|
|
+ VariableInterfacePtr name(VariableInterfacePtr input) \
|
|
|
+ { \
|
|
|
+ DimVector result_dim; \
|
|
|
+ VariableInterfacePtr result = input->new_empty(result_dim, true); \
|
|
|
+ std::shared_ptr<op_name> op(new op_name); \
|
|
|
+ std::vector<VariableInterfacePtr> result_inputs{ input }; \
|
|
|
+ result->data_(op->forward({ input->data() })); \
|
|
|
+ if (input->requires_grad()) \
|
|
|
+ { \
|
|
|
+ result->grad_(result->data()->create_grad()); \
|
|
|
+ result->grad()->fill_(0); \
|
|
|
+ result->requires_grad_(true); \
|
|
|
+ result->grad_fn_(op); \
|
|
|
+ result->inputs_(result_inputs); \
|
|
|
+ } \
|
|
|
+ else \
|
|
|
+ { \
|
|
|
+ result->requires_grad_(false); \
|
|
|
+ } \
|
|
|
+ return result; \
|
|
|
+ }
|
|
|
+
|
|
|
+#define BINARY_OP(name, op_name) \
|
|
|
+ VariableInterfacePtr name(VariableInterfacePtr left, VariableInterfacePtr right) \
|
|
|
+ { \
|
|
|
+ DimVector result_dim; \
|
|
|
+ VariableInterfacePtr result = left->new_empty(result_dim, true); \
|
|
|
+ std::shared_ptr<op_name> op(new op_name); \
|
|
|
+ result->data_(op->forward({ left->data(), right->data() })); \
|
|
|
+ if (left->requires_grad() || right->requires_grad()) \
|
|
|
+ { \
|
|
|
+ std::vector<VariableInterfacePtr> result_inputs{ left, right }; \
|
|
|
+ result->grad_(result->data()->create_grad()); \
|
|
|
+ result->grad()->fill_(0); \
|
|
|
+ result->requires_grad_(true); \
|
|
|
+ result->grad_fn_(op); \
|
|
|
+ result->inputs_(result_inputs); \
|
|
|
+ } \
|
|
|
+ else \
|
|
|
+ { \
|
|
|
+ result->requires_grad_(false); \
|
|
|
+ } \
|
|
|
+ return result; \
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
// creation function
|
|
|
template<typename T>
|
|
|
VariableInterfacePtr empty(std::initializer_list<idx_type> l, bool requires_grad = false)
|
|
|
@@ -63,83 +111,12 @@ namespace traph
|
|
|
}
|
|
|
|
|
|
// arithmetic function
|
|
|
- VariableInterfacePtr sum(VariableInterfacePtr input)
|
|
|
- {
|
|
|
- DimVector result_dim(1);
|
|
|
- result_dim[0] = 1;
|
|
|
-
|
|
|
- VariableInterfacePtr result = input->new_empty(result_dim, true);
|
|
|
- std::shared_ptr<SumOp> op(new SumOp);
|
|
|
-
|
|
|
- result->data_(op->forward({ input->data() }));
|
|
|
- if(input->requires_grad())
|
|
|
- {
|
|
|
- std::vector<VariableInterfacePtr> result_inputs { input };
|
|
|
- result->grad_(result->data()->create_grad());
|
|
|
- result->grad()->fill_(0);
|
|
|
- result->requires_grad_(true);
|
|
|
- result->grad_fn_(op);
|
|
|
- result->inputs_(result_inputs);
|
|
|
- }
|
|
|
- else
|
|
|
- {
|
|
|
- result->requires_grad_(false);
|
|
|
- }
|
|
|
-
|
|
|
- return result;
|
|
|
- }
|
|
|
-
|
|
|
- VariableInterfacePtr add(VariableInterfacePtr left, VariableInterfacePtr right)
|
|
|
- {
|
|
|
- DimVector result_dim;
|
|
|
-
|
|
|
- VariableInterfacePtr result = left->new_empty(result_dim, true);
|
|
|
- std::shared_ptr<AddOp> op(new AddOp);
|
|
|
- if (left->requires_grad() || right->requires_grad())
|
|
|
- {
|
|
|
- std::vector<VariableInterfacePtr> result_inputs{ left, right };
|
|
|
- result->data_(op->forward({ left->data(), right->data() }));
|
|
|
- result->grad_(result->data()->create_grad());
|
|
|
- result->grad()->fill_(0);
|
|
|
- result->requires_grad_(true);
|
|
|
- result->grad_fn_(op);
|
|
|
- result->inputs_(result_inputs);
|
|
|
- }
|
|
|
- else
|
|
|
- {
|
|
|
- result->data_(op->forward({ left->data(), right->data() }));
|
|
|
- result->requires_grad_(false);
|
|
|
- }
|
|
|
-
|
|
|
- return result;
|
|
|
- }
|
|
|
-
|
|
|
- VariableInterfacePtr matmul(VariableInterfacePtr left, VariableInterfacePtr right)
|
|
|
- {
|
|
|
- DimVector result_dim;
|
|
|
-
|
|
|
- VariableInterfacePtr result = left->new_empty(result_dim, true);
|
|
|
- std::shared_ptr<MatmulOp> op(new MatmulOp);
|
|
|
- if (left->requires_grad() || right->requires_grad())
|
|
|
- {
|
|
|
- std::vector<VariableInterfacePtr> result_inputs{ left, right };
|
|
|
- result->data_(op->forward({ left->data(), right->data() }));
|
|
|
- result->grad_(result->data()->create_grad());
|
|
|
- result->grad()->fill_(0);
|
|
|
- result->requires_grad_(true);
|
|
|
- result->grad_fn_(op);
|
|
|
- result->inputs_(result_inputs);
|
|
|
- }
|
|
|
- else
|
|
|
- {
|
|
|
- result->data_(op->forward({ left->data(), right->data() }));
|
|
|
- result->requires_grad_(false);
|
|
|
- }
|
|
|
-
|
|
|
- return result;
|
|
|
- }
|
|
|
+ UNARY_OP(sum, SumOp)
|
|
|
|
|
|
+ BINARY_OP(add, AddOp)
|
|
|
|
|
|
+ BINARY_OP(matmul, MatmulOp)
|
|
|
+
|
|
|
VariableInterfacePtr select(VariableInterfacePtr input, const SliceVector& slice)
|
|
|
{
|
|
|
DimVector result_dim;
|
|
|
@@ -167,56 +144,9 @@ namespace traph
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
+ UNARY_OP(sin, SinOp)
|
|
|
|
|
|
- VariableInterfacePtr sin(VariableInterfacePtr input)
|
|
|
- {
|
|
|
- DimVector result_dim;
|
|
|
-
|
|
|
- VariableInterfacePtr result = input->new_empty(result_dim, true);
|
|
|
- std::shared_ptr<SinOp> op(new SinOp);
|
|
|
-
|
|
|
- std::vector<VariableInterfacePtr> result_inputs{ input };
|
|
|
- result->data_(op->forward({ input->data() }));
|
|
|
-
|
|
|
- if (input->requires_grad())
|
|
|
- {
|
|
|
- result->grad_(result->data()->create_grad());
|
|
|
- result->grad()->fill_(0);
|
|
|
- result->requires_grad_(true);
|
|
|
- result->grad_fn_(op);
|
|
|
- result->inputs_(result_inputs);
|
|
|
- }
|
|
|
- else
|
|
|
- {
|
|
|
- result->requires_grad_(false);
|
|
|
- }
|
|
|
-
|
|
|
- return result;
|
|
|
- }
|
|
|
-
|
|
|
- VariableInterfacePtr sub(VariableInterfacePtr left, VariableInterfacePtr right)
|
|
|
- {
|
|
|
- DimVector result_dim;
|
|
|
-
|
|
|
- VariableInterfacePtr result = left->new_empty(result_dim, true);
|
|
|
- std::shared_ptr<SubOp> op(new SubOp);
|
|
|
- result->data_(op->forward({ left->data(), right->data() }));
|
|
|
- if (left->requires_grad() || right->requires_grad())
|
|
|
- {
|
|
|
- std::vector<VariableInterfacePtr> result_inputs{ left, right };
|
|
|
- result->grad_(result->data()->create_grad());
|
|
|
- result->grad()->fill_(0);
|
|
|
- result->requires_grad_(true);
|
|
|
- result->grad_fn_(op);
|
|
|
- result->inputs_(result_inputs);
|
|
|
- }
|
|
|
- else
|
|
|
- {
|
|
|
- result->requires_grad_(false);
|
|
|
- }
|
|
|
-
|
|
|
- return result;
|
|
|
- }
|
|
|
+ BINARY_OP(sub, SubOp)
|
|
|
|
|
|
VariableInterfacePtr transpose(VariableInterfacePtr input, idx_type dim0, idx_type dim1)
|
|
|
{
|