How did I finetune t5 using wikisql dataset?

#2
by dsivakumar - opened

Thanks to Shivanandroy, this https://github.com/Shivanandroy/T5-Finetuning-PyTorch/blob/main/notebook/T5_Fine_tuning_with_PyTorch.ipynb, helped to understand and adapt it to text 2 SQL

Only change or major changes is Dataloader, and the wiki dataset I converted it into, just two columns, inputs as 'qurey' and targets as 'sql'

                                        Sample Data                                         
+-------------------------------------------------------------------------------------------+
|                source_text                  |                 target_text                 |
|---------------------------------------------+---------------------------------------------|
| What is the season year where the rank is   |      SELECT tv season WHERE rank EQL 39     |
|                    39?                      |                                             |
|What is the number of season premieres were  | SELECT count(season premiere) WHERE viewers |
|           10.17 people watched?             |             (millions) EQL 10.17            |
+-------------------------------------------------------------------------------------------+

Dataset class

class CSQLSetClass(Dataset):
  """
 Using wikiSQL dataset for reading the dataset and 
  loading it into the dataloader to pass it to the neural network for finetuning the model

  """

  def __init__(self, dataframe, tokenizer, source_len, target_len, source_text, target_text):
    self.tokenizer = tokenizer
    self.data = dataframe
    self.source_len = source_len
    self.summ_len = target_len
    self.target_text = self.data[target_text]
    self.source_text = self.data[source_text]

    self.data["query"] = "English to SQL: "+self.data["query"]
    self.data["sql"] = "<pad>" + self.data["sql"] + "</s>"
    
  def __len__(self):
    return len(self.target_text)

  def __getitem__(self, index):
    source_text = str(self.source_text[index])
    target_text = str(self.target_text[index])

    #cleaning data so as to ensure data is in string type
    source_text = ' '.join(source_text.split())
    target_text = ' '.join(target_text.split())

    source = self.tokenizer.batch_encode_plus([source_text], max_length= self.source_len, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt')
    target = self.tokenizer.batch_encode_plus([target_text], max_length= self.summ_len, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt')

    source_ids = source['input_ids'].squeeze()
    source_mask = source['attention_mask'].squeeze()
    target_ids = target['input_ids'].squeeze()
    target_mask = target['attention_mask'].squeeze()

    return {
        'source_ids': source_ids.to(dtype=torch.long), 
        'source_mask': source_mask.to(dtype=torch.long), 
        'target_ids': target_ids.to(dtype=torch.long),
        'target_ids_y': target_ids.to(dtype=torch.long)
    }

Prediction function takes a plain English question

#Predict function 
def get_sql(query,tokenizer,model):
    source_text= "English to SQL: "+query
    source_text = ' '.join(source_text.split())
    source = tokenizer.batch_encode_plus([source_text],max_length= 128, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt')
    source_ids = source['input_ids'] #.squeeze()
    source_mask = source['attention_mask']#.squeeze()
    generated_ids = model.generate(
      input_ids = source_ids.to(dtype=torch.long),
      attention_mask = source_mask.to(dtype=torch.long), 
      max_length=150, 
      num_beams=2,
      repetition_penalty=2.5, 
      length_penalty=1.0, 
      early_stopping=True
      )
    preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
    return preds

Sign up or log in to comment