|
|
@@ -52,6 +52,9 @@ namespace traph
|
|
|
template<class T>
|
|
|
friend std::shared_ptr<Variable<T>> add(std::shared_ptr<Variable<T>> left, std::shared_ptr<Variable<T>> right);
|
|
|
|
|
|
+ template<class T>
|
|
|
+ friend std::shared_ptr<Variable<T>> sin(std::shared_ptr<Variable<T>> input);
|
|
|
+
|
|
|
virtual void backward() override;
|
|
|
virtual TensorInterfacePtr data() override;
|
|
|
virtual device_id device() override;
|
|
|
@@ -64,6 +67,7 @@ namespace traph
|
|
|
virtual idx_type offset() const override;
|
|
|
virtual layout_type order() const override;
|
|
|
virtual platform_type platform() override;
|
|
|
+ virtual bool requires_grad() const override;
|
|
|
virtual void requires_grad_(bool requires_grad) override;
|
|
|
virtual void reshape_(const DimVector& dims) override;
|
|
|
virtual void resize_(const DimVector& dims) override;
|
|
|
@@ -247,6 +251,12 @@ namespace traph
|
|
|
return _data->platform();
|
|
|
}
|
|
|
|
|
|
+ template<typename T>
|
|
|
+ bool Variable<T>::requires_grad() const
|
|
|
+ {
|
|
|
+ return _requires_grad;
|
|
|
+ }
|
|
|
+
|
|
|
template<typename T>
|
|
|
void Variable<T>::requires_grad_(bool requires_grad)
|
|
|
{
|