DataAIDemo / pages /tweet_classification.py
themeetjani's picture
Upload 10 files
6060e42 verified
raw
history blame
1.1 kB
import streamlit as st
from streamlit import session_state
# Load model directly
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline
tokenizer = AutoTokenizer.from_pretrained("themeetjani/tweet-classification")
model = AutoModelForSequenceClassification.from_pretrained("themeetjani/tweet-classification")
classifier = pipeline("text-classification", model= model, tokenizer = tokenizer, truncation=True, max_length=512)
st.set_page_config(page_title="Classification", page_icon="📈")
if 'tweet_class' not in session_state:
session_state['tweet_class']= ""
def classify(tweet):
predicted_classes= session_state['tweet_class']= classifier(tweet, top_k=1)
print (tweet)
print (predicted_classes)
session_state['tweet_class'] = predicted_classes[0]['label']
st.title("Tweet Classifier")
tweet= st.text_area(label= "Please write the tweet bellow",
placeholder="What does the tweet say?")
st.text_area("result", value=session_state['tweet_class'])
st.button("Classify", on_click=classify, args=[tweet])