Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from example_strings import example1, example2, example3 | |
| template_str = """{table_schemas} | |
| \n \n | |
| {task_spec} | |
| \n \n | |
| {prompt} | |
| \n \n | |
| SELECT""" | |
| def load_model(model_name: str): | |
| tokenizer = AutoTokenizer.from_pretrained(f"NumbersStation/{model_name}") | |
| model = AutoModelForCausalLM.from_pretrained(f"NumbersStation/{model_name}") | |
| return tokenizer, model | |
| def build_complete_prompt(table_schemas: str, task_spec: str, prompt: str) -> str: | |
| return template_str.format(table_schemas=table_schemas, task_spec=task_spec, prompt=prompt) | |
| def infer(table_schemas: str, task_spec: str, prompt: str, model_choice: str = "nsql-350M"): | |
| tokenizer, model = load_model(model_choice) | |
| input_text = build_complete_prompt(table_schemas, task_spec, prompt) | |
| input_ids = tokenizer(input_text, return_tensors="pt").input_ids | |
| generated_ids = model.generate(input_ids, max_length=500) | |
| return (tokenizer.decode(generated_ids[0], skip_special_tokens=True)) | |
| description = """The NSQL model family was published by [Numbers Station](https://www.numbersstation.ai/) and is available in three flavors: | |
| - [nsql-6B](https://huggingface.co/NumbersStation/nsql-6B) | |
| - [nsql-2B](https://huggingface.co/NumbersStation/nsql-2B) | |
| - [nsql-350M]((https://huggingface.co/NumbersStation/nsql-350M)) | |
| This demo let's you choose from all of them and provides the three examples you can also find in their model cards. | |
| In general you should first provide the table schemas of the tables you have questions about and then prompt it with a natural language question. | |
| The model will then generate a SQL query that you can run against your database. | |
| """ | |
| iface = gr.Interface( | |
| title="Text to SQL with NSQL", | |
| description=description, | |
| fn=infer, | |
| inputs=[gr.Text(label="Table schemas", placeholder="Insert your table schemas here"), | |
| gr.Text(label="Specify Task", value="Using valid SQLite, answer the following questions for the tables provided above."), | |
| gr.Text(label="Prompt", placeholder="Put your natural language prompt here"), | |
| gr.Dropdown(["nsql-6B", "nsql-2B", "nsql-350M"], value="nsql-6B") | |
| ], | |
| outputs="text", | |
| examples=[example1, example2, example3]) | |
| iface.launch() | |