import argparse import os import sys import warnings from pathlib import Path import datasets import pandas as pd import torch from datasets import Dataset, DatasetDict from transformers import ( AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, EarlyStoppingCallback, Seq2SeqTrainer, Seq2SeqTrainingArguments, ) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from utils import ( add_new_tokens, canonicalize, filter_out, get_accuracy_score, preprocess_dataset, seed_everything, space_clean, ) # Suppress warnings and disable progress bars warnings.filterwarnings("ignore") datasets.utils.logging.disable_progress_bar() def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser( description="Training script for reaction prediction model." ) parser.add_argument( "--train_data_path", type=str, required=True, help="Path to training data CSV." ) parser.add_argument( "--valid_data_path", type=str, required=True, help="Path to validation data CSV.", ) parser.add_argument("--test_data_path", type=str, help="Path to test data CSV.") parser.add_argument( "--USPTO_test_data_path", type=str, help="The path to data used for USPTO testing. CSV file that contains ['REACTANT', 'REAGENT', 'PRODUCT'] columns is expected.", ) parser.add_argument( "--output_dir", type=str, default="t5", help="Path of the output directory." ) parser.add_argument( "--pretrained_model_name_or_path", type=str, required=True, help="Pretrained model path or name.", ) parser.add_argument( "--debug", action="store_true", default=False, help="Enable debug mode." ) parser.add_argument( "--epochs", type=int, default=5, help="Number of epochs.", ) parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate.") parser.add_argument("--batch_size", type=int, default=16, help="Batch size.") parser.add_argument( "--input_max_length", type=int, default=400, help="Max input token length.", ) parser.add_argument( "--target_max_length", type=int, default=150, help="Max target token length.", ) parser.add_argument( "--eval_beams", type=int, default=5, help="Number of beams used for beam search during evaluation.", ) parser.add_argument( "--target_column", type=str, default="PRODUCT", help="Target column name.", ) parser.add_argument( "--weight_decay", type=float, default=0.01, help="Weight decay.", ) parser.add_argument( "--evaluation_strategy", type=str, default="epoch", help="Evaluation strategy used during training. Select from 'no', 'steps', or 'epoch'. If you select 'steps', also give --eval_steps.", ) parser.add_argument( "--eval_steps", type=int, help="Evaluation steps.", ) parser.add_argument( "--save_strategy", type=str, default="epoch", help="Save strategy used during training. Select from 'no', 'steps', or 'epoch'. If you select 'steps', also give --save_steps.", ) parser.add_argument( "--save_steps", type=int, default=500, help="Save steps.", ) parser.add_argument( "--logging_strategy", type=str, default="epoch", help="Logging strategy used during training. Select from 'no', 'steps', or 'epoch'. If you select 'steps', also give --logging_steps.", ) parser.add_argument( "--logging_steps", type=int, default=500, help="Logging steps.", ) parser.add_argument( "--save_total_limit", type=int, default=2, help="Limit of saved checkpoints.", ) parser.add_argument( "--fp16", action="store_true", default=False, help="Enable fp16 training.", ) parser.add_argument( "--disable_tqdm", action="store_true", default=False, help="Disable tqdm.", ) parser.add_argument( "--seed", type=int, default=42, help="Random seed.", ) return parser.parse_args() def preprocess_df(df, drop_duplicates=True): """Preprocess the dataframe by filling NaNs, dropping duplicates, and formatting the input.""" for col in ["REACTANT", "PRODUCT", "CATALYST", "REAGENT", "SOLVENT"]: if col not in df.columns: df[col] = None df[col] = df[col].fillna(" ") if drop_duplicates: df = ( df[["REACTANT", "PRODUCT", "CATALYST", "REAGENT", "SOLVENT"]] .drop_duplicates() .reset_index(drop=True) ) df["REAGENT"] = df["CATALYST"] + "." + df["REAGENT"] + "." + df["SOLVENT"] df["REAGENT"] = df["REAGENT"].apply(lambda x: space_clean(x)) df["REAGENT"] = df["REAGENT"].apply(lambda x: canonicalize(x) if x != " " else " ") df["input"] = "REACTANT:" + df["REACTANT"] + "REAGENT:" + df["REAGENT"] return df def preprocess_USPTO(df): df["REACTANT"] = df["REACTANT"].apply(lambda x: str(sorted(x.split(".")))) df["REAGENT"] = df["REAGENT"].apply(lambda x: str(sorted(x.split(".")))) df["PRODUCT"] = df["PRODUCT"].apply(lambda x: str(sorted(x.split(".")))) df["input"] = "REACTANT:" + df["REACTANT"] + "REAGENT:" + df["REAGENT"] df["pair"] = df["input"] + " - " + df["PRODUCT"].astype(str) return df if __name__ == "__main__": CFG = parse_args() CFG.disable_tqdm = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") seed_everything(seed=CFG.seed) # Load and preprocess data train = preprocess_df( filter_out(pd.read_csv(CFG.train_data_path), ["REACTANT", "PRODUCT"]) ) valid = preprocess_df( filter_out(pd.read_csv(CFG.valid_data_path), ["REACTANT", "PRODUCT"]) ) if CFG.USPTO_test_data_path: train_copy = preprocess_USPTO(train.copy()) USPTO_test = preprocess_USPTO(pd.read_csv(CFG.USPTO_test_data_path)) train = train[~train_copy["pair"].isin(USPTO_test["pair"])].reset_index( drop=True ) train["pair"] = train["input"] + " - " + train["PRODUCT"] valid["pair"] = valid["input"] + " - " + valid["PRODUCT"] valid = valid[~valid["pair"].isin(train["pair"])].reset_index(drop=True) train.to_csv("train.csv", index=False) valid.to_csv("valid.csv", index=False) if CFG.test_data_path: test = preprocess_df( filter_out(pd.read_csv(CFG.test_data_path), ["REACTANT", "PRODUCT"]) ) test["pair"] = test["input"] + " - " + test["PRODUCT"] test = test[~test["pair"].isin(train["pair"])].reset_index(drop=True) test = test.drop_duplicates(subset=["pair"]).reset_index(drop=True) test.to_csv("test.csv", index=False) dataset = DatasetDict( { "train": Dataset.from_pandas(train[["input", "PRODUCT"]]), "validation": Dataset.from_pandas(valid[["input", "PRODUCT"]]), } ) # load tokenizer tokenizer = AutoTokenizer.from_pretrained( os.path.abspath(CFG.pretrained_model_name_or_path) if os.path.exists(CFG.pretrained_model_name_or_path) else CFG.pretrained_model_name_or_path, return_tensors="pt", ) tokenizer = add_new_tokens( tokenizer, Path(__file__).resolve().parent.parent / "data" / "additional_tokens.txt", ) tokenizer.add_special_tokens( { "additional_special_tokens": tokenizer.additional_special_tokens + ["REACTANT:", "REAGENT:"] } ) CFG.tokenizer = tokenizer # load model model = AutoModelForSeq2SeqLM.from_pretrained( os.path.abspath(CFG.pretrained_model_name_or_path) if os.path.exists(CFG.pretrained_model_name_or_path) else CFG.pretrained_model_name_or_path ) model.resize_token_embeddings(len(tokenizer)) tokenized_datasets = dataset.map( lambda examples: preprocess_dataset(examples, CFG), batched=True, remove_columns=dataset["train"].column_names, load_from_cache_file=False, ) data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) args = Seq2SeqTrainingArguments( CFG.output_dir, evaluation_strategy=CFG.evaluation_strategy, eval_steps=CFG.eval_steps, save_strategy=CFG.save_strategy, save_steps=CFG.save_steps, logging_strategy=CFG.logging_strategy, logging_steps=CFG.logging_steps, learning_rate=CFG.lr, per_device_train_batch_size=CFG.batch_size, per_device_eval_batch_size=CFG.batch_size, weight_decay=CFG.weight_decay, save_total_limit=CFG.save_total_limit, num_train_epochs=CFG.epochs, predict_with_generate=True, fp16=CFG.fp16, disable_tqdm=CFG.disable_tqdm, push_to_hub=False, load_best_model_at_end=True, ) model.config.eval_beams = CFG.eval_beams model.config.max_length = CFG.target_max_length trainer = Seq2SeqTrainer( model, args, train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["validation"], data_collator=data_collator, tokenizer=tokenizer, compute_metrics=lambda eval_preds: get_accuracy_score(eval_preds, CFG), callbacks=[EarlyStoppingCallback(early_stopping_patience=10)], ) try: trainer.train(resume_from_checkpoint=True) except: trainer.train(resume_from_checkpoint=None) trainer.save_model("./best_model")