Upload 6 files
Browse files- Dockerfile.txt +14 -0
- app.py +321 -0
- db.sql +0 -0
- requirements.txt +12 -0
Dockerfile.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
2 |
+
# you will also find guides on how best to write your Dockerfile
|
3 |
+
|
4 |
+
FROM python:3.9
|
5 |
+
|
6 |
+
WORKDIR /code
|
7 |
+
|
8 |
+
COPY ./requirements.txt /code/requirements.txt
|
9 |
+
|
10 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
11 |
+
|
12 |
+
COPY . .
|
13 |
+
|
14 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, render_template, request, jsonify, render_template_string
|
2 |
+
from flask_cors import CORS
|
3 |
+
from newspaper import Article
|
4 |
+
from transformers import pipeline
|
5 |
+
import torch
|
6 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, T5ForConditionalGeneration, T5Tokenizer
|
7 |
+
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
|
8 |
+
from sklearn.preprocessing import LabelEncoder
|
9 |
+
import joblib
|
10 |
+
import mysql.connector
|
11 |
+
from flask import send_file
|
12 |
+
from reportlab.pdfgen import canvas
|
13 |
+
import io
|
14 |
+
from reportlab.lib.pagesizes import letter
|
15 |
+
from reportlab.lib import colors
|
16 |
+
from reportlab.platypus import SimpleDocTemplate, Table, TableStyle, PageBreak, Paragraph
|
17 |
+
from nltk.tokenize import sent_tokenize
|
18 |
+
from reportlab.platypus import Spacer
|
19 |
+
from reportlab.platypus.flowables import KeepTogether
|
20 |
+
from reportlab.lib.styles import getSampleStyleSheet
|
21 |
+
import datetime
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
app = Flask(__name__, template_folder='templates')
|
26 |
+
CORS(app)
|
27 |
+
chat_history = []
|
28 |
+
|
29 |
+
mysql_config = {
|
30 |
+
'host': 'localhost',
|
31 |
+
'user': 'root',
|
32 |
+
'password': '9553641651',
|
33 |
+
'database': 'articles'
|
34 |
+
}
|
35 |
+
|
36 |
+
def insert_question_and_answer(question, answer,timestamp):
|
37 |
+
try:
|
38 |
+
# Connect to the MySQL database
|
39 |
+
connection = mysql.connector.connect(**mysql_config)
|
40 |
+
cursor = connection.cursor()
|
41 |
+
|
42 |
+
# SQL query to insert a new record into the 'supplychain' table
|
43 |
+
query = "INSERT INTO supplychain143 (question, answer, timestamp) VALUES (%s, %s, %s);"
|
44 |
+
values = (question, answer,timestamp)
|
45 |
+
|
46 |
+
# Execute the query
|
47 |
+
cursor.execute(query, values)
|
48 |
+
|
49 |
+
# Commit the changes
|
50 |
+
connection.commit()
|
51 |
+
|
52 |
+
# Close the cursor and connection
|
53 |
+
cursor.close()
|
54 |
+
connection.close()
|
55 |
+
|
56 |
+
print("Record inserted successfully!")
|
57 |
+
except Exception as e:
|
58 |
+
print("Error inserting record:", str(e))
|
59 |
+
|
60 |
+
def retrieve_article_content(timestamp):
|
61 |
+
try:
|
62 |
+
# Connect to the MySQL database
|
63 |
+
connection = mysql.connector.connect(**mysql_config)
|
64 |
+
cursor = connection.cursor()
|
65 |
+
|
66 |
+
# SQL query to retrieve article content based on the question
|
67 |
+
query = "SELECT question, answer FROM supplychain143 WHERE timestamp = %s;"
|
68 |
+
values = (timestamp,)
|
69 |
+
|
70 |
+
# Execute the query
|
71 |
+
cursor.execute(query, values)
|
72 |
+
|
73 |
+
# Fetch the results
|
74 |
+
results = cursor.fetchall()
|
75 |
+
# Close the cursor and connection
|
76 |
+
cursor.close()
|
77 |
+
connection.close()
|
78 |
+
|
79 |
+
return results
|
80 |
+
except Exception as e:
|
81 |
+
print("Error retrieving article content:", str(e))
|
82 |
+
return None
|
83 |
+
|
84 |
+
def scrape_news_content(url):
|
85 |
+
# ... (Your existing implementation)
|
86 |
+
try:
|
87 |
+
article = Article(url)
|
88 |
+
article.download()
|
89 |
+
article.parse()
|
90 |
+
|
91 |
+
title = article.title
|
92 |
+
content = article.text
|
93 |
+
|
94 |
+
return content
|
95 |
+
# Remove leading/trailing whitespaces
|
96 |
+
except Exception as e:
|
97 |
+
return "Error: " + str(e)
|
98 |
+
|
99 |
+
|
100 |
+
def summarize_with_t5(article_content, classification, model, tokenizer, device):
|
101 |
+
# ... (Your existing implementation)
|
102 |
+
article_content = str(article_content)
|
103 |
+
prompt = "Classification: " + str(classification) + "\n"
|
104 |
+
if not article_content or article_content == "nan":
|
105 |
+
return "", ""
|
106 |
+
if classification == "risks":
|
107 |
+
prompt = "summarize the key supply chain risks: "
|
108 |
+
elif classification == "opportunities":
|
109 |
+
prompt = "summarize the key supply chain opportunities: "
|
110 |
+
elif classification == "neither":
|
111 |
+
print("Nooo")
|
112 |
+
return "None", "None"
|
113 |
+
|
114 |
+
input_text = prompt + article_content
|
115 |
+
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
116 |
+
|
117 |
+
model = model.to(device) #/ Move the model to the correct device
|
118 |
+
summary_ids = model.generate(input_ids.to(device), max_length=150, num_beams=4, length_penalty=2.0, early_stopping=True)
|
119 |
+
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
120 |
+
print(summary)
|
121 |
+
if classification in ["risks", "opportunities"]:
|
122 |
+
if classification == "risks":
|
123 |
+
return summary, "None"
|
124 |
+
elif classification == "opportunities":
|
125 |
+
return "None", summary
|
126 |
+
else:
|
127 |
+
return None,None
|
128 |
+
else:
|
129 |
+
return ("This article is not classified as related to the supply chain.")
|
130 |
+
|
131 |
+
|
132 |
+
def classify_and_summarize(input_text, cls_model, tokenizer_cls, label_encoder, model_summ, tokenizer_summ, device):
|
133 |
+
# ... (Your existing implementation)
|
134 |
+
results = []
|
135 |
+
request_timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
136 |
+
input_text=input_text.split(",")
|
137 |
+
for url in input_text:
|
138 |
+
if url.startswith("http"):
|
139 |
+
# If the input starts with "http", assume it's a URL and extract content
|
140 |
+
article_content = scrape_news_content(url)
|
141 |
+
else:
|
142 |
+
# If the input is not a URL, assume it's the content
|
143 |
+
article_content = url
|
144 |
+
|
145 |
+
# Perform sentiment classification
|
146 |
+
inputs_cls = tokenizer_cls(article_content, return_tensors="pt", max_length=512, truncation=True, padding=True)
|
147 |
+
inputs_cls = {key: value.to(device) for key, value in inputs_cls.items()}
|
148 |
+
|
149 |
+
# Move cls_model to the specified device
|
150 |
+
cls_model = cls_model.to(device)
|
151 |
+
|
152 |
+
outputs_cls = cls_model(**inputs_cls)
|
153 |
+
logits_cls = outputs_cls.logits
|
154 |
+
predicted_class = torch.argmax(logits_cls, dim=1).item()
|
155 |
+
classification = label_encoder.inverse_transform([predicted_class])[0]
|
156 |
+
|
157 |
+
# Perform summarization based on the classification
|
158 |
+
summary_risk, summary_opportunity = summarize_with_t5(article_content, classification, model_summ, tokenizer_summ, device)
|
159 |
+
|
160 |
+
if summary_risk is None:
|
161 |
+
summary_risk = "No risk summary available"
|
162 |
+
if summary_opportunity is None:
|
163 |
+
summary_opportunity = "No opportunity summary available"
|
164 |
+
answer=article_content
|
165 |
+
article_content_words = article_content.split()[:200]
|
166 |
+
short_article_content = ' '.join(article_content_words)
|
167 |
+
insert_question_and_answer(url,answer, request_timestamp)
|
168 |
+
current_request_timestamp=request_timestamp
|
169 |
+
results.append({"Question": url, "Article content":article_content,"Short Article content":short_article_content,"Classification": classification, "Summary risk": summary_risk, "Opportunity Summary": summary_opportunity})
|
170 |
+
print("Result",results)
|
171 |
+
return results
|
172 |
+
|
173 |
+
def generate_sentence_from_keywords(keywords):
|
174 |
+
# Concatenate keywords into a single string
|
175 |
+
keyword_sentence = ' '.join(keywords)
|
176 |
+
|
177 |
+
# Tokenize the concatenated keywords into sentences
|
178 |
+
sentences = sent_tokenize(keyword_sentence)
|
179 |
+
|
180 |
+
# If there are sentences, return the first one; otherwise, return a default message
|
181 |
+
return sentences[0] if sentences else "Unable to generate a sentence."
|
182 |
+
|
183 |
+
def is_question(input_text):
|
184 |
+
questioning_words = ["who", "what", "when", "where", "why", "how"]
|
185 |
+
return any(input_text.lower().startswith(q) for q in questioning_words)
|
186 |
+
|
187 |
+
|
188 |
+
def process_question(user_question,articlecontent):
|
189 |
+
answers = [item[1] for item in articlecontent]
|
190 |
+
context_string = ' '.join(map(str, answers))
|
191 |
+
model_name = "deepset/tinyroberta-squad2"
|
192 |
+
nlp = pipeline('question-answering', model=model_name, tokenizer=model_name)
|
193 |
+
QA_input = {'question': user_question, 'context': context_string}
|
194 |
+
print("Debug - QA_input:", QA_input)
|
195 |
+
res = nlp(QA_input)
|
196 |
+
print("Debug - res:", res)
|
197 |
+
print(res['answer'])
|
198 |
+
return res["answer"]
|
199 |
+
|
200 |
+
def generate_pdf(chat_history):
|
201 |
+
# Create a PDF document using ReportLab
|
202 |
+
buffer = io.BytesIO()
|
203 |
+
|
204 |
+
# Adjust the page size and margins as needed
|
205 |
+
pdf = SimpleDocTemplate(buffer, pagesize=letter)
|
206 |
+
|
207 |
+
# List to store the content for the PDF
|
208 |
+
pdf_content = []
|
209 |
+
|
210 |
+
# Get sample styles for formatting
|
211 |
+
styles = getSampleStyleSheet()
|
212 |
+
|
213 |
+
# Maximum characters per line
|
214 |
+
max_chars_per_line = 100
|
215 |
+
|
216 |
+
# Write chat history to the PDF
|
217 |
+
for message in chat_history:
|
218 |
+
if isinstance(message, dict):
|
219 |
+
for key, value in message.items():
|
220 |
+
formatted_value = value[:max_chars_per_line] + ('...' if len(value) > max_chars_per_line else '')
|
221 |
+
pdf_content.append(Paragraph(f"<strong>{key}:</strong> {formatted_value}", styles['Normal']))
|
222 |
+
elif isinstance(message, str):
|
223 |
+
formatted_message = message[:max_chars_per_line] + ('...' if len(message) > max_chars_per_line else '')
|
224 |
+
pdf_content.append(Paragraph(formatted_message, styles['Normal']))
|
225 |
+
else:
|
226 |
+
formatted_message = str(message)[:max_chars_per_line] + ('...' if len(str(message)) > max_chars_per_line else '')
|
227 |
+
pdf_content.append(Paragraph(formatted_message, styles['Normal']))
|
228 |
+
pdf_content.append(Spacer(1, 10)) # Add space between messages
|
229 |
+
|
230 |
+
# Build PDF document
|
231 |
+
pdf.build(pdf_content)
|
232 |
+
|
233 |
+
buffer.seek(0)
|
234 |
+
return buffer.getvalue()
|
235 |
+
|
236 |
+
@app.route('/download_pdf', methods=['GET'])
|
237 |
+
def download_pdf():
|
238 |
+
# Generate a PDF document based on chat history
|
239 |
+
pdf_buffer = generate_pdf(chat_history)
|
240 |
+
|
241 |
+
# Provide the PDF as a download
|
242 |
+
return send_file(
|
243 |
+
io.BytesIO(pdf_buffer),
|
244 |
+
as_attachment=True,
|
245 |
+
download_name='chat_history.pdf',
|
246 |
+
mimetype='application/pdf'
|
247 |
+
)
|
248 |
+
|
249 |
+
current_request_timestamp = None
|
250 |
+
|
251 |
+
@app.route('/', methods=['GET', 'POST'])
|
252 |
+
def home():
|
253 |
+
global current_request_timestamp
|
254 |
+
classification = None
|
255 |
+
summary_risk = None
|
256 |
+
summary_opportunity = None
|
257 |
+
article_content = None
|
258 |
+
input_submitted = False
|
259 |
+
|
260 |
+
if request.method == 'POST':
|
261 |
+
url_input = request.form['userInput']
|
262 |
+
print("Form Data:", request.form)
|
263 |
+
input_submitted = True
|
264 |
+
print(url_input)
|
265 |
+
cls_model = AutoModelForSequenceClassification.from_pretrained("riskclassification_finetuned_xlnet_model_ld")
|
266 |
+
tokenizer_cls = AutoTokenizer.from_pretrained("xlnet-base-cased")
|
267 |
+
label_encoder_path = "riskclassification_finetuned_xlnet_model_ld/encoder_labels.pkl"
|
268 |
+
label_encoder = LabelEncoder()
|
269 |
+
|
270 |
+
# Assuming 'label_column values' is the column you want to encode
|
271 |
+
label_column_values = ["risks","opportunities","neither"]
|
272 |
+
|
273 |
+
|
274 |
+
label_encoder.fit_transform(label_column_values)
|
275 |
+
|
276 |
+
joblib.dump(label_encoder, label_encoder_path)
|
277 |
+
|
278 |
+
|
279 |
+
model_summ = T5ForConditionalGeneration.from_pretrained("t5-small")
|
280 |
+
tokenizer_summ = T5Tokenizer.from_pretrained("t5-small")
|
281 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
282 |
+
|
283 |
+
if url_input.startswith("http"):
|
284 |
+
current_request_timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
285 |
+
# If the input starts with "http", assume it's a URL and extract content
|
286 |
+
totalresult = classify_and_summarize(
|
287 |
+
url_input, cls_model, tokenizer_cls, label_encoder, model_summ, tokenizer_summ, device
|
288 |
+
)
|
289 |
+
chat_history.extend(totalresult)
|
290 |
+
'''first={"Classification":classification}
|
291 |
+
second={"Summary risk":summary_risk}
|
292 |
+
opp={"Opportunity Summary":summary_opportunity}
|
293 |
+
third={"Article content":article_content}
|
294 |
+
chat_history.extend([{"Question":url_input}])
|
295 |
+
chat_history.extend([first])
|
296 |
+
chat_history.extend([second])
|
297 |
+
chat_history.extend([opp])
|
298 |
+
chat_history.extend([third])
|
299 |
+
chat_history.extend([{"Short Article content":short_article_content}]) ''' # Display only the first 200 words
|
300 |
+
'''return render_template('index.html', classification=classification, summary_risk=summary_risk,
|
301 |
+
summary_opportunity=summary_opportunity, article_content=article_content,
|
302 |
+
input_submitted=input_submitted, chat_history=chat_history)'''
|
303 |
+
elif is_question(url_input):
|
304 |
+
# If the input starts with questioning words, process the question
|
305 |
+
timestamp= datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
306 |
+
if current_request_timestamp and current_request_timestamp is not None:
|
307 |
+
articlecontent = retrieve_article_content(current_request_timestamp)
|
308 |
+
#articlecontent=retrieve_article_content()
|
309 |
+
answer = process_question(url_input,articlecontent)
|
310 |
+
# You need to implement process_question function
|
311 |
+
insert_question_and_answer(url_input,answer,timestamp)
|
312 |
+
uq={"User Question": url_input}
|
313 |
+
chat_history.extend([uq])
|
314 |
+
ma={"Model Answer": answer}
|
315 |
+
chat_history.extend([ma])
|
316 |
+
# return render_template('index.html', question=url_input,answer=answer,chat_history=chat_history)
|
317 |
+
print("chat history",chat_history)
|
318 |
+
return render_template('index.html', chat_history=chat_history,classification=classification, summary_risk=summary_risk, summary_opportunity=summary_opportunity, article_content=article_content, input_submitted=input_submitted)
|
319 |
+
|
320 |
+
if __name__ == '__main__':
|
321 |
+
app.run(debug=True)
|
db.sql
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Flask==3.0.2
|
2 |
+
flask_cors==4.0.0
|
3 |
+
torch==2.2.1
|
4 |
+
news-fetch==0.2.8
|
5 |
+
transformers==4.37.2
|
6 |
+
newspaper3k==0.2.8
|
7 |
+
nltk==3.8.1
|
8 |
+
reportlab==4.1.0
|
9 |
+
scikit-learn==1.4.1
|
10 |
+
joblib==1.3.2
|
11 |
+
sentencepiece==0.2.0
|
12 |
+
mysql-connector-python==8.3.0
|