|
|
@@ -58,17 +58,16 @@ bigdl.Graph = class {
|
|
|
this._inputs = [];
|
|
|
this._outputs = [];
|
|
|
this._nodes = [];
|
|
|
- this._loadModule(metadata, '', module);
|
|
|
+ this._loadModule(metadata, module);
|
|
|
}
|
|
|
|
|
|
- _loadModule(metadata, group, module) {
|
|
|
+ _loadModule(metadata, module) {
|
|
|
switch (module.moduleType) {
|
|
|
- case 'com.intel.analytics.bigdl.nn.StaticGraph': {
|
|
|
- this._loadStaticGraph(metadata, group, module);
|
|
|
- break;
|
|
|
- }
|
|
|
+ case 'com.intel.analytics.bigdl.nn.StaticGraph':
|
|
|
case 'com.intel.analytics.bigdl.nn.Sequential': {
|
|
|
- this._loadSequential(metadata, group, module);
|
|
|
+ for (const submodule of module.subModules) {
|
|
|
+ this._loadModule(metadata, submodule);
|
|
|
+ }
|
|
|
break;
|
|
|
}
|
|
|
case 'com.intel.analytics.bigdl.nn.Input': {
|
|
|
@@ -78,30 +77,12 @@ bigdl.Graph = class {
|
|
|
break;
|
|
|
}
|
|
|
default: {
|
|
|
- this._nodes.push(new bigdl.Node(metadata, group, module));
|
|
|
+ this._nodes.push(new bigdl.Node(metadata, module));
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- _loadSequential(metadata, group, module) {
|
|
|
- group = group.length > 0 ? group + '.' + module.namePostfix : module.namePostfix;
|
|
|
- for (const submodule of module.subModules) {
|
|
|
- this._loadModule(metadata, group, submodule);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- _loadStaticGraph(metadata, group, module) {
|
|
|
- group = group.length > 0 ? group + '.' + module.namePostfix : module.namePostfix;
|
|
|
- for (const submodule of module.subModules) {
|
|
|
- this._loadModule(metadata, group, submodule);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- get groups() {
|
|
|
- return this._groups || false;
|
|
|
- }
|
|
|
-
|
|
|
get type() {
|
|
|
return this._type;
|
|
|
}
|
|
|
@@ -168,15 +149,14 @@ bigdl.Argument = class {
|
|
|
|
|
|
bigdl.Node = class {
|
|
|
|
|
|
- constructor(metadata, group, module) {
|
|
|
- this._group = group;
|
|
|
- const type = module.moduleType.split('.').pop();
|
|
|
+ constructor(metadata, module) {
|
|
|
+ const type = module.moduleType;
|
|
|
this._name = module.name;
|
|
|
this._attributes = [];
|
|
|
this._inputs = [];
|
|
|
this._outputs = [];
|
|
|
this._inputs.push(new bigdl.Parameter('input', module.preModules.map((id) => new bigdl.Argument(id, null, null))));
|
|
|
- this._type = metadata.type(type);
|
|
|
+ this._type = metadata.type(type) || { name: type };
|
|
|
const inputs = (this._type && this._type.inputs) ? this._type.inputs.slice() : [];
|
|
|
inputs.shift();
|
|
|
if (module.weight) {
|
|
|
@@ -226,10 +206,6 @@ bigdl.Node = class {
|
|
|
]));
|
|
|
}
|
|
|
|
|
|
- get group() {
|
|
|
- return this._group;
|
|
|
- }
|
|
|
-
|
|
|
get type() {
|
|
|
return this._type;
|
|
|
}
|