Edit model card

This is an upgraded version of https://huggingface.co/juierror/flan-t5-text2sql-with-schema.

It supports the '<' sign and can handle multiple tables.

How to use

from typing import List
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")
model = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")

def get_prompt(tables, question):
    prompt = f"""convert question and table into SQL query. tables: {tables}. question: {question}"""
    return prompt

def prepare_input(question: str, tables: Dict[str, List[str]]):
    tables = [f"""{table_name}({",".join(tables[table_name])})""" for table_name in tables]
    tables = ", ".join(tables)
    prompt = get_prompt(tables, question)
    input_ids = tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
    return input_ids

def inference(question: str, tables: Dict[str, List[str]]) -> str:
    input_data = prepare_input(question=question, tables=tables)
    input_data = input_data.to(model.device)
    outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
    result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
    return result

print(inference("how many people with name jui and age less than 25", {
    "people_name": ["id", "name"],
    "people_age": ["people_id", "age"]
}))

print(inference("what is id with name jui and age less than 25", {
    "people_name": ["id", "name", "age"]
})))

Dataset

Downloads last month
1,259
Safetensors
Model size
248M params
Tensor type
F32
ยท
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Space using juierror/flan-t5-text2sql-with-schema-v2 1