# -*- coding: utf-8 -*- """ Created on Fri May 26 14:07:22 2023 @author: vibin """ import streamlit as st from pandasql import sqldf import pandas as pd import re from typing import List from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline import re @st.cache_resource() def tapas_model(): return(pipeline(task="table-question-answering", model="google/tapas-base-finetuned-wtq")) @st.cache_resource() def prepare_input(question: str, table: List[str]): table_prefix = "table:" question_prefix = "question:" join_table = ",".join(table) inputs = f"{question_prefix} {question} {table_prefix} {join_table}" input_ids = tokenizer(inputs, max_length=512, return_tensors="pt").input_ids return input_ids @st.cache_resource() def inference(question: str, table: List[str]) -> str: input_data = prepare_input(question=question, table=table) input_data = input_data.to(model.device) outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=700) result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True) return result @st.cache_resource() def tokmod(tok_md): tkn = AutoTokenizer.from_pretrained(tok_md) mdl = AutoModelForSeq2SeqLM.from_pretrained(tok_md) return(tkn,mdl) ### Main nav = st.sidebar.radio("Navigation",["TAPAS","Text2SQL"]) if nav == "TAPAS": col1 , col2, col3 = st.columns(3) col2.title("TAPAS") col3 , col4 = st.columns([3,12]) col4.text("Tabular Data Text Extraction using text") table = pd.read_csv("data.csv") table = table.astype(str) st.text("DataSet - ") st.dataframe(table,width=3000,height= 400) st.title("") lst_q = ["Which country has low medicare","Who are the patients from india","Who are the patients from india","Patients who have Edema","CUI code for diabetes patients","Patients having oxygen less than 94 but 91"] v2 = st.selectbox("Choose your text",lst_q,index = 0) st.title("") sql_txt = st.text_area("TAPAS Input",v2) if st.button("Predict"): tqa = tapas_model() txt_sql = tqa(table=table, query=sql_txt)["answer"] st.text("Output - ") st.success(f"{txt_sql}") # st.write(all_students) elif nav == "Text2SQL": ### Function col1 , col2, col3 = st.columns(3) col2.title("Text2SQL") col3 , col4 = st.columns([1,20]) col4.text("Text will be converted to SQL Query and can extract the data from DataSet") # Import Data #df_qna = pd.read_csv("qnacsv.csv", encoding= 'unicode_escape') df_qna = pd.read_csv("data.csv") st.title("") st.text("DataSet - ") st.dataframe(df_qna,width=3000,height= 500) st.title("") lst_q = ["what interface is measure indicator code = 72_HR_ABX and version is 1 and source is TD", "get class code with measure = 72_HR_ABX", "get sum of version for Class_Code is Antibiotic Stewardship", "what interface is measure indicator code = 72_HR_ABX"] v2 = st.selectbox("Choose your text",lst_q,index = 0) st.title("") sql_txt = st.text_area("Text for SQL Conversion",v2) if st.button("Predict"): tok_model = "juierror/flan-t5-text2sql-with-schema" tokenizer,model = tokmod(tok_model) # text = "what interface is measure indicator code = 72_HR_ABX and version is 1 and source is TD" table_name = "df_qna" table_column = ['Patient_Name', 'Country', 'Disease', 'CUI', 'Snomed', 'Oxygen_Rate','Med_Type', 'Admission_Date'] txt_sql = inference(question=sql_txt, table=table_column) ### SQL Modification sql_avg = ["AVG","COUNT","DISTINCT","MAX","MIN","SUM"] txt_sql = txt_sql.replace("table",table_name) sql_quotes = [] for match in re.finditer("=",txt_sql): new_txt = txt_sql[match.span()[1]+1:] try: match2 = re.search("AND",new_txt) sql_quotes.append((new_txt[:match2.span()[0]]).strip()) except: sql_quotes.append(new_txt.strip()) for i in sql_quotes: qts = "'" + i + "'" txt_sql = txt_sql.replace(i, qts) for r in sql_avg: if r in txt_sql: rr = re.search(rf"{r} (\w+)", txt_sql) init = " " + rr[1] qts = "(" + rr[1] + ")" txt_sql = txt_sql.replace(init,qts) else: pass st.success(f"{txt_sql}") all_students = sqldf(txt_sql) st.text("Output - ") st.write(all_students)