MINHCT commited on
Commit
c77a94b
β€’
1 Parent(s): c2af0b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -114
app.py CHANGED
@@ -92,14 +92,12 @@ def process_api(text):
92
  SVM_Predicted = SVM_model.predict(processed_text).tolist() # SVC Model
93
  Seq_Predicted = Seq_model.predict(padded_sequence)
94
  predicted_label_index = np.argmax(Seq_Predicted)
95
- print(int(predicted_label_index))
96
 
97
  # ----------- Proba -----------
98
  Logistic_Predicted_proba = logistic_model.predict_proba(processed_text)
99
- #print(float(np.max(Logistic_Predicted_proba)))
100
  svm_new_probs = SVM_model.decision_function(processed_text)
101
  svm_probs = svm_model.predict_proba(svm_new_probs)
102
- #print(float(np.max(svm_probs)))
103
 
104
  # ----------- Debug Logs -----------
105
  logistic_debug = decodedLabel(int(Logistic_Predicted[0]))
@@ -114,122 +112,16 @@ def process_api(text):
114
 
115
  'predicted_label_svm': decodedLabel(int(SVM_Predicted[0])),
116
  'probability_svm': f"{int(float(np.max(svm_probs))*10000//100)}%",
117
- 'LSTM': decodedLabel(int(predicted_label_index)),
 
 
 
118
  'Article_Content': text
119
  }
120
 
121
  # Init web crawling, process article content by Model and return result as JSON
122
- def categorize(url):
123
- try:
124
- article_content = crawURL(url)
125
- result = process_api(article_content)
126
- return result
127
- except Exception as error:
128
- if hasattr(error, 'message'):
129
- return {"error_message": error.message}
130
- else:
131
- return {"error_message": error}
132
-
133
-
134
- # Main App
135
- st.title('Instant Category Classification')
136
- st.write("Unsure what category a CNN article belongs to? Our clever tool can help! Paste the URL below and press Enter. We'll sort it into one of our 5 categories in a flash! ⚑️")
137
-
138
- # Define category information (modify content and bullet points as needed)
139
- categories = {
140
- "Business": [
141
- "Analyze market trends and investment opportunities.",
142
- "Gain insights into company performance and industry news.",
143
- "Stay informed about economic developments and regulations."
144
- ],
145
- "Health": [
146
- "Discover healthy recipes and exercise tips.",
147
- "Learn about the latest medical research and advancements.",
148
- "Find resources for managing chronic conditions and improving well-being."
149
- ],
150
- "Sport": [
151
- "Follow your favorite sports teams and athletes.",
152
- "Explore news and analysis from various sports categories.",
153
- "Stay updated on upcoming games and competitions."
154
- ],
155
- "Politics": [
156
- "Get informed about current political events and policies.",
157
- "Understand different perspectives on political issues.",
158
- "Engage in discussions and debates about politics."
159
- ],
160
- "Entertainment": [
161
- "Find recommendations for movies, TV shows, and music.",
162
- "Explore reviews and insights from entertainment critics.",
163
- "Stay updated on celebrity news and cultural trends."
164
- ]
165
- }
166
-
167
- # Define model information (modify descriptions as needed)
168
- models = {
169
- "Logistic Regression": "A widely used statistical method for classification problems. It excels at identifying linear relationships between features and the target variable.",
170
- "SVC (Support Vector Classifier)": "A powerful machine learning model that seeks to find a hyperplane that best separates data points of different classes. It's effective for high-dimensional data and can handle some non-linear relationships.",
171
- "LSTM (Long Short-Term Memory)": "A type of recurrent neural network (RNN) particularly well-suited for sequential data like text or time series. LSTMs can effectively capture long-term dependencies within the data.",
172
- "BERT (Bidirectional Encoder Representations from Transformers)": "A powerful pre-trained model based on the Transformer architecture. It excels at understanding the nuances of language and can be fine-tuned for various NLP tasks like text classification."
173
- }
174
-
175
- # Create expanders containing list of categories can be classified
176
- with st.expander("Category List"):
177
- # Title for each category
178
- st.subheader("Available Categories:")
179
- for category in categories.keys():
180
- st.write(f"- {category}")
181
- # Content for each category (separated by a horizontal line)
182
- st.write("---")
183
- for category, content in categories.items():
184
- st.subheader(category)
185
- for item in content:
186
- st.write(f"- {item}")
187
-
188
-
189
- # Create expanders containing list of models used in this project
190
- with st.expander("Available Models"):
191
- st.subheader("List of Models:")
192
- for model_name in models.keys():
193
- st.write(f"- {model_name}")
194
- st.write("---")
195
- for model_name, description in models.items():
196
- st.subheader(model_name)
197
- st.write(description)
198
-
199
- # Explain to user why this project is only worked for CNN domain
200
- with st.expander("Tips", expanded=True):
201
- st.write(
202
- '''
203
- This project works best with CNN articles right now.
204
- Our web crawler is like a special tool for CNN's website.
205
- It can't quite understand other websites because they're built differently
206
- '''
207
- )
208
-
209
- st.divider() # πŸ‘ˆ Draws a horizontal rule
210
-
211
- st.title('Dive in! See what category your CNN story belongs to πŸ˜‰.')
212
- # Paste URL Input
213
- url = st.text_input("Find your favorite CNN story! Paste the URL and press ENTER πŸ”.", placeholder='Ex: https://edition.cnn.com/2012/01/31/health/frank-njenga-mental-health/index.html')
214
-
215
- if url:
216
- st.divider() # πŸ‘ˆ Draws a horizontal rule
217
- result = categorize(url)
218
- article_content = result.get('Article_Content')
219
- st.title('Article Content Fetched')
220
- st.text_area("", value=article_content, height=400) # render the article content as textarea element
221
- st.divider() # πŸ‘ˆ Draws a horizontal rule
222
- st.title('Predicted Results')
223
- st.json({
224
- "Logistic": {
225
- "predicted_label": result.get("predicted_label_logistic"),
226
- "probability": result.get("probability_logistic")
227
- },
228
- "SVC": {
229
- "predicted_label": result.get("predicted_label_svm"),
230
- "probability": result.get("probability_svm")
231
  },
232
- "LSTM": result.get("LSTM")
233
  })
234
 
235
  st.divider() # πŸ‘ˆ Draws a horizontal rule
 
92
  SVM_Predicted = SVM_model.predict(processed_text).tolist() # SVC Model
93
  Seq_Predicted = Seq_model.predict(padded_sequence)
94
  predicted_label_index = np.argmax(Seq_Predicted)
 
95
 
96
  # ----------- Proba -----------
97
  Logistic_Predicted_proba = logistic_model.predict_proba(processed_text)
 
98
  svm_new_probs = SVM_model.decision_function(processed_text)
99
  svm_probs = svm_model.predict_proba(svm_new_probs)
100
+ predicted_label_index = np.argmax(Seq_Predicted)
101
 
102
  # ----------- Debug Logs -----------
103
  logistic_debug = decodedLabel(int(Logistic_Predicted[0]))
 
112
 
113
  'predicted_label_svm': decodedLabel(int(SVM_Predicted[0])),
114
  'probability_svm': f"{int(float(np.max(svm_probs))*10000//100)}%",
115
+
116
+ 'predicted_label_lstm': int(predicted_label_index),
117
+ 'probability_lstm': f"{int(float(np.max(Seq_Predicted))*10000//100)}%",
118
+
119
  'Article_Content': text
120
  }
121
 
122
  # Init web crawling, process article content by Model and return result as JSON
123
+ lstm")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  },
 
125
  })
126
 
127
  st.divider() # πŸ‘ˆ Draws a horizontal rule