File size: 12,851 Bytes
1630a37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37e5ca9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493fe55
 
1630a37
 
 
 
7c4b085
1630a37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c4b085
1630a37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a46194
792591c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
from flask import Flask, render_template, request, jsonify, render_template_string
from flask_cors import CORS
from newspaper import Article
from transformers import pipeline
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, T5ForConditionalGeneration, T5Tokenizer
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from sklearn.preprocessing import LabelEncoder
import joblib
import mysql.connector
from flask import send_file
from reportlab.pdfgen import canvas
import io
from reportlab.lib.pagesizes import letter
from reportlab.lib import colors
from reportlab.platypus import SimpleDocTemplate, Table, TableStyle, PageBreak, Paragraph
from nltk.tokenize import sent_tokenize
from reportlab.platypus import Spacer
from reportlab.platypus.flowables import KeepTogether
from reportlab.lib.styles import getSampleStyleSheet
import datetime



app = Flask(__name__, template_folder='templates')
CORS(app)
chat_history = []


cls_model = AutoModelForSequenceClassification.from_pretrained("riskclassification_finetuned_xlnet_model_ld")
tokenizer_cls = AutoTokenizer.from_pretrained("xlnet-base-cased")
label_encoder_path = "riskclassification_finetuned_xlnet_model_ld/encoder_labels.pkl"
label_encoder = LabelEncoder()

# Assuming 'label_column values' is the column you want to encode
label_column_values = ["risks","opportunities","neither"]


label_encoder.fit_transform(label_column_values)

joblib.dump(label_encoder, label_encoder_path)


model_summ = T5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer_summ = T5Tokenizer.from_pretrained("t5-small")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model_name = "deepset/tinyroberta-squad2"
nlp = pipeline('question-answering', model=model_name, tokenizer=model_name)

def insert_question_and_answer(question, answer,timestamp):
    try:
        # Connect to the MySQL database
        connection = mysql.connector.connect('db.sql')
        cursor = connection.cursor()

        # SQL query to insert a new record into the 'supplychain' table
        query = "INSERT INTO supplychain143 (question, answer, timestamp) VALUES (%s, %s, %s);"
        values = (question, answer,timestamp)

        # Execute the query
        cursor.execute(query, values)

        # Commit the changes
        connection.commit()

        # Close the cursor and connection
        cursor.close()
        connection.close()

        print("Record inserted successfully!")
    except Exception as e:
        print("Error inserting record:", str(e))

def retrieve_article_content(timestamp):
    try:
        # Connect to the MySQL database
        connection = mysql.connector.connect('db.sql')
        cursor = connection.cursor()

        # SQL query to retrieve article content based on the question
        query = "SELECT question, answer FROM supplychain143 WHERE timestamp = %s;"
        values = (timestamp,)

        # Execute the query
        cursor.execute(query, values)

        # Fetch the results
        results = cursor.fetchall()
        # Close the cursor and connection
        cursor.close()
        connection.close()

        return results
    except Exception as e:
        print("Error retrieving article content:", str(e))
        return None

def scrape_news_content(url):
    # ... (Your existing implementation)
    try:
      article = Article(url)
      article.download()
      article.parse()

      title = article.title
      content = article.text

      return content
 # Remove leading/trailing whitespaces
    except Exception as e:
      return "Error: " + str(e)
      

def summarize_with_t5(article_content, classification, model, tokenizer, device):
    # ... (Your existing implementation)
    article_content = str(article_content)
    prompt = "Classification: " + str(classification) + "\n"
    if not article_content or article_content == "nan":
        return "", ""
    if classification == "risks":
        prompt = "summarize the key supply chain risks: "
    elif classification == "opportunities":
        prompt = "summarize the key supply chain opportunities: "
    elif classification == "neither":
        print("Nooo")
        return "None", "None"

    input_text = prompt + article_content
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

    model = model.to(device)  #/ Move the model to the correct device
    summary_ids = model.generate(input_ids.to(device), max_length=150, num_beams=4, length_penalty=2.0, early_stopping=True)
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    print(summary)
    if classification in ["risks", "opportunities"]:
        if classification == "risks":
            return summary, "None"
        elif classification == "opportunities":
            return "None", summary
        else:
          return None,None
    else:
        return ("This article is not classified as related to the supply chain.")


def classify_and_summarize(input_text, cls_model, tokenizer_cls, label_encoder, model_summ, tokenizer_summ, device):
    # ... (Your existing implementation)
    results = []
    request_timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    input_text=input_text.split(",")
    for url in input_text:
        if url.startswith("http"):
            # If the input starts with "http", assume it's a URL and extract content
            article_content = scrape_news_content(url)
        else:
            # If the input is not a URL, assume it's the content
            article_content = url

        # Perform sentiment classification
        inputs_cls = tokenizer_cls(article_content, return_tensors="pt", max_length=512, truncation=True, padding=True)
        inputs_cls = {key: value.to(device) for key, value in inputs_cls.items()}

        # Move cls_model to the specified device
        cls_model = cls_model.to(device)

        outputs_cls = cls_model(**inputs_cls)
        logits_cls = outputs_cls.logits
        predicted_class = torch.argmax(logits_cls, dim=1).item()
        classification = label_encoder.inverse_transform([predicted_class])[0]

        # Perform summarization based on the classification
        summary_risk, summary_opportunity = summarize_with_t5(article_content, classification, model_summ, tokenizer_summ, device)

        if summary_risk is None:
            summary_risk = "No risk summary available"
        if summary_opportunity is None:
            summary_opportunity = "No opportunity summary available"
        answer=article_content
        article_content_words = article_content.split()[:200]
        short_article_content = ' '.join(article_content_words)
        insert_question_and_answer(url,answer, request_timestamp)
        current_request_timestamp=request_timestamp
        results.append({"Question": url, "Article content":article_content,"Short Article content":short_article_content,"Classification": classification, "Summary risk": summary_risk, "Opportunity Summary": summary_opportunity})
        print("Result",results)
    return results

def generate_sentence_from_keywords(keywords):
    # Concatenate keywords into a single string
    keyword_sentence = ' '.join(keywords)

    # Tokenize the concatenated keywords into sentences
    sentences = sent_tokenize(keyword_sentence)

    # If there are sentences, return the first one; otherwise, return a default message
    return sentences[0] if sentences else "Unable to generate a sentence."

def is_question(input_text):
    questioning_words = ["who", "what", "when", "where", "why", "how"]
    return any(input_text.lower().startswith(q) for q in questioning_words)


def process_question(user_question,articlecontent):
    answers = [item[1] for item in articlecontent]
    context_string = ' '.join(map(str, answers))
    QA_input = {'question': user_question, 'context': context_string}
    print("Debug - QA_input:", QA_input)
    res = nlp(QA_input)
    print("Debug - res:", res)
    print(res['answer'])
    return res["answer"]

def generate_pdf(chat_history):
    # Create a PDF document using ReportLab
    buffer = io.BytesIO()

    # Adjust the page size and margins as needed
    pdf = SimpleDocTemplate(buffer, pagesize=letter)

    # List to store the content for the PDF
    pdf_content = []

    # Get sample styles for formatting
    styles = getSampleStyleSheet()

    # Maximum characters per line
    max_chars_per_line = 100

    # Write chat history to the PDF
    for message in chat_history:
        if isinstance(message, dict):
            for key, value in message.items():
                formatted_value = value[:max_chars_per_line] + ('...' if len(value) > max_chars_per_line else '')
                pdf_content.append(Paragraph(f"<strong>{key}:</strong> {formatted_value}", styles['Normal']))
        elif isinstance(message, str):
            formatted_message = message[:max_chars_per_line] + ('...' if len(message) > max_chars_per_line else '')
            pdf_content.append(Paragraph(formatted_message, styles['Normal']))
        else:
            formatted_message = str(message)[:max_chars_per_line] + ('...' if len(str(message)) > max_chars_per_line else '')
            pdf_content.append(Paragraph(formatted_message, styles['Normal']))
        pdf_content.append(Spacer(1, 10))  # Add space between messages

    # Build PDF document
    pdf.build(pdf_content)

    buffer.seek(0)
    return buffer.getvalue()

@app.route('/download_pdf', methods=['GET'])
def download_pdf():
    # Generate a PDF document based on chat history
    pdf_buffer = generate_pdf(chat_history)
    
    # Provide the PDF as a download
    return send_file(
        io.BytesIO(pdf_buffer),
        as_attachment=True,
        download_name='chat_history.pdf',
        mimetype='application/pdf'
    )

current_request_timestamp = None

@app.route('/', methods=['GET', 'POST'])
def home():
    global current_request_timestamp
    classification = None
    summary_risk = None
    summary_opportunity = None
    article_content = None 
    input_submitted = False

    if request.method == 'POST':
        url_input = request.form['userInput']
        print("Form Data:", request.form)
        input_submitted = True
        print(url_input)

        if url_input.startswith("http"):
            current_request_timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            # If the input starts with "http", assume it's a URL and extract content
            totalresult = classify_and_summarize(
                url_input, cls_model, tokenizer_cls, label_encoder, model_summ, tokenizer_summ, device
            )
            chat_history.extend(totalresult)
            '''first={"Classification":classification}
            second={"Summary risk":summary_risk}
            opp={"Opportunity Summary":summary_opportunity}
            third={"Article content":article_content}
            chat_history.extend([{"Question":url_input}])
            chat_history.extend([first])
            chat_history.extend([second])
            chat_history.extend([opp])
            chat_history.extend([third])
            chat_history.extend([{"Short Article content":short_article_content}]) ''' # Display only the first 200 words
            '''return render_template('index.html', classification=classification, summary_risk=summary_risk,
                               summary_opportunity=summary_opportunity, article_content=article_content,
                               input_submitted=input_submitted, chat_history=chat_history)'''
        elif is_question(url_input):
            # If the input starts with questioning words, process the question
            timestamp= datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            if current_request_timestamp and current_request_timestamp is not None:
                articlecontent = retrieve_article_content(current_request_timestamp)
                #articlecontent=retrieve_article_content()
                answer = process_question(url_input,articlecontent) 
             # You need to implement process_question function
                insert_question_and_answer(url_input,answer,timestamp)
                uq={"User Question": url_input}
                chat_history.extend([uq])
                ma={"Model Answer": answer}
                chat_history.extend([ma])
           # return render_template('index.html', question=url_input,answer=answer,chat_history=chat_history)
    print("chat history",chat_history)
    return render_template('index.html', chat_history=chat_history,classification=classification, summary_risk=summary_risk, summary_opportunity=summary_opportunity, article_content=article_content, input_submitted=input_submitted)

if __name__ == '__main__':
    app.run(debug=True,host='0.0.0.0', port=7860)