File size: 2,763 Bytes
e3d46c8 8158997 e3d46c8 0788ae6 56b99e8 0788ae6 e3d46c8 56b99e8 0788ae6 a85495d 68d6aa3 093cd61 68d6aa3 8158997 68d6aa3 8158997 68d6aa3 9299174 bceabb4 68d6aa3 bceabb4 68d6aa3 bceabb4 0788ae6 bceabb4 68d6aa3 0788ae6 8158997 0788ae6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import streamlit as st
from transformers import pipeline
import torch
import matplotlib.pyplot as plt
#pipe = pipeline(model="RuudVelo/dutch_news_classifier_bert_finetuned")
#text = st.text_area('Please type/copy/paste the Dutch article')
#labels = ['Binnenland' 'Buitenland' 'Cultuur & Media' 'Economie' 'Koningshuis'
# 'Opmerkelijk' 'Politiek' 'Regionaal nieuws' 'Tech']
#if text:
# out = pipe(text)
# st.json(out)
# load tokenizer and model, create trainer
#model_name = "RuudVelo/dutch_news_classifier_bert_finetuned"
#tokenizer = AutoTokenizer.from_pretrained(model_name)
#model = AutoModelForSequenceClassification.from_pretrained(model_name)
#trainer = Trainer(model=model)
#print(filename, type(filename))
#print(filename.name)
from transformers import BertForSequenceClassification, BertTokenizer
model = BertForSequenceClassification.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned")
#from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned")
# Title
st.title("Dutch news article classification")
#text = st.text_area('Please type/copy/paste text of the Dutch article')
#if text:
# encoding = tokenizer(text, return_tensors="pt")
# outputs = model(**encoding)
# predictions = outputs.logits.argmax(-1)
# probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
## fig = plt.figure()
# ax = fig.add_axes([0,0,1,1])
# labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis',
# 'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech']
# probs_plot = probabilities[0].cpu().detach().numpy()
# ax.barh(labels_plot,probs_plot )
# st.pyplot(fig)
input = st.text_input('Context')
if st.button('Submit'):
with st.spinner('Generating a response...'):
encoding = tokenizer(input, return_tensors="pt")
outputs = model(**encoding)
predictions = outputs.logits.argmax(-1)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
fig = plt.figure()
ax = fig.add_axes([0,0,1,1])
labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis',
'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech']
probs_plot = probabilities[0].cpu().detach().numpy()
ax.barh(labels_plot,probs_plot )
st.pyplot(fig)
# output = genQuestion(option, input)
# print(output)
# st.write(output)
#encoding = tokenizer(text, return_tensors="pt")
#import numpy as np
#arr = np.random.normal(1, 1, size=100)
#fig, ax = plt.subplots()
#ax.hist(arr, bins=20)
#st.pyplot(fig)
# forward pass
#outputs = model(**encoding)
#predictions = outputs.logits.argmax(-1)
|