RamAI123 commited on
Commit
1630a37
1 Parent(s): d003d76

Upload 6 files

Browse files
Files changed (4) hide show
  1. Dockerfile.txt +14 -0
  2. app.py +321 -0
  3. db.sql +0 -0
  4. requirements.txt +12 -0
Dockerfile.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.9
5
+
6
+ WORKDIR /code
7
+
8
+ COPY ./requirements.txt /code/requirements.txt
9
+
10
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
11
+
12
+ COPY . .
13
+
14
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, jsonify, render_template_string
2
+ from flask_cors import CORS
3
+ from newspaper import Article
4
+ from transformers import pipeline
5
+ import torch
6
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, T5ForConditionalGeneration, T5Tokenizer
7
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
8
+ from sklearn.preprocessing import LabelEncoder
9
+ import joblib
10
+ import mysql.connector
11
+ from flask import send_file
12
+ from reportlab.pdfgen import canvas
13
+ import io
14
+ from reportlab.lib.pagesizes import letter
15
+ from reportlab.lib import colors
16
+ from reportlab.platypus import SimpleDocTemplate, Table, TableStyle, PageBreak, Paragraph
17
+ from nltk.tokenize import sent_tokenize
18
+ from reportlab.platypus import Spacer
19
+ from reportlab.platypus.flowables import KeepTogether
20
+ from reportlab.lib.styles import getSampleStyleSheet
21
+ import datetime
22
+
23
+
24
+
25
+ app = Flask(__name__, template_folder='templates')
26
+ CORS(app)
27
+ chat_history = []
28
+
29
+ mysql_config = {
30
+ 'host': 'localhost',
31
+ 'user': 'root',
32
+ 'password': '9553641651',
33
+ 'database': 'articles'
34
+ }
35
+
36
+ def insert_question_and_answer(question, answer,timestamp):
37
+ try:
38
+ # Connect to the MySQL database
39
+ connection = mysql.connector.connect(**mysql_config)
40
+ cursor = connection.cursor()
41
+
42
+ # SQL query to insert a new record into the 'supplychain' table
43
+ query = "INSERT INTO supplychain143 (question, answer, timestamp) VALUES (%s, %s, %s);"
44
+ values = (question, answer,timestamp)
45
+
46
+ # Execute the query
47
+ cursor.execute(query, values)
48
+
49
+ # Commit the changes
50
+ connection.commit()
51
+
52
+ # Close the cursor and connection
53
+ cursor.close()
54
+ connection.close()
55
+
56
+ print("Record inserted successfully!")
57
+ except Exception as e:
58
+ print("Error inserting record:", str(e))
59
+
60
+ def retrieve_article_content(timestamp):
61
+ try:
62
+ # Connect to the MySQL database
63
+ connection = mysql.connector.connect(**mysql_config)
64
+ cursor = connection.cursor()
65
+
66
+ # SQL query to retrieve article content based on the question
67
+ query = "SELECT question, answer FROM supplychain143 WHERE timestamp = %s;"
68
+ values = (timestamp,)
69
+
70
+ # Execute the query
71
+ cursor.execute(query, values)
72
+
73
+ # Fetch the results
74
+ results = cursor.fetchall()
75
+ # Close the cursor and connection
76
+ cursor.close()
77
+ connection.close()
78
+
79
+ return results
80
+ except Exception as e:
81
+ print("Error retrieving article content:", str(e))
82
+ return None
83
+
84
+ def scrape_news_content(url):
85
+ # ... (Your existing implementation)
86
+ try:
87
+ article = Article(url)
88
+ article.download()
89
+ article.parse()
90
+
91
+ title = article.title
92
+ content = article.text
93
+
94
+ return content
95
+ # Remove leading/trailing whitespaces
96
+ except Exception as e:
97
+ return "Error: " + str(e)
98
+
99
+
100
+ def summarize_with_t5(article_content, classification, model, tokenizer, device):
101
+ # ... (Your existing implementation)
102
+ article_content = str(article_content)
103
+ prompt = "Classification: " + str(classification) + "\n"
104
+ if not article_content or article_content == "nan":
105
+ return "", ""
106
+ if classification == "risks":
107
+ prompt = "summarize the key supply chain risks: "
108
+ elif classification == "opportunities":
109
+ prompt = "summarize the key supply chain opportunities: "
110
+ elif classification == "neither":
111
+ print("Nooo")
112
+ return "None", "None"
113
+
114
+ input_text = prompt + article_content
115
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
116
+
117
+ model = model.to(device) #/ Move the model to the correct device
118
+ summary_ids = model.generate(input_ids.to(device), max_length=150, num_beams=4, length_penalty=2.0, early_stopping=True)
119
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
120
+ print(summary)
121
+ if classification in ["risks", "opportunities"]:
122
+ if classification == "risks":
123
+ return summary, "None"
124
+ elif classification == "opportunities":
125
+ return "None", summary
126
+ else:
127
+ return None,None
128
+ else:
129
+ return ("This article is not classified as related to the supply chain.")
130
+
131
+
132
+ def classify_and_summarize(input_text, cls_model, tokenizer_cls, label_encoder, model_summ, tokenizer_summ, device):
133
+ # ... (Your existing implementation)
134
+ results = []
135
+ request_timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
136
+ input_text=input_text.split(",")
137
+ for url in input_text:
138
+ if url.startswith("http"):
139
+ # If the input starts with "http", assume it's a URL and extract content
140
+ article_content = scrape_news_content(url)
141
+ else:
142
+ # If the input is not a URL, assume it's the content
143
+ article_content = url
144
+
145
+ # Perform sentiment classification
146
+ inputs_cls = tokenizer_cls(article_content, return_tensors="pt", max_length=512, truncation=True, padding=True)
147
+ inputs_cls = {key: value.to(device) for key, value in inputs_cls.items()}
148
+
149
+ # Move cls_model to the specified device
150
+ cls_model = cls_model.to(device)
151
+
152
+ outputs_cls = cls_model(**inputs_cls)
153
+ logits_cls = outputs_cls.logits
154
+ predicted_class = torch.argmax(logits_cls, dim=1).item()
155
+ classification = label_encoder.inverse_transform([predicted_class])[0]
156
+
157
+ # Perform summarization based on the classification
158
+ summary_risk, summary_opportunity = summarize_with_t5(article_content, classification, model_summ, tokenizer_summ, device)
159
+
160
+ if summary_risk is None:
161
+ summary_risk = "No risk summary available"
162
+ if summary_opportunity is None:
163
+ summary_opportunity = "No opportunity summary available"
164
+ answer=article_content
165
+ article_content_words = article_content.split()[:200]
166
+ short_article_content = ' '.join(article_content_words)
167
+ insert_question_and_answer(url,answer, request_timestamp)
168
+ current_request_timestamp=request_timestamp
169
+ results.append({"Question": url, "Article content":article_content,"Short Article content":short_article_content,"Classification": classification, "Summary risk": summary_risk, "Opportunity Summary": summary_opportunity})
170
+ print("Result",results)
171
+ return results
172
+
173
+ def generate_sentence_from_keywords(keywords):
174
+ # Concatenate keywords into a single string
175
+ keyword_sentence = ' '.join(keywords)
176
+
177
+ # Tokenize the concatenated keywords into sentences
178
+ sentences = sent_tokenize(keyword_sentence)
179
+
180
+ # If there are sentences, return the first one; otherwise, return a default message
181
+ return sentences[0] if sentences else "Unable to generate a sentence."
182
+
183
+ def is_question(input_text):
184
+ questioning_words = ["who", "what", "when", "where", "why", "how"]
185
+ return any(input_text.lower().startswith(q) for q in questioning_words)
186
+
187
+
188
+ def process_question(user_question,articlecontent):
189
+ answers = [item[1] for item in articlecontent]
190
+ context_string = ' '.join(map(str, answers))
191
+ model_name = "deepset/tinyroberta-squad2"
192
+ nlp = pipeline('question-answering', model=model_name, tokenizer=model_name)
193
+ QA_input = {'question': user_question, 'context': context_string}
194
+ print("Debug - QA_input:", QA_input)
195
+ res = nlp(QA_input)
196
+ print("Debug - res:", res)
197
+ print(res['answer'])
198
+ return res["answer"]
199
+
200
+ def generate_pdf(chat_history):
201
+ # Create a PDF document using ReportLab
202
+ buffer = io.BytesIO()
203
+
204
+ # Adjust the page size and margins as needed
205
+ pdf = SimpleDocTemplate(buffer, pagesize=letter)
206
+
207
+ # List to store the content for the PDF
208
+ pdf_content = []
209
+
210
+ # Get sample styles for formatting
211
+ styles = getSampleStyleSheet()
212
+
213
+ # Maximum characters per line
214
+ max_chars_per_line = 100
215
+
216
+ # Write chat history to the PDF
217
+ for message in chat_history:
218
+ if isinstance(message, dict):
219
+ for key, value in message.items():
220
+ formatted_value = value[:max_chars_per_line] + ('...' if len(value) > max_chars_per_line else '')
221
+ pdf_content.append(Paragraph(f"<strong>{key}:</strong> {formatted_value}", styles['Normal']))
222
+ elif isinstance(message, str):
223
+ formatted_message = message[:max_chars_per_line] + ('...' if len(message) > max_chars_per_line else '')
224
+ pdf_content.append(Paragraph(formatted_message, styles['Normal']))
225
+ else:
226
+ formatted_message = str(message)[:max_chars_per_line] + ('...' if len(str(message)) > max_chars_per_line else '')
227
+ pdf_content.append(Paragraph(formatted_message, styles['Normal']))
228
+ pdf_content.append(Spacer(1, 10)) # Add space between messages
229
+
230
+ # Build PDF document
231
+ pdf.build(pdf_content)
232
+
233
+ buffer.seek(0)
234
+ return buffer.getvalue()
235
+
236
+ @app.route('/download_pdf', methods=['GET'])
237
+ def download_pdf():
238
+ # Generate a PDF document based on chat history
239
+ pdf_buffer = generate_pdf(chat_history)
240
+
241
+ # Provide the PDF as a download
242
+ return send_file(
243
+ io.BytesIO(pdf_buffer),
244
+ as_attachment=True,
245
+ download_name='chat_history.pdf',
246
+ mimetype='application/pdf'
247
+ )
248
+
249
+ current_request_timestamp = None
250
+
251
+ @app.route('/', methods=['GET', 'POST'])
252
+ def home():
253
+ global current_request_timestamp
254
+ classification = None
255
+ summary_risk = None
256
+ summary_opportunity = None
257
+ article_content = None
258
+ input_submitted = False
259
+
260
+ if request.method == 'POST':
261
+ url_input = request.form['userInput']
262
+ print("Form Data:", request.form)
263
+ input_submitted = True
264
+ print(url_input)
265
+ cls_model = AutoModelForSequenceClassification.from_pretrained("riskclassification_finetuned_xlnet_model_ld")
266
+ tokenizer_cls = AutoTokenizer.from_pretrained("xlnet-base-cased")
267
+ label_encoder_path = "riskclassification_finetuned_xlnet_model_ld/encoder_labels.pkl"
268
+ label_encoder = LabelEncoder()
269
+
270
+ # Assuming 'label_column values' is the column you want to encode
271
+ label_column_values = ["risks","opportunities","neither"]
272
+
273
+
274
+ label_encoder.fit_transform(label_column_values)
275
+
276
+ joblib.dump(label_encoder, label_encoder_path)
277
+
278
+
279
+ model_summ = T5ForConditionalGeneration.from_pretrained("t5-small")
280
+ tokenizer_summ = T5Tokenizer.from_pretrained("t5-small")
281
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
282
+
283
+ if url_input.startswith("http"):
284
+ current_request_timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
285
+ # If the input starts with "http", assume it's a URL and extract content
286
+ totalresult = classify_and_summarize(
287
+ url_input, cls_model, tokenizer_cls, label_encoder, model_summ, tokenizer_summ, device
288
+ )
289
+ chat_history.extend(totalresult)
290
+ '''first={"Classification":classification}
291
+ second={"Summary risk":summary_risk}
292
+ opp={"Opportunity Summary":summary_opportunity}
293
+ third={"Article content":article_content}
294
+ chat_history.extend([{"Question":url_input}])
295
+ chat_history.extend([first])
296
+ chat_history.extend([second])
297
+ chat_history.extend([opp])
298
+ chat_history.extend([third])
299
+ chat_history.extend([{"Short Article content":short_article_content}]) ''' # Display only the first 200 words
300
+ '''return render_template('index.html', classification=classification, summary_risk=summary_risk,
301
+ summary_opportunity=summary_opportunity, article_content=article_content,
302
+ input_submitted=input_submitted, chat_history=chat_history)'''
303
+ elif is_question(url_input):
304
+ # If the input starts with questioning words, process the question
305
+ timestamp= datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
306
+ if current_request_timestamp and current_request_timestamp is not None:
307
+ articlecontent = retrieve_article_content(current_request_timestamp)
308
+ #articlecontent=retrieve_article_content()
309
+ answer = process_question(url_input,articlecontent)
310
+ # You need to implement process_question function
311
+ insert_question_and_answer(url_input,answer,timestamp)
312
+ uq={"User Question": url_input}
313
+ chat_history.extend([uq])
314
+ ma={"Model Answer": answer}
315
+ chat_history.extend([ma])
316
+ # return render_template('index.html', question=url_input,answer=answer,chat_history=chat_history)
317
+ print("chat history",chat_history)
318
+ return render_template('index.html', chat_history=chat_history,classification=classification, summary_risk=summary_risk, summary_opportunity=summary_opportunity, article_content=article_content, input_submitted=input_submitted)
319
+
320
+ if __name__ == '__main__':
321
+ app.run(debug=True)
db.sql ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Flask==3.0.2
2
+ flask_cors==4.0.0
3
+ torch==2.2.1
4
+ news-fetch==0.2.8
5
+ transformers==4.37.2
6
+ newspaper3k==0.2.8
7
+ nltk==3.8.1
8
+ reportlab==4.1.0
9
+ scikit-learn==1.4.1
10
+ joblib==1.3.2
11
+ sentencepiece==0.2.0
12
+ mysql-connector-python==8.3.0