Spaces:
Running
Running
import argparse | |
import gc | |
import os | |
import sys | |
import warnings | |
import pandas as pd | |
import torch | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
from generation_utils import ( | |
ReactionT5Dataset, | |
decode_output, | |
save_multiple_predictions, | |
) | |
from train import preprocess_df | |
from utils import seed_everything | |
warnings.filterwarnings("ignore") | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description="Script for reaction retrosynthesis prediction." | |
) | |
parser.add_argument( | |
"--input_data", | |
type=str, | |
required=True, | |
help="Path to the input data.", | |
) | |
parser.add_argument( | |
"--input_max_length", | |
type=int, | |
default=400, | |
help="Maximum token length of input.", | |
) | |
parser.add_argument( | |
"--output_min_length", | |
type=int, | |
default=1, | |
help="Minimum token length of output.", | |
) | |
parser.add_argument( | |
"--output_max_length", | |
type=int, | |
default=300, | |
help="Maximum token length of output.", | |
) | |
parser.add_argument( | |
"--model_name_or_path", | |
type=str, | |
default="sagawa/ReactionT5v2-retrosynthesis", | |
help="Name or path of the finetuned model for prediction. Can be a local model or one from Hugging Face.", | |
) | |
parser.add_argument( | |
"--num_beams", type=int, default=5, help="Number of beams used for beam search." | |
) | |
parser.add_argument( | |
"--num_return_sequences", | |
type=int, | |
default=5, | |
help="Number of predictions returned. Must be less than or equal to num_beams.", | |
) | |
parser.add_argument( | |
"--batch_size", type=int, default=5, help="Batch size for prediction." | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default="./", | |
help="Directory where predictions are saved.", | |
) | |
parser.add_argument( | |
"--debug", action="store_true", default=False, help="Use debug mode." | |
) | |
parser.add_argument( | |
"--seed", type=int, default=42, help="Seed for reproducibility." | |
) | |
return parser.parse_args() | |
if __name__ == "__main__": | |
CFG = parse_args() | |
CFG.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if not os.path.exists(CFG.output_dir): | |
os.makedirs(CFG.output_dir) | |
seed_everything(seed=CFG.seed) | |
CFG.tokenizer = AutoTokenizer.from_pretrained( | |
os.path.abspath(CFG.model_name_or_path) | |
if os.path.exists(CFG.model_name_or_path) | |
else CFG.model_name_or_path, | |
return_tensors="pt", | |
) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
os.path.abspath(CFG.model_name_or_path) | |
if os.path.exists(CFG.model_name_or_path) | |
else CFG.model_name_or_path | |
).to(CFG.device) | |
model.eval() | |
input_data = pd.read_csv(CFG.input_data) | |
input_data = preprocess_df(input_data, drop_duplicates=False) | |
dataset = ReactionT5Dataset(CFG, input_data) | |
dataloader = DataLoader( | |
dataset, | |
batch_size=CFG.batch_size, | |
shuffle=False, | |
num_workers=4, | |
pin_memory=True, | |
drop_last=False, | |
) | |
all_sequences, all_scores = [], [] | |
for inputs in tqdm(dataloader, total=len(dataloader)): | |
inputs = {k: v.to(CFG.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
output = model.generate( | |
**inputs, | |
min_length=CFG.output_min_length, | |
max_length=CFG.output_max_length, | |
num_beams=CFG.num_beams, | |
num_return_sequences=CFG.num_return_sequences, | |
return_dict_in_generate=True, | |
output_scores=True, | |
) | |
sequences, scores = decode_output(output, CFG) | |
all_sequences.extend(sequences) | |
if scores: | |
all_scores.extend(scores) | |
del output | |
torch.cuda.empty_cache() | |
gc.collect() | |
output_df = save_multiple_predictions(input_data, all_sequences, all_scores, CFG) | |
output_df.to_csv(os.path.join(CFG.output_dir, "output.csv"), index=False) | |