import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from example_strings import example1, example2, example3 template_str = """{table_schemas} \n \n {task_spec} \n \n {prompt} \n \n SELECT""" def load_model(model_name: str): tokenizer = AutoTokenizer.from_pretrained(f"NumbersStation/{model_name}") model = AutoModelForCausalLM.from_pretrained(f"NumbersStation/{model_name}") return tokenizer, model def build_complete_prompt(table_schemas: str, task_spec: str, prompt: str) -> str: return template_str.format(table_schemas=table_schemas, task_spec=task_spec, prompt=prompt) def infer(table_schemas: str, task_spec: str, prompt: str, model_choice: str = "nsql-350M"): tokenizer, model = load_model(model_choice) input_text = build_complete_prompt(table_schemas, task_spec, prompt) input_ids = tokenizer(input_text, return_tensors="pt").input_ids generated_ids = model.generate(input_ids, max_length=500) return (tokenizer.decode(generated_ids[0], skip_special_tokens=True)) description = """The NSQL model family was published by [Numbers Station](https://www.numbersstation.ai/) and is available in three flavors: - [nsql-6B](https://huggingface.co/NumbersStation/nsql-6B) - [nsql-2B](https://huggingface.co/NumbersStation/nsql-2B) - [nsql-350M]((https://huggingface.co/NumbersStation/nsql-350M)) For now you can only use the 350M version of the model here, as the file size of the other models exceeds the max memory available in spaces. In general you should first provide the table schemas of the tables you have questions about and then prompt it with a natural language question. The model will then generate a SQL query that you can run against your database. """ iface = gr.Interface( title="Text to SQL with NSQL", description=description, fn=infer, inputs=[gr.Text(label="Table schemas", placeholder="Insert your table schemas here"), gr.Text(label="Specify Task", value="Using valid SQLite, answer the following questions for the tables provided above."), gr.Text(label="Prompt", placeholder="Put your natural language prompt here"), gr.Dropdown(["nsql-6B", "nsql-2B", "nsql-350M"], value="nsql-6B") ], outputs="text", examples=[example1, example2, example3]) iface.launch()