utils.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Permission is hereby granted, free of charge, to any person obtaining a
  4. # copy of this software and associated documentation files (the "Software"),
  5. # to deal in the Software without restriction, including without limitation
  6. # the rights to use, copy, modify, merge, publish, distribute, sublicense,
  7. # and/or sell copies of the Software, and to permit persons to whom the
  8. # Software is furnished to do so, subject to the following conditions:
  9. #
  10. # The above copyright notice and this permission notice shall be included in
  11. # all copies or substantial portions of the Software.
  12. #
  13. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  14. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  15. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  16. # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  17. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
  18. # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
  19. # DEALINGS IN THE SOFTWARE.
  20. #
  21. # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
  22. # SPDX-License-Identifier: MIT
  23. import dgl
  24. import torch
  25. def get_random_graph(N, num_edges_factor=18):
  26. graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor))
  27. return graph
  28. def assign_relative_pos(graph, coords):
  29. src, dst = graph.edges()
  30. graph.edata['rel_pos'] = coords[src] - coords[dst]
  31. return graph
  32. def get_max_diff(a, b):
  33. return (a - b).abs().max().item()
  34. def rot_z(gamma):
  35. return torch.tensor([
  36. [torch.cos(gamma), -torch.sin(gamma), 0],
  37. [torch.sin(gamma), torch.cos(gamma), 0],
  38. [0, 0, 1]
  39. ], dtype=gamma.dtype)
  40. def rot_y(beta):
  41. return torch.tensor([
  42. [torch.cos(beta), 0, torch.sin(beta)],
  43. [0, 1, 0],
  44. [-torch.sin(beta), 0, torch.cos(beta)]
  45. ], dtype=beta.dtype)
  46. def rot(alpha, beta, gamma):
  47. return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)