JasonWang 6 年 前
コミット
9140181c28
2 ファイル変更29 行追加19 行削除
  1. 1 1
      traph/include/traph/nn/variable.h
  2. 28 18
      traph/source/nn/executor.cpp

+ 1 - 1
traph/include/traph/nn/variable.h

@@ -145,7 +145,7 @@ namespace traph
 		_grad->fill_(1);
 
 		std::vector<VariableInterface*> sorted_node = Executor::topology_sort(dynamic_cast<VariableInterface*>(this));
-		for (int i = static_cast<int>(sorted_node.size()) - 1; i >= 0; --i)
+		for (int i = 0; i < static_cast<int>(sorted_node.size()); ++i)
 		{
 			VariableInterface* cur_node = sorted_node[i];
 			if (cur_node->is_leaf()) continue;

+ 28 - 18
traph/source/nn/executor.cpp

@@ -1,38 +1,48 @@
 #include <traph/nn/executor.h>
 
+#include <map>
+
 namespace traph
 {
     std::vector<VariableInterface*> Executor::topology_sort(VariableInterface* root)
     {
         std::set<VariableInterface*> all_nodes = collect_backward_tensors(root);
-        std::vector<VariableInterface*> visited_nodes;
+		// std::set<VariableInterface*> visited_nodes;
+        std::vector<VariableInterface*> sort_result;
+		std::map<VariableInterface*, int> indegrees;
         std::size_t all_size = all_nodes.size();
 
+		for (std::set<VariableInterface*>::iterator it = all_nodes.begin(); it != all_nodes.end(); ++it)
+			indegrees[*it] = 0;
+
+		for (std::set<VariableInterface*>::iterator it = all_nodes.begin(); it != all_nodes.end(); ++it)
+		{
+			std::vector<VariableInterfacePtr>& cur_inputs = (*it)->inputs();
+			for (auto &each : cur_inputs)
+				if(indegrees.find(each.get()) != indegrees.end())
+					indegrees[each.get()]++;
+		}
+
         for(int i = 0; i<all_size; ++i)
         {
             for(std::set<VariableInterface*>::iterator it = all_nodes.begin(); it != all_nodes.end(); ++it)
             {
                 std::vector<VariableInterfacePtr> cur_inputs = (*it)->inputs();
-				std::vector<VariableInterface*> cur_raw_inputs;
-
-				std::set<VariableInterface*> sorted_cur_raw_inputs(cur_raw_inputs.begin(), cur_raw_inputs.end());
-				std::set<VariableInterface*> sorted_visited_nodes(visited_nodes.begin(), visited_nodes.end());
-
-				for (auto &each : cur_inputs)
-					cur_raw_inputs.push_back(each.get());
-                std::vector<VariableInterface*> cur_inputs_no_visited;
-                std::set_difference(sorted_cur_raw_inputs.begin(), sorted_cur_raw_inputs.end(), sorted_visited_nodes.begin(), sorted_visited_nodes.end(),
-                        std::inserter(cur_inputs_no_visited, cur_inputs_no_visited.begin()));
-                if(cur_inputs_no_visited.empty())
-                {
-                    visited_nodes.push_back(*it);
-                    all_nodes.erase(it);
-                    break;
-                }
+				if (indegrees[*it] == 0)
+				{
+					sort_result.push_back(*it);
+					all_nodes.erase(it);
+					for (auto& each:cur_inputs)
+					{
+						if (indegrees.find(each.get()) != indegrees.end())
+							indegrees[each.get()]--;
+					}
+					break;
+				}
             }
         }
 
-        return visited_nodes;
+		return sort_result;
     }
 
     std::set<VariableInterface*> Executor::collect_backward_tensors(VariableInterface* root)