|
|
@@ -1,38 +1,48 @@
|
|
|
#include <traph/nn/executor.h>
|
|
|
|
|
|
+#include <map>
|
|
|
+
|
|
|
namespace traph
|
|
|
{
|
|
|
std::vector<VariableInterface*> Executor::topology_sort(VariableInterface* root)
|
|
|
{
|
|
|
std::set<VariableInterface*> all_nodes = collect_backward_tensors(root);
|
|
|
- std::vector<VariableInterface*> visited_nodes;
|
|
|
+ // std::set<VariableInterface*> visited_nodes;
|
|
|
+ std::vector<VariableInterface*> sort_result;
|
|
|
+ std::map<VariableInterface*, int> indegrees;
|
|
|
std::size_t all_size = all_nodes.size();
|
|
|
|
|
|
+ for (std::set<VariableInterface*>::iterator it = all_nodes.begin(); it != all_nodes.end(); ++it)
|
|
|
+ indegrees[*it] = 0;
|
|
|
+
|
|
|
+ for (std::set<VariableInterface*>::iterator it = all_nodes.begin(); it != all_nodes.end(); ++it)
|
|
|
+ {
|
|
|
+ std::vector<VariableInterfacePtr>& cur_inputs = (*it)->inputs();
|
|
|
+ for (auto &each : cur_inputs)
|
|
|
+ if(indegrees.find(each.get()) != indegrees.end())
|
|
|
+ indegrees[each.get()]++;
|
|
|
+ }
|
|
|
+
|
|
|
for(int i = 0; i<all_size; ++i)
|
|
|
{
|
|
|
for(std::set<VariableInterface*>::iterator it = all_nodes.begin(); it != all_nodes.end(); ++it)
|
|
|
{
|
|
|
std::vector<VariableInterfacePtr> cur_inputs = (*it)->inputs();
|
|
|
- std::vector<VariableInterface*> cur_raw_inputs;
|
|
|
-
|
|
|
- std::set<VariableInterface*> sorted_cur_raw_inputs(cur_raw_inputs.begin(), cur_raw_inputs.end());
|
|
|
- std::set<VariableInterface*> sorted_visited_nodes(visited_nodes.begin(), visited_nodes.end());
|
|
|
-
|
|
|
- for (auto &each : cur_inputs)
|
|
|
- cur_raw_inputs.push_back(each.get());
|
|
|
- std::vector<VariableInterface*> cur_inputs_no_visited;
|
|
|
- std::set_difference(sorted_cur_raw_inputs.begin(), sorted_cur_raw_inputs.end(), sorted_visited_nodes.begin(), sorted_visited_nodes.end(),
|
|
|
- std::inserter(cur_inputs_no_visited, cur_inputs_no_visited.begin()));
|
|
|
- if(cur_inputs_no_visited.empty())
|
|
|
- {
|
|
|
- visited_nodes.push_back(*it);
|
|
|
- all_nodes.erase(it);
|
|
|
- break;
|
|
|
- }
|
|
|
+ if (indegrees[*it] == 0)
|
|
|
+ {
|
|
|
+ sort_result.push_back(*it);
|
|
|
+ all_nodes.erase(it);
|
|
|
+ for (auto& each:cur_inputs)
|
|
|
+ {
|
|
|
+ if (indegrees.find(each.get()) != indegrees.end())
|
|
|
+ indegrees[each.get()]--;
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- return visited_nodes;
|
|
|
+ return sort_result;
|
|
|
}
|
|
|
|
|
|
std::set<VariableInterface*> Executor::collect_backward_tensors(VariableInterface* root)
|