manideepvemula commited on
Commit
3ecebd4
1 Parent(s): d5a16f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -8
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import torch
3
  import numpy as np
4
  import pandas as pd
@@ -6,9 +5,11 @@ from newsfetch.news import newspaper
6
  from transformers import pipeline
7
  from transformers import T5Tokenizer, T5ForConditionalGeneration
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
9
  from newspaper import Article
10
  from sklearn.preprocessing import LabelEncoder
11
  import joblib
 
12
 
13
 
14
  # Example usage:
@@ -24,6 +25,11 @@ def main():
24
  try:
25
  news_article = newspaper(url)
26
  print("scraped: ",news_article)
 
 
 
 
 
27
  return news_article.article
28
  except Exception as e:
29
  return "Error: " + str(e)
@@ -58,15 +64,35 @@ def main():
58
  return None,None
59
  else:
60
  st.write("This article is not classified as related to the supply chain.")
61
-
62
  def classify_and_summarize(input_text, cls_model, tokenizer_cls, label_encoder, model_summ, tokenizer_summ, device):
63
  if input_text.startswith("http"):
64
  # If the input starts with "http", assume it's a URL and extract content
65
  article_content = scrape_news_content(input_text)
 
66
  else:
67
  # If the input is not a URL, assume it's the content
68
  article_content = input_text
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  # Perform sentiment classification
71
  inputs_cls = tokenizer_cls(article_content, return_tensors="pt", max_length=512, truncation=True)
72
  inputs_cls = {key: value.to(device) for key, value in inputs_cls.items()}
@@ -91,11 +117,13 @@ def main():
91
  print("No opportunity summary generated.")
92
  summary_opportunity = "No opportunity summary available" # Provide a default value or handle accordingly
93
 
94
- return classification, summary_risk, summary_opportunity
95
 
96
 
97
  print(url_input)
98
- cls_model =AutoModelForSequenceClassification.from_pretrained("/content/drive/MyDrive/riskclassification_finetuned_xlnet_model_ld")
 
 
99
  tokenizer_cls = AutoTokenizer.from_pretrained("xlnet-base-cased")
100
  label_encoder = LabelEncoder()
101
 
@@ -110,7 +138,7 @@ def main():
110
  print("Label encoder values")
111
 
112
  # Replace the original column with the encoded values
113
- label_encoder_path = "/content/drive/MyDrive/riskclassification_finetuned_xlnet_model_ld/encoder_labels.pkl"
114
  joblib.dump(label_encoder, label_encoder_path)
115
 
116
  model_summ = T5ForConditionalGeneration.from_pretrained("t5-small")
@@ -119,7 +147,7 @@ def main():
119
 
120
 
121
 
122
- classification, summary_risk, summary_opportunity = classify_and_summarize(url_input, cls_model, tokenizer_cls, label_encoder, model_summ, tokenizer_summ, device)
123
 
124
  print("Classification:", classification)
125
  print("Risk Summary:", summary_risk)
@@ -127,10 +155,45 @@ def main():
127
 
128
 
129
  # Display the entered URL
130
- st.write("Entered URL:", url_input)
131
  st.write("Classification:",classification)
132
  st.write("Risk Summary:",summary_risk)
133
  st.write("Opportunity Summary:",summary_opportunity)
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  if __name__ == "__main__":
136
- main()
 
 
1
  import torch
2
  import numpy as np
3
  import pandas as pd
 
5
  from transformers import pipeline
6
  from transformers import T5Tokenizer, T5ForConditionalGeneration
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
9
  from newspaper import Article
10
  from sklearn.preprocessing import LabelEncoder
11
  import joblib
12
+ from datetime import datetime
13
 
14
 
15
  # Example usage:
 
25
  try:
26
  news_article = newspaper(url)
27
  print("scraped: ",news_article)
28
+ print("Attributes of the newspaper object:", dir(news_article))
29
+ # Print the methods of the newspaper object
30
+ print("Methods of the newspaper object:", [method for method in dir(news_article) if callable(getattr(news_article, method))])
31
+ # Try to print some specific attributes
32
+ print("Authors:", news_article.authors)
33
  return news_article.article
34
  except Exception as e:
35
  return "Error: " + str(e)
 
64
  return None,None
65
  else:
66
  st.write("This article is not classified as related to the supply chain.")
67
+
68
  def classify_and_summarize(input_text, cls_model, tokenizer_cls, label_encoder, model_summ, tokenizer_summ, device):
69
  if input_text.startswith("http"):
70
  # If the input starts with "http", assume it's a URL and extract content
71
  article_content = scrape_news_content(input_text)
72
+ st.write("Entered URL:", url_input)
73
  else:
74
  # If the input is not a URL, assume it's the content
75
  article_content = input_text
76
 
77
+
78
+ # Get the number of lines in the text.
79
+ truncated_content = " ".join(article_content.split()[:150])
80
+ st.markdown(f"Truncated Content:\n{truncated_content}", unsafe_allow_html=True)
81
+
82
+ # Add a button to toggle between truncated and full content
83
+ if st.button("Read More"):
84
+ # Display the full content when the button is clicked
85
+ full_content = " ".join(article_content.split())
86
+ st.markdown(f"Full Content:\n{full_content}", unsafe_allow_html=True)
87
+ # Remove the truncated content when the full content is displayed
88
+ st.markdown(
89
+ """
90
+ <script>
91
+ document.getElementById("truncated-content").style.display = "none";
92
+ </script>
93
+ """,
94
+ unsafe_allow_html=True
95
+ )
96
  # Perform sentiment classification
97
  inputs_cls = tokenizer_cls(article_content, return_tensors="pt", max_length=512, truncation=True)
98
  inputs_cls = {key: value.to(device) for key, value in inputs_cls.items()}
 
117
  print("No opportunity summary generated.")
118
  summary_opportunity = "No opportunity summary available" # Provide a default value or handle accordingly
119
 
120
+ return classification, summary_risk, summary_opportunity,article_content
121
 
122
 
123
  print(url_input)
124
+ cls_model =AutoModelForSequenceClassification.from_pretrained("riskclassification_finetuned_xlnet_model_ld")
125
+ print(type(cls_model))
126
+
127
  tokenizer_cls = AutoTokenizer.from_pretrained("xlnet-base-cased")
128
  label_encoder = LabelEncoder()
129
 
 
138
  print("Label encoder values")
139
 
140
  # Replace the original column with the encoded values
141
+ label_encoder_path = "riskclassification_finetuned_xlnet_model_ld/encoder_labels.pkl"
142
  joblib.dump(label_encoder, label_encoder_path)
143
 
144
  model_summ = T5ForConditionalGeneration.from_pretrained("t5-small")
 
147
 
148
 
149
 
150
+ classification, summary_risk, summary_opportunity,article_content = classify_and_summarize(url_input, cls_model, tokenizer_cls, label_encoder, model_summ, tokenizer_summ, device)
151
 
152
  print("Classification:", classification)
153
  print("Risk Summary:", summary_risk)
 
155
 
156
 
157
  # Display the entered URL
 
158
  st.write("Classification:",classification)
159
  st.write("Risk Summary:",summary_risk)
160
  st.write("Opportunity Summary:",summary_opportunity)
161
 
162
+
163
+ def process_question():
164
+ # Use session_state to persist variables across sessions
165
+ if 'qa_history' not in st.session_state:
166
+ st.session_state.qa_history = []
167
+
168
+ # Input box for user's question
169
+ user_question_key = st.session_state.question_counter if 'question_counter' in st.session_state else 0
170
+ user_question = st.text_input("Ask a question about the article content:", key=user_question_key)
171
+
172
+ # Check if "Send" button is clicked
173
+ send_button_key = f"send_button_{user_question_key}"
174
+ if st.button("Send", key=send_button_key) and user_question:
175
+ # Use a question-answering pipeline to generate a response
176
+ model_name = "deepset/tinyroberta-squad2"
177
+ nlp = pipeline('question-answering', model=model_name, tokenizer=model_name)
178
+ QA_input = {'question': user_question, 'context': article_content}
179
+ res = nlp(QA_input)
180
+
181
+ # Display the user's question and the model's answer
182
+ st.write(f"You asked: {user_question}")
183
+ st.write("Model's Answer:", res["answer"])
184
+
185
+ # Add the question and answer to the history
186
+ st.session_state.qa_history.append((user_question, res["answer"]))
187
+
188
+ # Clear the input box
189
+
190
+ # Display the history
191
+ st.write("Question-Answer History:")
192
+ for q, a in st.session_state.qa_history:
193
+ st.write(f"Q: {q}")
194
+ st.write(f"A: {a}")
195
+
196
+ # Run the function to process questions
197
+ process_question()
198
  if __name__ == "__main__":
199
+ main()