|
|
|
""" |
|
Created on Fri May 26 14:07:22 2023 |
|
|
|
@author: vibin |
|
""" |
|
|
|
import streamlit as st |
|
from pandasql import sqldf |
|
import pandas as pd |
|
import re |
|
from typing import List |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
import re |
|
|
|
|
|
@st.cache_resource() |
|
def tapas_model(): |
|
return(pipeline(task="table-question-answering", model="google/tapas-base-finetuned-wtq")) |
|
|
|
@st.cache_resource() |
|
def prepare_input(question: str, table: List[str]): |
|
table_prefix = "table:" |
|
question_prefix = "question:" |
|
join_table = ",".join(table) |
|
inputs = f"{question_prefix} {question} {table_prefix} {join_table}" |
|
input_ids = tokenizer(inputs, max_length=512, return_tensors="pt").input_ids |
|
return input_ids |
|
|
|
@st.cache_resource() |
|
def inference(question: str, table: List[str]) -> str: |
|
input_data = prepare_input(question=question, table=table) |
|
input_data = input_data.to(model.device) |
|
outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=700) |
|
result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True) |
|
return result |
|
|
|
@st.cache_resource() |
|
def tokmod(tok_md): |
|
tkn = AutoTokenizer.from_pretrained(tok_md) |
|
mdl = AutoModelForSeq2SeqLM.from_pretrained(tok_md) |
|
return(tkn,mdl) |
|
|
|
|
|
|
|
|
|
nav = st.sidebar.radio("Navigation",["TAPAS","Text2SQL"]) |
|
if nav == "TAPAS": |
|
|
|
col1 , col2, col3 = st.columns(3) |
|
col2.title("TAPAS") |
|
|
|
col3 , col4 = st.columns([3,12]) |
|
col4.text("Tabular Data Text Extraction using text") |
|
|
|
table = pd.read_csv("data.csv") |
|
table = table.astype(str) |
|
st.text("DataSet - ") |
|
st.dataframe(table,width=3000,height= 400) |
|
|
|
st.title("") |
|
|
|
lst_q = ["Which country has low medicare","Who are the patients from india","Who are the patients from india","Patients who have Edema","CUI code for diabetes patients","Patients having oxygen less than 94 but 91"] |
|
|
|
v2 = st.selectbox("Choose your text",lst_q,index = 0) |
|
|
|
st.title("") |
|
|
|
sql_txt = st.text_area("TAPAS Input",v2) |
|
|
|
if st.button("Predict"): |
|
tqa = tapas_model() |
|
txt_sql = tqa(table=table, query=sql_txt)["answer"] |
|
st.text("Output - ") |
|
st.success(f"{txt_sql}") |
|
|
|
|
|
|
|
|
|
elif nav == "Text2SQL": |
|
|
|
|
|
col1 , col2, col3 = st.columns(3) |
|
col2.title("Text2SQL") |
|
|
|
col3 , col4 = st.columns([1,20]) |
|
col4.text("Text will be converted to SQL Query and can extract the data from DataSet") |
|
|
|
|
|
|
|
df_qna = pd.read_csv("qnacsv.csv", encoding= 'unicode_escape') |
|
|
|
st.title("") |
|
|
|
st.text("DataSet - ") |
|
st.dataframe(df_qna,width=3000,height= 500) |
|
|
|
st.title("") |
|
|
|
lst_q = ["what interface is measure indicator code = 72_HR_ABX and version is 1 and source is TD", "get class code with measure = 72_HR_ABX", "get sum of version for Class_Code is Antibiotic Stewardship", "what interface is measure indicator code = 72_HR_ABX"] |
|
v2 = st.selectbox("Choose your text",lst_q,index = 0) |
|
|
|
st.title("") |
|
|
|
|
|
sql_txt = st.text_area("Text for SQL Conversion",v2) |
|
|
|
|
|
if st.button("Predict"): |
|
|
|
tok_model = "juierror/flan-t5-text2sql-with-schema" |
|
tokenizer,model = tokmod(tok_model) |
|
|
|
|
|
table_name = "df_qna" |
|
table_col = ["Type","Class_Code", "Version","Measure_Indicator_Code","Measure_Indicator_Name","Description_Definition", "Source", "Interfaces"] |
|
|
|
txt_sql = inference(question=sql_txt, table=table_col) |
|
|
|
|
|
|
|
sql_avg = ["AVG","COUNT","DISTINCT","MAX","MIN","SUM"] |
|
txt_sql = txt_sql.replace("table",table_name) |
|
sql_quotes = [] |
|
for match in re.finditer("=",txt_sql): |
|
new_txt = txt_sql[match.span()[1]+1:] |
|
try: |
|
match2 = re.search("AND",new_txt) |
|
sql_quotes.append((new_txt[:match2.span()[0]]).strip()) |
|
except: |
|
sql_quotes.append(new_txt.strip()) |
|
|
|
for i in sql_quotes: |
|
qts = "'" + i + "'" |
|
txt_sql = txt_sql.replace(i, qts) |
|
|
|
for r in sql_avg: |
|
if r in txt_sql: |
|
rr = re.search(rf"{r} (\w+)", txt_sql) |
|
init = " " + rr[1] |
|
qts = "(" + rr[1] + ")" |
|
txt_sql = txt_sql.replace(init,qts) |
|
else: |
|
pass |
|
|
|
|
|
st.success(f"{txt_sql}") |
|
all_students = sqldf(txt_sql) |
|
|
|
st.text("Output - ") |
|
st.write(all_students) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|