--- language: - en datasets: - wikisql widget: - text: "English to SQL: Show me the average age of of wines in Italy by provinces" - text: "English to SQL: What is the current series where the new series began in June 2011?" --- #import transformers ``` from transformers import ( T5ForConditionalGeneration, T5Tokenizer, ) #load model model = T5ForConditionalGeneration.from_pretrained('dsivakumar/text2sql') tokenizer = T5Tokenizer.from_pretrained('dsivakumar/text2sql') #predict function def get_sql(query,tokenizer,model): source_text= "English to SQL: "+query source_text = ' '.join(source_text.split()) source = tokenizer.batch_encode_plus([source_text],max_length= 128, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt') source_ids = source['input_ids'] #.squeeze() source_mask = source['attention_mask']#.squeeze() generated_ids = model.generate( input_ids = source_ids.to(dtype=torch.long), attention_mask = source_mask.to(dtype=torch.long), max_length=150, num_beams=2, repetition_penalty=2.5, length_penalty=1.0, early_stopping=True ) preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids] return preds #test query="Show me the average age of of wines in Italy by provinces" sql = get_sql(query,tokenizer,model) print(sql) #https://huggingface.co/mrm8488/t5-small-finetuned-wikiSQL def get_sql(query): input_text = "translate English to SQL: %s " % query features = tokenizer([input_text], return_tensors='pt') output = model.generate(input_ids=features['input_ids'], attention_mask=features['attention_mask']) return tokenizer.decode(output[0]) query = "How many models were finetuned using BERT as base model?" get_sql(query) ```