Sfoglia il codice sorgente

update module and optim methods

JasonWang 6 anni fa
parent
commit
13a16b8824

+ 25 - 5
traph/include/traph/nn/module.h

@@ -15,12 +15,32 @@ namespace traph
     class Module
     {
     private:
-        std::map<std::string, std::shared_ptr<ParameterInterface>> _parameters;
+        std::vector<std::pair<std::string, std::shared_ptr<VariableInterface>>> _parameters;
         std::vector<std::shared_ptr<Module>> _children;
     public:
-        std::vector<std::shared_ptr<ParameterInterface>> parameters(bool recurse)
+
+        std::vector<std::pair<std::string, std::shared_ptr<VariableInterface>>> named_parameters(bool recurse)
+        {
+            std::vector<std::pair<std::string, std::shared_ptr<VariableInterface>>> result;
+            if(recurse)
+            {
+                // fixme: children params recurse
+                for (const auto &p : _parameters)
+                    if(p.first != "")
+                        result.push_back(p);
+            }
+            else
+            {
+                for (const auto &p : _parameters)
+                    if(p.first != "")
+                        result.push_back(p);
+            }
+            return result;
+        }
+
+        std::vector<std::shared_ptr<VariableInterface>> parameters(bool recurse)
         {
-            std::vector<std::shared_ptr<ParameterInterface>> result;
+            std::vector<std::shared_ptr<VariableInterface>> result;
             if(recurse)
             {
                 // fixme: children params recurse
@@ -35,9 +55,9 @@ namespace traph
             return result;
         }
 
-        void register_parameter(const std::string& name, std::shared_ptr<ParameterInterface> param)
+        void register_parameter(const std::string& name, std::shared_ptr<VariableInterface> param)
         {
-            _parameters[name] = param;
+            _parameters.push_back(std::make_pair(name, param));
         }
     };
 } // traph

+ 6 - 2
traph/include/traph/nn/optim.h

@@ -11,9 +11,9 @@ namespace traph
     class Optimizer
     {
     private:
-        std::vector<std::shared_ptr<ParameterInterface>> _params;
+        std::vector<std::shared_ptr<VariableInterface>> _params;
     public:
-        Optimizer(std::vector<std::shared_ptr<ParameterInterface>> params)
+        Optimizer(std::vector<std::shared_ptr<VariableInterface>> params)
             :_params(params)
         {
         }
@@ -22,6 +22,10 @@ namespace traph
 
         void zero_grad()
         {
+            for(auto& each_param: _params)
+            {
+                each_param->grad()->fill_(0);
+            }
         }
     };
 }

+ 1 - 1
traph/include/traph/nn/parameter.h

@@ -10,7 +10,7 @@ namespace traph
     };
 
     template<typename T>
-    class Parameter:public Variable<T>, public ParameterInterface
+    class Parameter:public Variable<T>
     {
     public:
         Parameter();

+ 1 - 1
traph/source/test/main.cpp

@@ -68,7 +68,7 @@ int main()
 	auto result = loss.forward(out, y);
 
 	result->backward();
-	std::cout << result->data()->to_string();
+	std::cout << linear_model.parameters(true)[0]->grad()->to_string();
 
     return 0;
 }