amir7d0 commited on
Commit
f9ce5cf
β€’
1 Parent(s): 6312bdc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ import numpy as np
4
+ import pandas as pd
5
+ import json
6
+ import base64
7
+ import uuid
8
+ from pandas import DataFrame
9
+ import time
10
+ import re
11
+
12
+ def download_button(object_to_download, download_filename, button_text):
13
+
14
+ if isinstance(object_to_download, bytes):
15
+ pass
16
+
17
+ elif isinstance(object_to_download, pd.DataFrame):
18
+ object_to_download = object_to_download.to_csv(index=False)
19
+ # Try JSON encode for everything else
20
+ else:
21
+ object_to_download = json.dumps(object_to_download)
22
+
23
+ try:
24
+ # some strings <-> bytes conversions necessary here
25
+ b64 = base64.b64encode(object_to_download.encode()).decode()
26
+ except AttributeError as e:
27
+ b64 = base64.b64encode(object_to_download).decode()
28
+
29
+ button_uuid = str(uuid.uuid4()).replace("-", "")
30
+ button_id = re.sub("\d+", "", button_uuid)
31
+
32
+ custom_css = f"""
33
+ <style>
34
+ #{button_id} {{
35
+ display: inline-flex;
36
+ align-items: center;
37
+ justify-content: center;
38
+ background-color: rgb(255, 255, 255);
39
+ color: rgb(38, 39, 48);
40
+ padding: .25rem .75rem;
41
+ position: relative;
42
+ text-decoration: none;
43
+ border-radius: 4px;
44
+ border-width: 1px;
45
+ border-style: solid;
46
+ border-color: rgb(230, 234, 241);
47
+ border-image: initial;
48
+ }}
49
+ #{button_id}:hover {{
50
+ border-color: rgb(246, 51, 102);
51
+ color: rgb(246, 51, 102);
52
+ }}
53
+ #{button_id}:active {{
54
+ box-shadow: none;
55
+ background-color: rgb(246, 51, 102);
56
+ color: white;
57
+ }}
58
+ </style> """
59
+
60
+ dl_link = (
61
+ custom_css
62
+ + f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br><br>'
63
+ )
64
+ # dl_link = f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}"><input type="button" kind="primary" value="{button_text}"></a><br></br>'
65
+
66
+ st.markdown(dl_link, unsafe_allow_html=True)
67
+
68
+
69
+
70
+ class c_model:
71
+ def __init__(self):
72
+ # st.write('my model')
73
+ pass
74
+
75
+ @st.cache
76
+ def load_model(self, name_or_path):
77
+ time.sleep(3)
78
+ return None
79
+
80
+ def predict(self, texts):
81
+ return np.random.randint(2), np.random.rand()
82
+
83
+
84
+ st.title('Sentiment Analysis')
85
+
86
+
87
+ # Load classification model
88
+ with st.spinner('Loading classification model...'):
89
+ from transformers import pipeline
90
+
91
+ checkpoint = "amir7d0/distilbert-base-uncased-finetuned-amazon-reviews"
92
+ classifier = pipeline("text-classification", model=checkpoint)
93
+
94
+
95
+ tab1, tab2 = st.tabs(["Single Comment", "Multiple Comment"])
96
+
97
+ with tab1:
98
+ st.subheader('Single comment classification')
99
+ text_input = st.text_area(label='Paste your text below (max 256 words)',
100
+ value='Hiiiiiiiii')
101
+ MAX_WORDS = 256
102
+ res = len(re.findall(r"\w+", text_input))
103
+ if res > MAX_WORDS:
104
+ st.warning(
105
+ "⚠️ Your text contains "
106
+ + str(res)
107
+ + " words."
108
+ + " Only the first 256 words will be reviewed! 😊"
109
+ )
110
+ text_input = text_input[:MAX_WORDS]
111
+
112
+ submit_button = st.button(label='Submit comment')
113
+ if submit_button:
114
+ with st.spinner('Predicting ...'):
115
+ start_time = time.time()
116
+ time.sleep(2)
117
+ preds = classifier([text_input])[0]
118
+ end_time = time.time()
119
+ p_time = round(end_time-start_time, 2)
120
+ st.success(f'Prediction finished in {p_time}s!')
121
+
122
+ st.write(f'Label: {preds["label"]}, with certainty: {preds["score"]}')
123
+
124
+
125
+ with tab2:
126
+ st.subheader('Multiple comment classification')
127
+ file_input = st.file_uploader(label='Choose a file:', type='csv')
128
+ if file_input:
129
+ try:
130
+ df = pd.read_csv(file_input)
131
+ texts = df['text'].to_list()
132
+ except:
133
+ st.write('Bad File Error...')
134
+
135
+ st.write(f"First 5 rows of {file_input.name} texts")
136
+ st.write(texts[:5])
137
+
138
+ submit_button = st.button(label='Submit file')
139
+ if submit_button:
140
+ with st.spinner('Predicting ...'):
141
+ start_time = time.time()
142
+ time.sleep(2)
143
+ preds = classifier(texts)
144
+ end_time = time.time()
145
+ p_time = round(end_time-start_time, 2)
146
+ st.success(f'Prediction finished in {p_time}s!')
147
+
148
+ c1, c2 = st.columns([3, 1])
149
+ with c1:
150
+ st.subheader("🎈 Check & download results")
151
+ with c2:
152
+ CSVButton2 = download_button(results, "Data.csv", "πŸ“₯ Download (.csv)")
153
+
154
+ st.header("")
155
+
156
+ for text, pred in zip(texts, preds):
157
+ pred['text'] = text
158
+
159
+ df = pd.DataFrame(preds, columns=['text', 'label', 'score'])
160
+
161
+ import seaborn as sns
162
+ # Add styling
163
+ cmGreen = sns.light_palette("green", as_cmap=True)
164
+ cmRed = sns.light_palette("red", as_cmap=True)
165
+ df = df.style.background_gradient(
166
+ cmap=cmGreen,
167
+ subset=["score"],
168
+ )
169
+
170
+ st.table(df)