Spaces:
Sleeping
Sleeping
import sys | |
import os | |
import csv | |
import argparse | |
import random | |
from pathlib import Path | |
import numpy as np | |
import torch | |
import pandas as pd | |
import re | |
from torch.utils.data import DataLoader | |
try: | |
from torch_geometric.data import Batch | |
except ImportError: | |
pass | |
def set_seed(seed): | |
"""Sets seed""" | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed) | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
torch.backends.cudnn.benchmark = False | |
torch.backends.cudnn.deterministic = True | |
def move_to(obj, device): | |
if isinstance(obj, dict): | |
return {k: move_to(v, device) for k, v in obj.items()} | |
elif isinstance(obj, list): | |
return [move_to(v, device) for v in obj] | |
elif isinstance(obj, float) or isinstance(obj, int): | |
return obj | |
else: | |
# Assume obj is a Tensor or other type | |
# (like Batch, for MolPCBA) that supports .to(device) | |
return obj.to(device) | |
def detach_and_clone(obj): | |
if torch.is_tensor(obj): | |
return obj.detach().clone() | |
elif isinstance(obj, dict): | |
return {k: detach_and_clone(v) for k, v in obj.items()} | |
elif isinstance(obj, list): | |
return [detach_and_clone(v) for v in obj] | |
elif isinstance(obj, float) or isinstance(obj, int): | |
return obj | |
else: | |
raise TypeError("Invalid type for detach_and_clone") | |
def collate_list(vec): | |
""" | |
If vec is a list of Tensors, it concatenates them all along the first dimension. | |
If vec is a list of lists, it joins these lists together, but does not attempt to | |
recursively collate. This allows each element of the list to be, e.g., its own dict. | |
If vec is a list of dicts (with the same keys in each dict), it returns a single dict | |
with the same keys. For each key, it recursively collates all entries in the list. | |
""" | |
if not isinstance(vec, list): | |
raise TypeError("collate_list must take in a list") | |
elem = vec[0] | |
if torch.is_tensor(elem): | |
return torch.cat(vec) | |
elif isinstance(elem, list): | |
return [obj for sublist in vec for obj in sublist] | |
elif isinstance(elem, dict): | |
return {k: collate_list([d[k] for d in vec]) for k in elem} | |
else: | |
raise TypeError("Elements of the list to collate must be tensors or dicts.") |