Lauraayu commited on
Commit
f0e521d
·
verified ·
1 Parent(s): 5c1719d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -17
app.py CHANGED
@@ -1,22 +1,45 @@
1
- import streamlit as st
2
- from transformers import pipeline
 
3
 
4
- # Load the summariztion model pipeline
5
- summarizer_ntg = pipeline("text2text-generation", model="mrm8488/t5-base-finetuned-summarize-news")
6
- classifier = pipeline("text-classification", model='Lauraayu/News_Classi_Model', return_all_scores=True)
7
 
8
- # Streamlit application title
9
- st.title("News Classification")
10
- st.write("Classification for different News types")
11
 
12
- # Text input for user to enter the text to classify
13
- text = st.text_area("Enter the News to classify","")
14
-
15
- # Perform text classification when the user clicks the "Classify" button
16
- if st.button("Classify"):
 
 
 
 
 
 
17
 
18
- # Perform text classification on the input text
19
- result0 = summarizer_ntg(text)
20
- result = classifier(result0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- st.write(result)
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
4
 
5
+ # Step 1: Define Summarization Pipeline
6
+ summarizer_ntg = pipeline("summarization", model="mrm8488/t5-base-finetuned-summarize-news")
 
7
 
8
+ # Step 2: Define Classification Pipeline
9
+ tokenizer_bb = AutoTokenizer.from_pretrained("Lauraayu/News_Classi_Model")
10
+ model_bb = AutoModelForSequenceClassification.from_pretrained("Lauraayu/News_Classi_Model")
11
 
12
+ def summarize_and_classify(text):
13
+ # Summarize the article
14
+ summary = summarizer_ntg(text)[0]['summary_text']
15
+
16
+ # Tokenize the summarized text
17
+ inputs = tokenizer_bb(summary, return_tensors="pt", truncation=True, padding=True, max_length=512)
18
+
19
+ # Move inputs and model to the same device (GPU or CPU)
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ inputs = {k: v.to(device) for k, v in inputs.items()}
22
+ model_bb.to(device)
23
 
24
+ # Perform classification
25
+ with torch.no_grad():
26
+ outputs = model_bb(**inputs)
27
+
28
+ # Get the predicted label
29
+ predicted_label_id = torch.argmax(outputs.logits, dim=-1).item()
30
+ label_mapping = model_bb.config.id2label
31
+ predicted_label = label_mapping[predicted_label_id]
32
+
33
+ return summary, predicted_label
34
+
35
+ # Create Gradio Interface
36
+ iface = gr.Interface(
37
+ fn=summarize_and_classify,
38
+ inputs=gr.inputs.Textbox(lines=10, placeholder="Enter news article text here..."),
39
+ outputs=[gr.outputs.Textbox(label="Summary"), gr.outputs.Textbox(label="Category")],
40
+ title="News Article Summarizer and Classifier",
41
+ description="Enter a news article text and get its summary and category."
42
+ )
43
 
44
+ # Launch the interface
45
+ iface.launch()