|
|
@@ -70,6 +70,7 @@ namespace traph
|
|
|
virtual std::vector<VariableInterfacePtr>& inputs() override;
|
|
|
virtual bool is_leaf() const override;
|
|
|
virtual T item() const override;
|
|
|
+ virtual void leaf_(bool state) override;
|
|
|
virtual idx_type offset() const override;
|
|
|
virtual layout_type order() const override;
|
|
|
virtual platform_type platform() override;
|
|
|
@@ -239,6 +240,12 @@ namespace traph
|
|
|
return _data->item();
|
|
|
}
|
|
|
|
|
|
+ template<typename T>
|
|
|
+ void Variable<T>::leaf_(bool state)
|
|
|
+ {
|
|
|
+ _leaf = state;
|
|
|
+ }
|
|
|
+
|
|
|
template<typename T>
|
|
|
idx_type Variable<T>::offset() const
|
|
|
{
|
|
|
@@ -307,33 +314,6 @@ namespace traph
|
|
|
{
|
|
|
return _data->stride();
|
|
|
}
|
|
|
-
|
|
|
- // variable constructor
|
|
|
- template<class T>
|
|
|
- VariablePtr<T> zeros(std::initializer_list<idx_type> l, bool requires_grad = false)
|
|
|
- {
|
|
|
- DimVector dim;
|
|
|
- for (auto i : l)
|
|
|
- dim.push_back(i);
|
|
|
-
|
|
|
- std::shared_ptr<Variable<T>> result(new Variable<T>(dim, false));
|
|
|
- result->fill_(0);
|
|
|
-
|
|
|
- return result;
|
|
|
- }
|
|
|
-
|
|
|
- template<class T>
|
|
|
- VariablePtr<T> ones(std::initializer_list<idx_type> l, bool requires_grad = false)
|
|
|
- {
|
|
|
- DimVector dim;
|
|
|
- for (auto i : l)
|
|
|
- dim.push_back(i);
|
|
|
-
|
|
|
- std::shared_ptr<Variable<T>> result(new Variable<T>(dim, false));
|
|
|
- result->fill_(1);
|
|
|
-
|
|
|
- return result;
|
|
|
- }
|
|
|
}
|
|
|
|
|
|
|