Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
# pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog") | |
# def predict(input_img): | |
# predictions = pipeline(input_img) | |
# return input_img, {p["label"]: p["score"] for p in predictions} | |
from typing import List | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
tokenizer = AutoTokenizer.from_pretrained("juierror/text-to-sql-with-table-schema") | |
model = AutoModelForSeq2SeqLM.from_pretrained("juierror/text-to-sql-with-table-schema") | |
def prepare_input(question: str, table: List[str]): | |
table_prefix = "table:" | |
question_prefix = "question:" | |
join_table = ",".join(table) | |
inputs = f"{question_prefix} {question} {table_prefix} {join_table}" | |
input_ids = tokenizer(inputs, max_length=700, return_tensors="pt").input_ids | |
return input_ids | |
def inference(question: str, table: str) -> str: | |
cols = table.split(",") | |
input_data = prepare_input(question=question, table=cols) | |
input_data = input_data.to(model.device) | |
outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=700) | |
result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True) | |
return result | |
# print(inference(question="get people name with age equal 25", table=["id", "name", "age"])) | |
gradio_app = gr.Interface( | |
inference, | |
inputs=["textbox", "textbox"], | |
outputs="label", | |
title="Text To SQL", | |
) | |
if __name__ == "__main__": | |
gradio_app.launch() |