|
|
@@ -10,7 +10,7 @@ namespace traph
|
|
|
{
|
|
|
class Optimizer
|
|
|
{
|
|
|
- private:
|
|
|
+ protected:
|
|
|
std::vector<std::shared_ptr<VariableInterface>> _params;
|
|
|
public:
|
|
|
Optimizer(std::vector<std::shared_ptr<VariableInterface>> params)
|
|
|
@@ -28,6 +28,31 @@ namespace traph
|
|
|
}
|
|
|
}
|
|
|
};
|
|
|
+
|
|
|
+ class SGD:public Optimizer
|
|
|
+ {
|
|
|
+ private:
|
|
|
+ float _lr;
|
|
|
+ public:
|
|
|
+ SGD(std::vector<std::shared_ptr<VariableInterface>> params,
|
|
|
+ float lr, float momentum=0, float dampening=0, float weight_decay=0,
|
|
|
+ bool nesterov=false)
|
|
|
+ :Optimizer(params), _lr(lr)
|
|
|
+ {
|
|
|
+ }
|
|
|
+
|
|
|
+ virtual void step() override
|
|
|
+ {
|
|
|
+ for(auto& each:_params)
|
|
|
+ {
|
|
|
+ auto d_p = each->grad();
|
|
|
+
|
|
|
+ auto cloned_d_p = std::dynamic_pointer_cast<TensorBase<f32>>(d_p->clone());
|
|
|
+ cloned_d_p->mul_(_lr);
|
|
|
+ each->data()->add_(cloned_d_p);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ };
|
|
|
}
|
|
|
|
|
|
#endif
|