Spaces:
Running
Running
import argparse | |
import os | |
import sys | |
import warnings | |
import numpy as np | |
import pandas as pd | |
import torch | |
from torch.utils.data import DataLoader | |
from transformers import AutoTokenizer, T5EncoderModel | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
from generation_utils import ReactionT5Dataset | |
from train import preprocess_df, preprocess_USPTO | |
from utils import filter_out, seed_everything | |
warnings.filterwarnings("ignore") | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--input_data", | |
type=str, | |
required=True, | |
help="Path to the input data.", | |
) | |
parser.add_argument( | |
"--test_data", | |
type=str, | |
required=False, | |
help="Path to the test data. If provided, the duplicates will be removed from the input data.", | |
) | |
parser.add_argument( | |
"--input_max_length", | |
type=int, | |
default=400, | |
help="Maximum token length of input.", | |
) | |
parser.add_argument( | |
"--model_name_or_path", | |
type=str, | |
default="sagawa/ReactionT5v2-forward", | |
help="Name or path of the finetuned model for prediction. Can be a local model or one from Hugging Face.", | |
) | |
parser.add_argument( | |
"--batch_size", type=int, default=5, help="Batch size for prediction." | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default="./", | |
help="Directory where predictions are saved.", | |
) | |
parser.add_argument( | |
"--debug", action="store_true", default=False, help="Use debug mode." | |
) | |
parser.add_argument( | |
"--seed", type=int, default=42, help="Seed for reproducibility." | |
) | |
return parser.parse_args() | |
def create_embedding(dataloader, model, device): | |
outputs_mean = [] | |
model.eval() | |
model.to(device) | |
for inputs in dataloader: | |
inputs = {k: v.to(CFG.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
output = model(**inputs) | |
last_hidden_states = output[0] | |
input_mask_expanded = ( | |
inputs["attention_mask"] | |
.unsqueeze(-1) | |
.expand(last_hidden_states.size()) | |
.float() | |
) | |
sum_embeddings = torch.sum(last_hidden_states * input_mask_expanded, 1) | |
sum_mask = input_mask_expanded.sum(1) | |
sum_mask = torch.clamp(sum_mask, min=1e-6) | |
mean_embeddings = sum_embeddings / sum_mask | |
outputs_mean.append(mean_embeddings.detach().cpu().numpy()) | |
return outputs_mean | |
if __name__ == "__main__": | |
CFG = parse_args() | |
CFG.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if not os.path.exists(CFG.output_dir): | |
os.makedirs(CFG.output_dir) | |
seed_everything(seed=CFG.seed) | |
CFG.tokenizer = AutoTokenizer.from_pretrained( | |
os.path.abspath(CFG.model_name_or_path) | |
if os.path.exists(CFG.model_name_or_path) | |
else CFG.model_name_or_path, | |
return_tensors="pt", | |
) | |
model = T5EncoderModel.from_pretrained(CFG.model_name_or_path).to(CFG.device) | |
model.eval() | |
input_data = filter_out(pd.read_csv(CFG.input_data), ["REACTANT", "PRODUCT"]) | |
input_data = preprocess_df(input_data, drop_duplicates=False) | |
if CFG.test_data: | |
input_data_copy = preprocess_USPTO(input_data.copy()) | |
test_data = filter_out(pd.read_csv(CFG.test_data), ["REACTANT", "PRODUCT"]) | |
USPTO_test = preprocess_USPTO(test_data) | |
input_data = input_data[ | |
~input_data_copy["pair"].isin(USPTO_test["pair"]) | |
].reset_index(drop=True) | |
input_data.to_csv(os.path.join(CFG.output_dir, "input_data.csv"), index=False) | |
dataset = ReactionT5Dataset(CFG, input_data) | |
dataloader = DataLoader( | |
dataset, | |
batch_size=CFG.batch_size, | |
shuffle=False, | |
num_workers=4, | |
pin_memory=True, | |
drop_last=False, | |
) | |
outputs = create_embedding(dataloader, model, CFG.device) | |
outputs = np.concatenate(outputs, axis=0) | |
np.save(os.path.join(CFG.output_dir, "embedding_mean.npy"), outputs) | |