Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import pipeline | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from transformers import BertForSequenceClassification, BertTokenizer | |
| model = BertForSequenceClassification.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned") | |
| tokenizer = BertTokenizer.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned") | |
| # Title | |
| st.title("Dutch news article classification") | |
| st.write("This app classifies a Dutch news article into one of 9 pre-defined* article categories") | |
| st.image('dataset-cover_articles.jpeg', width=150) | |
| text = st.text_area('Please type/copy/paste text of the Dutch article and click Submit') | |
| if st.button('Submit'): | |
| with st.spinner('Generating a response...'): | |
| encoding = tokenizer(text, return_tensors="pt") | |
| outputs = model(**encoding) | |
| predictions = outputs.logits.argmax(-1) | |
| number = predictions[0].cpu().detach().numpy() | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| fig = plt.figure(figsize=(10,4)) | |
| ax = fig.add_axes([0,0,1,1]) | |
| labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis', | |
| 'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech'] | |
| probs_plot = probabilities[0].cpu().detach().numpy()*100 | |
| ax.barh(labels_plot,probs_plot) | |
| ax.set_title("Predicted article category probability", fontsize=20) | |
| ax.set_xlabel("Probability (%)", fontsize=16) | |
| ax.set_ylabel("Predicted category", fontsize=16) | |
| # change the fontsize | |
| #ax.set_xticklabels(fontsize=14) | |
| ax.set_yticklabels(labels_plot, fontsize=14) | |
| st.pyplot(fig) | |
| st.write('The predicted category is: **{}** with a probability of: **{:.1f}%**'.format(labels_plot[number],(probs_plot[predictions])*1)) | |
| st.write("The pre-defined categories are Binnenland, Buitenland, Cultuur & Media, Economie , Koningshuis, Opmerkelijk, Politiek, Regionaal nieuws en Tech") | |
| st.write("The model for this app has been trained using data from Dutch news articles published by NOS. More information regarding the dataset can be found at https://www.kaggle.com/maxscheijen/dutch-news-articles") | |
| #st.write('\n') | |
| st.write('Model performance details can be found at https://huggingface.co/RuudVelo/dutch_news_clf_bert_finetuned') | |