Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| import csv | |
| import argparse | |
| import random | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import pandas as pd | |
| import re | |
| from torch.utils.data import DataLoader | |
| try: | |
| from torch_geometric.data import Batch | |
| except ImportError: | |
| pass | |
| def set_seed(seed): | |
| """Sets seed""" | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cudnn.deterministic = True | |
| def move_to(obj, device): | |
| if isinstance(obj, dict): | |
| return {k: move_to(v, device) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [move_to(v, device) for v in obj] | |
| elif isinstance(obj, float) or isinstance(obj, int): | |
| return obj | |
| else: | |
| # Assume obj is a Tensor or other type | |
| # (like Batch, for MolPCBA) that supports .to(device) | |
| return obj.to(device) | |
| def detach_and_clone(obj): | |
| if torch.is_tensor(obj): | |
| return obj.detach().clone() | |
| elif isinstance(obj, dict): | |
| return {k: detach_and_clone(v) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [detach_and_clone(v) for v in obj] | |
| elif isinstance(obj, float) or isinstance(obj, int): | |
| return obj | |
| else: | |
| raise TypeError("Invalid type for detach_and_clone") | |
| def collate_list(vec): | |
| """ | |
| If vec is a list of Tensors, it concatenates them all along the first dimension. | |
| If vec is a list of lists, it joins these lists together, but does not attempt to | |
| recursively collate. This allows each element of the list to be, e.g., its own dict. | |
| If vec is a list of dicts (with the same keys in each dict), it returns a single dict | |
| with the same keys. For each key, it recursively collates all entries in the list. | |
| """ | |
| if not isinstance(vec, list): | |
| raise TypeError("collate_list must take in a list") | |
| elem = vec[0] | |
| if torch.is_tensor(elem): | |
| return torch.cat(vec) | |
| elif isinstance(elem, list): | |
| return [obj for sublist in vec for obj in sublist] | |
| elif isinstance(elem, dict): | |
| return {k: collate_list([d[k] for d in vec]) for k in elem} | |
| else: | |
| raise TypeError("Elements of the list to collate must be tensors or dicts.") |