dbleek commited on
Commit
8114970
2 Parent(s): fa4d16c efd1c85

Merge branch 'main' into milestone-3

Browse files
Files changed (1) hide show
  1. milestone_2.py +26 -0
milestone_2.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import (AutoTokenizer, TFAutoModelForSequenceClassification,
3
+ pipeline)
4
+
5
+ st.title("CS-GY-6613 Project Milestone 2")
6
+ model_choices = (
7
+ "distilbert-base-uncased-finetuned-sst-2-english",
8
+ "j-hartmann/emotion-english-distilroberta-base",
9
+ "joeddav/distilbert-base-uncased-go-emotions-student",
10
+ )
11
+
12
+ with st.form("Input Form"):
13
+ text = st.text_area("Write your text here:", "CS-GY-6613 is a great course!")
14
+ model_name = st.selectbox("Select a model:", model_choices)
15
+ submitted = st.form_submit_button("Submit")
16
+
17
+ if submitted:
18
+ model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
19
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+ classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
21
+ res = classifier(text)
22
+ label = res[0]["label"].upper()
23
+ score = res[0]["score"]
24
+ st.markdown(
25
+ f"This text was classified as **{label}** with a confidence score of **{score}**."
26
+ )