Browse Source

Merge: [SE3Transformer/DGLPyT] Benchmarking fixes and tweaks

Krzysztof Kudrynski 3 years ago
parent
commit
f4b9bdfd2f

+ 2 - 2
DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/data_loading/qm9.py

@@ -99,10 +99,10 @@ class QM9DataModule(DataModule):
     def _collate(self, samples):
         graphs, y, *bases = map(list, zip(*samples))
         batched_graph = dgl.batch(graphs)
-        edge_feats = {'0': batched_graph.edata['edge_attr'][..., None]}
+        edge_feats = {'0': batched_graph.edata['edge_attr'][:, :self.EDGE_FEATURE_DIM, None]}
         batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
         # get node features
-        node_feats = {'0': batched_graph.ndata['attr'][:, :6, None]}
+        node_feats = {'0': batched_graph.ndata['attr'][:, :self.NODE_FEATURE_DIM, None]}
         targets = (torch.cat(y) - self.targets_mean) / self.targets_std
 
         if bases:

+ 3 - 1
DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/transformer.py

@@ -127,7 +127,9 @@ class SE3Transformer(nn.Module):
                                      fiber_edge=fiber_edge,
                                      self_interaction=True,
                                      use_layer_norm=use_layer_norm,
-                                     max_degree=self.max_degree))
+                                     max_degree=self.max_degree,
+                                     fuse_level=self.fuse_level,
+                                     low_memory=low_memory))
         self.graph_modules = Sequential(*graph_modules)
 
         if pooling is not None:

+ 14 - 9
DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/inference.py

@@ -116,18 +116,23 @@ if __name__ == '__main__':
     torch.set_float32_matmul_precision('high')
 
     test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
-    evaluate(model,
-             test_dataloader,
-             callbacks,
-             args)
+    if not args.benchmark:
+        evaluate(model,
+                 test_dataloader,
+                 callbacks,
+                 args)
 
-    for callback in callbacks:
-        callback.on_validation_end()
+        for callback in callbacks:
+            callback.on_validation_end()
 
-    if args.benchmark:
+    else:
         world_size = dist.get_world_size() if dist.is_initialized() else 1
-        callbacks = [PerformanceCallback(logger, args.batch_size * world_size, warmup_epochs=1, mode='inference')]
-        for _ in range(6):
+        callbacks = [PerformanceCallback(
+            logger, args.batch_size * world_size,
+            warmup_epochs=1 if args.epochs > 1 else 0,
+            mode='inference'
+        )]
+        for _ in range(args.epochs):
             evaluate(model,
                      test_dataloader,
                      callbacks,

+ 3 - 1
DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/training.py

@@ -221,7 +221,9 @@ if __name__ == '__main__':
     if args.benchmark:
         logging.info('Running benchmark mode')
         world_size = dist.get_world_size() if dist.is_initialized() else 1
-        callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
+        callbacks = [PerformanceCallback(
+            logger, args.batch_size * world_size, warmup_epochs=1 if args.epochs > 1 else 0
+        )]
     else:
         callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
                      QM9LRSchedulerCallback(logger, epochs=args.epochs)]