hitz02 commited on
Commit
927e0eb
1 Parent(s): a0eaac5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow.compat.v1 as tf
2
+ import os
3
+ import shutil
4
+ import csv
5
+ import pandas as pd
6
+ import numpy as np
7
+ import IPython
8
+ import streamlit as st
9
+ import subprocess
10
+ from itertools import islice
11
+ import random
12
+ from transformers import TapasTokenizer, TapasForQuestionAnswering
13
+
14
+ tf.get_logger().setLevel('ERROR')
15
+
16
+ model_name = 'google/tapas-base-finetuned-wtq'
17
+ model = TapasForQuestionAnswering.from_pretrained(model_name, local_files_only=False)
18
+ tokenizer = TapasTokenizer.from_pretrained(model_name)
19
+
20
+
21
+ st.set_option('deprecation.showfileUploaderEncoding', False)
22
+
23
+ st.title('Query your Table')
24
+
25
+ st.header('Upload CSV file')
26
+
27
+
28
+ uploaded_file = st.file_uploader("Choose your CSV file",type = 'csv')
29
+
30
+ placeholder = st.empty()
31
+
32
+ if uploaded_file is not None:
33
+ data = pd.read_csv(uploaded_file)
34
+ data.replace(',','', regex=True, inplace=True)
35
+ if st.checkbox('Want to see the data?'):
36
+ placeholder.dataframe(data)
37
+
38
+ st.header('Enter your queries')
39
+
40
+ input_queries = st.text_input('Type your queries separated by comma(,)',value='')
41
+ input_queries = input_queries.split(',')
42
+
43
+ colors1 = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)]) for i in range(len(input_queries))]
44
+ colors2 = ['background-color:'+str(color)+'; color: black' for color in colors1]
45
+
46
+ def styling_specific_cell(x,tags,colors):
47
+ df_styler = pd.DataFrame('', index=x.index, columns=x.columns)
48
+ for idx,tag in enumerate(tags):
49
+ for r,c in tag:
50
+ df_styler.iloc[r, c] = colors[idx]
51
+ return df_styler
52
+
53
+ if st.button('Predict Answers'):
54
+ with st.spinner('It will take approx a minute'):
55
+ data = data.astype(str)
56
+ inputs = tokenizer(table=table, queries=queries, padding='max_length', return_tensors="pt")
57
+ outputs = model(**inputs)
58
+ predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions( inputs, outputs.logits.detach(), outputs.logits_aggregation.detach())
59
+
60
+ id2aggregation = {0: "NONE", 1: "SUM", 2: "AVERAGE", 3:"COUNT"}
61
+ aggregation_predictions_string = [id2aggregation[x] for x in predicted_aggregation_indices]
62
+
63
+ answers = []
64
+
65
+ for coordinates in predicted_answer_coordinates:
66
+ if len(coordinates) == 1:
67
+ # only a single cell:
68
+ answers.append(table.iat[coordinates[0]])
69
+ else:
70
+ # multiple cells
71
+ cell_values = []
72
+ for coordinate in coordinates:
73
+ cell_values.append(table.iat[coordinate])
74
+ answers.append(", ".join(cell_values))
75
+
76
+ st.success('Done! Please check below the answers and its cells highlighted in table above')
77
+
78
+ placeholder.dataframe(data.style.apply(styling_specific_cell,tags=predicted_answer_coordinates,colors=colors2,axis=None))
79
+
80
+ for query, answer, predicted_agg, c in zip(queries, answers, aggregation_predictions_string, colors1):
81
+ st.write('\n')
82
+ st.markdown('<font color={} size=4>**{}**</font>'.format(c,query), unsafe_allow_html=True)
83
+ st.write('\n')
84
+
85
+ if predicted_agg == "NONE" or predicted_agg == 'COUNT':
86
+ st.markdown('**>** '+str(answer))
87
+ else:
88
+ if predicted_agg == 'SUM':
89
+ st.markdown('**>** '+str(sum(answer.split(','))))
90
+ else:
91
+ st.markdown('**>** '+str(np.round(np.mean(answer.split(',')),2)))