texttosql / app.py
Aaryaparikh's picture
Upload app.py
a56e433 verified
raw
history blame contribute delete
No virus
3.01 kB
import streamlit as st
import re
import pandas as pd
import numpy as np
import time
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
# Streamlit app
st.title("Private Sample")
tokenizer = T5Tokenizer.from_pretrained('t5-small')
# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T5ForConditionalGeneration.from_pretrained('cssupport/t5-small-awesome-text-to-sql')
model = model.to(device)
model.eval()
def generate_sql(input_prompt):
# Tokenize the input prompt
inputs = tokenizer(input_prompt, padding=True, truncation=True, return_tensors="pt").to(device)
# Forward pass
with torch.no_grad():
outputs = model.generate(**inputs, max_length=512)
# Decode the output IDs to a string (SQL query in this case)
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_sql
prompt=st.text_input("Enter Prompt: ","get target from app saless")
button_clicked=st.button("Generate")
if button_clicked:
input_prompt = "tables:\n" + "CREATE TABLE AppDrug_allergy_dataset( Id,ExtraProperties,ConcurrencyStamp,CreationTime,CreatorId,LastModificationTime,LastModifierId,IsDeleted,DeleterId,DeletionTime,Drug_Name,Chemical_Structure,Immunogenecity,Individual_Sensitivity,Prior_Allergic_Reaction,Cross_Reactivity,Route_of_administration,Dose,Duration,Hypersensitivity_Reaction,Allergic) CREATE TABLE AppSaless( Id,ExtraProperties,ConcurrencyStamp,CreationTime, CreatorId,LastModificationTime,LastModifierId,IsDeleted,DeleterId,DeletionTime,Month,Target,Customers_,Revenue)" + "\n" +"query for:" + prompt
generated_sql = generate_sql(input_prompt)
print(f"The generated SQL query is: {generated_sql}")
# Test the function
#input_prompt = "tables:\n" + "CREATE TABLE Catalogs (date_of_latest_revision VARCHAR)" + "\n" +"query for: Find the dates on which more than one revisions were made."
#input_prompt = "tables:\n" + "CREATE TABLE table_22767 ( \"Year\" real, \"World\" real, \"Asia\" text, \"Africa\" text, \"Europe\" text, \"Latin America/Caribbean\" text, \"Northern America\" text, \"Oceania\" text )" + "\n" +"query for:what will the population of Asia be when Latin America/Caribbean is 783 (7.5%)?."
# input_prompt = "Retrieve the names of all employees who work in the IT department."
#OUTPUT: The generated SQL query is: SELECT student_id FROM students WHERE NOT student_id IN (SELECT student_id FROM student_course_attendance)
progress_bar = st.progress(0)
status_text = st.empty()
chart = st.line_chart(np.random.randn(10, 2))
for i in range(100):
# Update progress bar.
progress_bar.progress(i + 1)
new_rows = np.random.randn(10, 2)
# Update status text.
status_text.text(
'The latest random number is: %s' % new_rows[-1, 1])
# Append data to the chart.
chart.add_rows(new_rows)
# Pretend we're doing some computation that takes time.
time.sleep(0.1)
status_text.text('Done!')
st.balloons()