|
import json |
|
import random |
|
from typing import List, Dict, Tuple |
|
from adaptor.evaluators.generative import ROUGE, BLEU |
|
from adaptor.lang_module import LangModule |
|
from adaptor.objectives.seq2seq import Sequence2Sequence |
|
from adaptor.utils import AdaptationArguments, StoppingStrategy |
|
from adaptor.schedules import ParallelSchedule |
|
from adaptor.adapter import Adapter |
|
import wandb |
|
|
|
|
|
|
|
|
|
db_path = 'db_schemas.json' |
|
spider_dataset_train_path = 'spider/train_spider.json' |
|
spider_dataset_dev_path = 'spider/dev.json' |
|
spider_syn_train_path = 'Spider-Syn/train_spider.json' |
|
spider_syn_dev_path = 'Spider-Syn/dev.json' |
|
|
|
|
|
with open(db_path, 'r') as file_db: |
|
database_schemas = json.load(file_db) |
|
|
|
with open(spider_dataset_train_path, 'r') as file_spider: |
|
spider_train_dataset = json.load(file_spider) |
|
|
|
with open(spider_dataset_dev_path, 'r') as file_spider: |
|
spider_dev_dataset = json.load(file_spider) |
|
|
|
with open(spider_syn_train_path, 'r') as file_spider: |
|
spider_syn_train_dataset = json.load(file_spider) |
|
|
|
with open(spider_syn_dev_path, 'r') as file_spider: |
|
spider_syn_dev_dataset = json.load(file_spider) |
|
|
|
|
|
spider_train_dataset.extend([question for question in spider_syn_train_dataset if question['SpiderQuestion']!=question['SpiderSynQuestion']]) |
|
spider_dev_dataset.extend([question for question in spider_syn_dev_dataset if question['SpiderQuestion']!=question['SpiderSynQuestion']]) |
|
|
|
random.shuffle(spider_train_dataset) |
|
random.shuffle(spider_dev_dataset) |
|
|
|
def create_prompt(question: str, schema: str) -> str: |
|
return " ".join(["Question: ",question, "Schema:", schema]) |
|
|
|
def create_vals_and_labels(dataset: List[dict], db_dict: Dict[str, str]) -> Tuple[List[str], List[str]]: |
|
list_labels = [data["query"] for data in dataset] |
|
list_prompts = [create_prompt(data["question"], db_dict[data["db_id"]]) |
|
if "question" in data else create_prompt(data["SpiderSynQuestion"], db_dict[data["db_id"]]) for data in dataset] |
|
return list_prompts, list_labels |
|
|
|
|
|
prompts_train, labels_train = create_vals_and_labels(spider_train_dataset, database_schemas) |
|
assert len(prompts_train)==len(labels_train) |
|
|
|
|
|
prompts_eval, labels_eval = create_vals_and_labels(spider_dev_dataset, database_schemas) |
|
assert len(prompts_eval)==len(labels_eval) |
|
|
|
|
|
|
|
lang_module = LangModule("google/t5-large-lm-adapt") |
|
evaluators = [BLEU(), ROUGE(decides_convergence=True)] |
|
|
|
wandb.init(project="chatbot") |
|
|
|
seq_qa = Sequence2Sequence(lang_module, |
|
texts_or_path=prompts_train, |
|
labels_or_path=labels_train, |
|
val_texts_or_path=prompts_eval, |
|
val_labels_or_path=labels_eval, |
|
batch_size=4, |
|
val_evaluators=evaluators, |
|
objective_id="txt2SQL_Spider") |
|
|
|
training_arguments = AdaptationArguments(output_dir="checkpoints-txt2sql", |
|
learning_rate=5e-5, |
|
stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED, |
|
stopping_patience=8, |
|
save_total_limit=8, |
|
do_train=True, |
|
do_eval=True, |
|
bf16=True, |
|
warmup_steps=100, |
|
gradient_accumulation_steps=8, |
|
logging_steps=10, |
|
eval_steps=200, |
|
save_steps=200, |
|
num_train_epochs=10, |
|
evaluation_strategy="steps") |
|
|
|
|
|
schedule = ParallelSchedule(objectives=[seq_qa], |
|
args=training_arguments) |
|
adapter = Adapter(lang_module, schedule, args=training_arguments) |
|
|
|
|
|
adapter.train() |