import json import random from typing import List, Dict, Tuple from adaptor.evaluators.generative import ROUGE, BLEU from adaptor.lang_module import LangModule from adaptor.objectives.seq2seq import Sequence2Sequence from adaptor.utils import AdaptationArguments, StoppingStrategy from adaptor.schedules import ParallelSchedule from adaptor.adapter import Adapter import wandb # Dataset creation ## Define paths to JSON files db_path = 'db_schemas.json' spider_dataset_train_path = 'spider/train_spider.json' spider_dataset_dev_path = 'spider/dev.json' spider_syn_train_path = 'Spider-Syn/train_spider.json' spider_syn_dev_path = 'Spider-Syn/dev.json' ## Open files with open(db_path, 'r') as file_db: database_schemas = json.load(file_db) with open(spider_dataset_train_path, 'r') as file_spider: spider_train_dataset = json.load(file_spider) with open(spider_dataset_dev_path, 'r') as file_spider: spider_dev_dataset = json.load(file_spider) with open(spider_syn_train_path, 'r') as file_spider: spider_syn_train_dataset = json.load(file_spider) with open(spider_syn_dev_path, 'r') as file_spider: spider_syn_dev_dataset = json.load(file_spider) ## Include spider questions with synonyms (questions include text which is not in DB columns) spider_train_dataset.extend([question for question in spider_syn_train_dataset if question['SpiderQuestion']!=question['SpiderSynQuestion']]) spider_dev_dataset.extend([question for question in spider_syn_dev_dataset if question['SpiderQuestion']!=question['SpiderSynQuestion']]) random.shuffle(spider_train_dataset) random.shuffle(spider_dev_dataset) def create_prompt(question: str, schema: str) -> str: return " ".join(["Question: ",question, "Schema:", schema]) def create_vals_and_labels(dataset: List[dict], db_dict: Dict[str, str]) -> Tuple[List[str], List[str]]: list_labels = [data["query"] for data in dataset] list_prompts = [create_prompt(data["question"], db_dict[data["db_id"]]) if "question" in data else create_prompt(data["SpiderSynQuestion"], db_dict[data["db_id"]]) for data in dataset] return list_prompts, list_labels ## Training prompts and labels prompts_train, labels_train = create_vals_and_labels(spider_train_dataset, database_schemas) assert len(prompts_train)==len(labels_train) ## Evaluation prompts and labels prompts_eval, labels_eval = create_vals_and_labels(spider_dev_dataset, database_schemas) assert len(prompts_eval)==len(labels_eval) # Training lang_module = LangModule("google/t5-large-lm-adapt") evaluators = [BLEU(), ROUGE(decides_convergence=True)] wandb.init(project="chatbot") seq_qa = Sequence2Sequence(lang_module, texts_or_path=prompts_train, labels_or_path=labels_train, val_texts_or_path=prompts_eval, val_labels_or_path=labels_eval, batch_size=4, val_evaluators=evaluators, objective_id="txt2SQL_Spider") training_arguments = AdaptationArguments(output_dir="checkpoints-txt2sql", learning_rate=5e-5, stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED, stopping_patience=8, save_total_limit=8, do_train=True, do_eval=True, bf16=True, warmup_steps=100, gradient_accumulation_steps=8, logging_steps=10, eval_steps=200, save_steps=200, num_train_epochs=10, evaluation_strategy="steps") schedule = ParallelSchedule(objectives=[seq_qa], args=training_arguments) adapter = Adapter(lang_module, schedule, args=training_arguments) adapter.train()