|
@@ -99,10 +99,10 @@ class QM9DataModule(DataModule):
|
|
|
def _collate(self, samples):
|
|
def _collate(self, samples):
|
|
|
graphs, y, *bases = map(list, zip(*samples))
|
|
graphs, y, *bases = map(list, zip(*samples))
|
|
|
batched_graph = dgl.batch(graphs)
|
|
batched_graph = dgl.batch(graphs)
|
|
|
- edge_feats = {'0': batched_graph.edata['edge_attr'][..., None]}
|
|
|
|
|
|
|
+ edge_feats = {'0': batched_graph.edata['edge_attr'][:, :self.EDGE_FEATURE_DIM, None]}
|
|
|
batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
|
|
batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
|
|
|
# get node features
|
|
# get node features
|
|
|
- node_feats = {'0': batched_graph.ndata['attr'][:, :6, None]}
|
|
|
|
|
|
|
+ node_feats = {'0': batched_graph.ndata['attr'][:, :self.NODE_FEATURE_DIM, None]}
|
|
|
targets = (torch.cat(y) - self.targets_mean) / self.targets_std
|
|
targets = (torch.cat(y) - self.targets_mean) / self.targets_std
|
|
|
|
|
|
|
|
if bases:
|
|
if bases:
|