Spaces:
Sleeping
Sleeping
liamvbetts
commited on
Commit
•
8df474c
1
Parent(s):
bc1a0a8
new changes
Browse files- app.py +42 -22
- requirements.txt +1 -1
app.py
CHANGED
@@ -1,22 +1,38 @@
|
|
1 |
import gradio as gr
|
2 |
import random
|
3 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
from datasets import load_dataset
|
5 |
import requests
|
6 |
from bs4 import BeautifulSoup
|
7 |
-
|
8 |
-
tokenizer = AutoTokenizer.from_pretrained("liamvbetts/bart-large-cnn-v4")
|
9 |
-
model = AutoModelForSeq2SeqLM.from_pretrained("liamvbetts/bart-large-cnn-v4")
|
10 |
|
11 |
dataset = load_dataset("cnn_dailymail", "3.0.0")
|
12 |
|
13 |
-
NEWS_API_KEY =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
def
|
16 |
-
|
17 |
-
outputs = model.generate(inputs, max_new_tokens=128, do_sample=False)
|
18 |
-
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
19 |
-
return summary
|
20 |
|
21 |
def get_random_article():
|
22 |
random.seed()
|
@@ -61,28 +77,32 @@ def scrape_article(url):
|
|
61 |
|
62 |
text = ' '.join([p.get_text() for p in article_content])
|
63 |
words = text.split()
|
64 |
-
truncated_text = ' '.join(words[:
|
65 |
|
66 |
return truncated_text, title
|
67 |
except Exception as e:
|
68 |
return "Error scraping article: " + str(e), ""
|
69 |
|
70 |
-
# Using Gradio Blocks
|
71 |
with gr.Blocks() as demo:
|
72 |
-
gr.Markdown("
|
73 |
gr.Markdown("Enter a news text and get its summary, or load a random article.")
|
|
|
74 |
with gr.Row():
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
load_dataset_article_button.click(fn=load_article, inputs=[], outputs=input_text)
|
83 |
load_news_article_button.click(fn=get_news_article, inputs=[], outputs=[input_text, article_title])
|
84 |
-
|
85 |
-
summarize_button = gr.Button("Summarize")
|
86 |
-
summarize_button.click(fn=summarize, inputs=input_text, outputs=output_text)
|
87 |
|
88 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import random
|
|
|
3 |
from datasets import load_dataset
|
4 |
import requests
|
5 |
from bs4 import BeautifulSoup
|
6 |
+
import os
|
|
|
|
|
7 |
|
8 |
dataset = load_dataset("cnn_dailymail", "3.0.0")
|
9 |
|
10 |
+
NEWS_API_KEY = os.environ['NEWS_API_KEY']
|
11 |
+
HF_TOKEN = os.environ['HF_TOKEN']
|
12 |
+
|
13 |
+
def summarize(model_name, article):
|
14 |
+
API_URL = f"https://api-inference.huggingface.co/models/{model_name}"
|
15 |
+
headers = {"Authorization": "Bearer {HF_TOKEN}"}
|
16 |
+
|
17 |
+
payload = {"inputs": article}
|
18 |
+
response = requests.post(API_URL, headers=headers, json=payload)
|
19 |
+
|
20 |
+
# Check if the response is successful
|
21 |
+
if response.status_code == 200:
|
22 |
+
# Assuming the response structure has a 'generated_text' field
|
23 |
+
return format(response.json())
|
24 |
+
else:
|
25 |
+
# Handle different types of errors
|
26 |
+
if response.status_code == 401:
|
27 |
+
return "Error: Unauthorized. Check your API token."
|
28 |
+
elif response.status_code == 503:
|
29 |
+
return "Error: Service unavailable or model is currently loading."
|
30 |
+
else:
|
31 |
+
return f"{response} - Error: Encountered an issue (status code: {response.status_code}). Please try again."
|
32 |
+
return format(response.json())
|
33 |
|
34 |
+
def format(response):
|
35 |
+
return response[0]['generated_text']
|
|
|
|
|
|
|
36 |
|
37 |
def get_random_article():
|
38 |
random.seed()
|
|
|
77 |
|
78 |
text = ' '.join([p.get_text() for p in article_content])
|
79 |
words = text.split()
|
80 |
+
truncated_text = ' '.join(words[:512]) # Truncate to first 1024 words
|
81 |
|
82 |
return truncated_text, title
|
83 |
except Exception as e:
|
84 |
return "Error scraping article: " + str(e), ""
|
85 |
|
86 |
+
# Using Gradio Blocks with improved layout and styling
|
87 |
with gr.Blocks() as demo:
|
88 |
+
gr.Markdown("# News Summary App", elem_id="header")
|
89 |
gr.Markdown("Enter a news text and get its summary, or load a random article.")
|
90 |
+
|
91 |
with gr.Row():
|
92 |
+
with gr.Column():
|
93 |
+
with gr.Row():
|
94 |
+
load_dataset_article_button = gr.Button("Load Random Article from Val Dataset")
|
95 |
+
load_news_article_button = gr.Button("Pull Random News Article from NewsAPI")
|
96 |
+
article_title = gr.Label() # Component to display the article title
|
97 |
+
input_text = gr.Textbox(lines=10, label="Input Text")
|
98 |
+
with gr.Column():
|
99 |
+
with gr.Row():
|
100 |
+
summarize_button = gr.Button("Summarize")
|
101 |
+
model_name = gr.Dropdown(label="Model Name", choices=["liamvbetts/bart-news-summary-v1", "liamvbetts/bart-base-cnn-v1", "liamvbetts/bart-large-cnn-v2", "liamvbetts/bart-large-cnn-v4"], value="liamvbetts/bart-news-summary-v1")
|
102 |
+
output_text = gr.Textbox(label="Summary")
|
103 |
|
104 |
load_dataset_article_button.click(fn=load_article, inputs=[], outputs=input_text)
|
105 |
load_news_article_button.click(fn=get_news_article, inputs=[], outputs=[input_text, article_title])
|
106 |
+
summarize_button.click(fn=summarize, inputs=[model_name, input_text], outputs=output_text)
|
|
|
|
|
107 |
|
108 |
demo.launch()
|
requirements.txt
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
gradio
|
2 |
-
transformers
|
3 |
datasets
|
4 |
evaluate
|
5 |
accelerate
|
|
|
1 |
gradio
|
2 |
+
#transformers
|
3 |
datasets
|
4 |
evaluate
|
5 |
accelerate
|