vbzvibin commited on
Commit
b072237
1 Parent(s): b8d0ad5

Upload 4 files

Browse files
Files changed (4) hide show
  1. data.csv +11 -0
  2. pasta.py +178 -0
  3. qnacsv.csv +0 -0
  4. requirements.txt +13 -0
data.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Patient_Name,Country,Disease,CUI,Snomed,Oxygen_Rate,Med_Type,Admission_Date
2
+ Rahul,India,Diabetes,CUI00234,SNO34672,90,Medicare,23-09-2022
3
+ Kumar ,Sri Lanka,Severe Fever,CUI00235,SNO34673,91,Medicaid,04-03-2022
4
+ Ricky ,Australia,Edema,CUI00236,SNO34674,92,Commercial,02-09-2022
5
+ Jayasuriya,Sri Lanka,Cardiac Arrest,CUI00237,SNO34675,93,Medicare,13-01-2022
6
+ Mahela,Sri Lanka,Alzheimer,CUI00238,SNO34676,94,Medicaid,07-03-2022
7
+ Kohli,India,Cancer,CUI00239,SNO34677,95,Commercial,05-05-2022
8
+ Inzamam,Pakistan,Pneumonia,CUI00240,SNO34678,96,Medicare,04-03-2022
9
+ Jacques,South Africa,Severe Fever,CUI00241,SNO34679,97,Medicaid,02-09-2022
10
+ Saurav,India,Edema,CUI00242,SNO34680,98,Medicare,13-01-2022
11
+ David,India,Cardiac Arrest,CUI00243,SNO34681,99,Medicaid,07-03-2022
pasta.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri May 26 14:07:22 2023
4
+
5
+ @author: vibin
6
+ """
7
+
8
+ import streamlit as st
9
+ from pandasql import sqldf
10
+ import pandas as pd
11
+ import re
12
+ from typing import List
13
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
14
+ import re
15
+
16
+
17
+ @st.cache_resource()
18
+ def tapas_model():
19
+ return(pipeline(task="table-question-answering", model="google/tapas-base-finetuned-wtq"))
20
+
21
+ @st.cache_resource()
22
+ def prepare_input(question: str, table: List[str]):
23
+ table_prefix = "table:"
24
+ question_prefix = "question:"
25
+ join_table = ",".join(table)
26
+ inputs = f"{question_prefix} {question} {table_prefix} {join_table}"
27
+ input_ids = tokenizer(inputs, max_length=512, return_tensors="pt").input_ids
28
+ return input_ids
29
+
30
+ @st.cache_resource()
31
+ def inference(question: str, table: List[str]) -> str:
32
+ input_data = prepare_input(question=question, table=table)
33
+ input_data = input_data.to(model.device)
34
+ outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=700)
35
+ result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
36
+ return result
37
+
38
+ @st.cache_resource()
39
+ def tokmod(tok_md):
40
+ tkn = AutoTokenizer.from_pretrained(tok_md)
41
+ mdl = AutoModelForSeq2SeqLM.from_pretrained(tok_md)
42
+ return(tkn,mdl)
43
+
44
+
45
+ ### Main
46
+
47
+ nav = st.sidebar.radio("Navigation",["TAPAS","Text2SQL"])
48
+ if nav == "TAPAS":
49
+
50
+ col1 , col2, col3 = st.columns(3)
51
+ col2.title("TAPAS")
52
+
53
+ col3 , col4 = st.columns([3,12])
54
+ col4.text("Tabular Data Text Extraction using text")
55
+
56
+ table = pd.read_csv("data.csv")
57
+ table = table.astype(str)
58
+ st.text("DataSet - ")
59
+ st.dataframe(table,width=3000,height= 400)
60
+
61
+ st.title("")
62
+
63
+ lst_q = ["Which country has low medicare","Who are the patients from india","Who are the patients from india","Patients who have Edema","CUI code for diabetes patients","Patients having oxygen less than 94 but 91"]
64
+
65
+ v2 = st.selectbox("Choose your text",lst_q,index = 0)
66
+
67
+ st.title("")
68
+
69
+ sql_txt = st.text_area("TAPAS Input",v2)
70
+
71
+ if st.button("Predict"):
72
+ tqa = tapas_model()
73
+ txt_sql = tqa(table=table, query=sql_txt)["answer"]
74
+ st.text("Output - ")
75
+ st.success(f"{txt_sql}")
76
+ # st.write(all_students)
77
+
78
+
79
+
80
+ elif nav == "Text2SQL":
81
+
82
+ ### Function
83
+ col1 , col2, col3 = st.columns(3)
84
+ col2.title("Text2SQL")
85
+
86
+ col3 , col4 = st.columns([1,20])
87
+ col4.text("Text will be converted to SQL Query and can extract the data from DataSet")
88
+
89
+ # Import Data
90
+
91
+ df_qna = pd.read_csv("qnacsv.csv", encoding= 'unicode_escape')
92
+
93
+ st.title("")
94
+
95
+ st.text("DataSet - ")
96
+ st.dataframe(df_qna,width=3000,height= 500)
97
+
98
+ st.title("")
99
+
100
+ lst_q = ["what interface is measure indicator code = 72_HR_ABX and version is 1 and source is TD", "get class code with measure = 72_HR_ABX", "get sum of version for Class_Code is Antibiotic Stewardship", "what interface is measure indicator code = 72_HR_ABX"]
101
+ v2 = st.selectbox("Choose your text",lst_q,index = 0)
102
+
103
+ st.title("")
104
+
105
+
106
+ sql_txt = st.text_area("Text for SQL Conversion",v2)
107
+
108
+
109
+ if st.button("Predict"):
110
+
111
+ tok_model = "juierror/flan-t5-text2sql-with-schema"
112
+ tokenizer,model = tokmod(tok_model)
113
+
114
+ # text = "what interface is measure indicator code = 72_HR_ABX and version is 1 and source is TD"
115
+ table_name = "df_qna"
116
+ table_col = ["Type","Class_Code", "Version","Measure_Indicator_Code","Measure_Indicator_Name","Description_Definition", "Source", "Interfaces"]
117
+
118
+ txt_sql = inference(question=sql_txt, table=table_col)
119
+
120
+
121
+ ### SQL Modification
122
+ sql_avg = ["AVG","COUNT","DISTINCT","MAX","MIN","SUM"]
123
+ txt_sql = txt_sql.replace("table",table_name)
124
+ sql_quotes = []
125
+ for match in re.finditer("=",txt_sql):
126
+ new_txt = txt_sql[match.span()[1]+1:]
127
+ try:
128
+ match2 = re.search("AND",new_txt)
129
+ sql_quotes.append((new_txt[:match2.span()[0]]).strip())
130
+ except:
131
+ sql_quotes.append(new_txt.strip())
132
+
133
+ for i in sql_quotes:
134
+ qts = "'" + i + "'"
135
+ txt_sql = txt_sql.replace(i, qts)
136
+
137
+ for r in sql_avg:
138
+ if r in txt_sql:
139
+ rr = re.search(rf"{r} (\w+)", txt_sql)
140
+ init = " " + rr[1]
141
+ qts = "(" + rr[1] + ")"
142
+ txt_sql = txt_sql.replace(init,qts)
143
+ else:
144
+ pass
145
+
146
+
147
+ st.success(f"{txt_sql}")
148
+ all_students = sqldf(txt_sql)
149
+
150
+ st.text("Output - ")
151
+ st.write(all_students)
152
+
153
+
154
+
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+
163
+
164
+
165
+
166
+
167
+
168
+
169
+
170
+
171
+
172
+
173
+
174
+
175
+
176
+
177
+
178
+
qnacsv.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pip
2
+ Cmake
3
+ wheel
4
+ pandas
5
+ jinja2==3.1.2
6
+ pandasql
7
+ Cython
8
+ datasets
9
+ huggingface-hub
10
+ tapas
11
+ torch
12
+ transformers
13
+ streamlit