|
|
@@ -2,6 +2,7 @@
|
|
|
#define TRAPH_NN_FUNCTION_H_
|
|
|
|
|
|
#include <utility>
|
|
|
+#include <random>
|
|
|
#include <cmath>
|
|
|
|
|
|
#include <traph/core/type.h>
|
|
|
@@ -16,27 +17,27 @@
|
|
|
namespace traph
|
|
|
{
|
|
|
|
|
|
-#define UNARY_OP(name, op_name) \
|
|
|
- VariableInterfacePtr name(VariableInterfacePtr input) \
|
|
|
- { \
|
|
|
- DimVector result_dim; \
|
|
|
- VariableInterfacePtr result = input->new_empty(result_dim, true); \
|
|
|
- std::shared_ptr<op_name> op(new op_name); \
|
|
|
- std::vector<VariableInterfacePtr> result_inputs{ input }; \
|
|
|
- result->data_(op->forward({ input->data() })); \
|
|
|
- if (input->requires_grad()) \
|
|
|
- { \
|
|
|
- result->grad_(result->data()->create_grad()); \
|
|
|
- result->grad()->fill_(0); \
|
|
|
- result->requires_grad_(true); \
|
|
|
- result->grad_fn_(op); \
|
|
|
- result->inputs_(result_inputs); \
|
|
|
- } \
|
|
|
- else \
|
|
|
- { \
|
|
|
- result->requires_grad_(false); \
|
|
|
- } \
|
|
|
- return result; \
|
|
|
+#define UNARY_OP(name, op_name) \
|
|
|
+ VariableInterfacePtr name(VariableInterfacePtr input) \
|
|
|
+ { \
|
|
|
+ DimVector result_dim; \
|
|
|
+ VariableInterfacePtr result = input->new_empty(result_dim, true); \
|
|
|
+ std::shared_ptr<op_name> op(new op_name); \
|
|
|
+ std::vector<VariableInterfacePtr> result_inputs{ input }; \
|
|
|
+ result->data_(op->forward({ input->data() })); \
|
|
|
+ if (input->requires_grad()) \
|
|
|
+ { \
|
|
|
+ result->grad_(result->data()->create_grad()); \
|
|
|
+ result->grad()->fill_(0); \
|
|
|
+ result->requires_grad_(true); \
|
|
|
+ result->grad_fn_(op); \
|
|
|
+ result->inputs_(result_inputs); \
|
|
|
+ } \
|
|
|
+ else \
|
|
|
+ { \
|
|
|
+ result->requires_grad_(false); \
|
|
|
+ } \
|
|
|
+ return result; \
|
|
|
}
|
|
|
|
|
|
#define BINARY_OP(name, op_name) \
|
|
|
@@ -102,6 +103,26 @@ namespace traph
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
+ template<typename T>
|
|
|
+ VariableInterfacePtr randn(std::initializer_list<idx_type> l, bool requires_grad = false)
|
|
|
+ {
|
|
|
+ DimVector dim;
|
|
|
+ for (auto i : l)
|
|
|
+ dim.push_back(i);
|
|
|
+
|
|
|
+ std::random_device rd{};
|
|
|
+ std::mt19937 gen{rd()};
|
|
|
+ std::normal_distribution<> d{0,1};
|
|
|
+
|
|
|
+ std::shared_ptr<VariableInterface> result(new Variable<T>(dim));
|
|
|
+ std::shared_ptr<TensorBase<T>> result_data = std::dynamic_pointer_cast<TensorBase<T>>(result->data());
|
|
|
+ result_data->apply_([&d, &gen](T n){
|
|
|
+ return d(gen);
|
|
|
+ });
|
|
|
+
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
template<typename T>
|
|
|
VariableInterfacePtr empty_like(VariableInterfacePtr input, bool requires_grad = false)
|
|
|
{
|