Browse Source

add sgd optim and ScalarType

JasonWang 6 năm trước cách đây
mục cha
commit
680d732420

+ 13 - 0
traph/include/traph/core/type.h

@@ -1,6 +1,7 @@
 #ifndef TRAPH_CORE_TYPE_H_
 #define TRAPH_CORE_TYPE_H_
 
+#include <variant>
 #include <cstdint>
 
 namespace traph
@@ -44,6 +45,18 @@ namespace traph
         FLOAT,
         DOUBLE
     };
+
+    class ScalarType
+    {
+    private:
+        std::variant<u8, i8, i16, i32, i64, f32, f64> _scalar;
+        DataType _dtype;
+    public:
+        DataType dtype() const
+        {
+            return _dtype;
+        }
+    };
 }
 
 #endif

+ 23 - 1
traph/include/traph/nn/function.h

@@ -114,9 +114,31 @@ namespace traph
 	UNARY_OP(sum, SumOp)
 
 	BINARY_OP(add, AddOp)
-	
+
 	BINARY_OP(matmul, MatmulOp)
 
+	VariableInterfacePtr pow(VariableInterfacePtr input, float exp)
+	{
+		DimVector result_dim;
+        VariableInterfacePtr result = input->new_empty(result_dim, true);
+		std::shared_ptr<PowOp> op(new PowOp);
+		op->set_exp(exp);
+		result->data_(op->forward({ input->data() }));
+		if (input->requires_grad())
+		{
+			result->grad_(result->data()->create_grad());
+			result->grad()->fill_(0);
+			result->requires_grad_(true);
+			result->grad_fn_(op);
+			result->inputs_({ input });
+		}
+		else
+		{
+			result->requires_grad_(false);
+		}
+		return result;
+	}
+
 	VariableInterfacePtr select(VariableInterfacePtr input, const SliceVector& slice)
 	{
 		DimVector result_dim;

+ 3 - 3
traph/include/traph/nn/layers/loss.h

@@ -27,16 +27,16 @@ namespace traph
             std::shared_ptr<VariableInterface> ret;
             if(_reduction == MSELossReduction::SUM)
             {
-                ret = sum(sub(input, target));
+                ret = sum(pow(sub(input, target), 2));
             }
             else if(_reduction == MSELossReduction::MEAN)
             {
                 // fixme: use mean if it impled
-                ret = sum(sub(input, target));
+                ret = sum(pow(sub(input, target), 2));
             }
             else
             {
-                ret = sum(sub(input, target));
+                ret = pow(sub(input, target), 2);
             }
             return ret;
         }

+ 26 - 1
traph/include/traph/nn/optim.h

@@ -10,7 +10,7 @@ namespace traph
 {
     class Optimizer
     {
-    private:
+    protected:
         std::vector<std::shared_ptr<VariableInterface>> _params;
     public:
         Optimizer(std::vector<std::shared_ptr<VariableInterface>> params)
@@ -28,6 +28,31 @@ namespace traph
             }
         }
     };
+
+    class SGD:public Optimizer
+    {
+    private:
+        float _lr;
+    public:
+        SGD(std::vector<std::shared_ptr<VariableInterface>> params, 
+            float lr, float momentum=0, float dampening=0, float weight_decay=0,
+            bool nesterov=false)
+            :Optimizer(params), _lr(lr)
+        {
+        }
+
+        virtual void step() override
+        {
+            for(auto& each:_params)
+            {
+                auto d_p = each->grad();
+
+                auto cloned_d_p = std::dynamic_pointer_cast<TensorBase<f32>>(d_p->clone());
+                cloned_d_p->mul_(_lr);
+                each->data()->add_(cloned_d_p);
+            }
+        }
+    };
 }
 
 #endif