Spaces:
Runtime error
Runtime error
import os | |
import json | |
import random | |
import streamlit as st | |
from transformers import TextClassificationPipeline, pipeline | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DistilBertTokenizerFast, DistilBertForSequenceClassification | |
emotion_model_names = ( | |
"cardiffnlp/twitter-roberta-base-sentiment", | |
"finiteautomata/beto-sentiment-analysis", | |
"bhadresh-savani/distilbert-base-uncased-emotion", | |
"siebert/sentiment-roberta-large-english" | |
) | |
class ModelImplementation(object): | |
def __init__( | |
self, | |
transformer_model_name, | |
model_transformer, | |
tokenizer_model_name, | |
tokenizer_func, | |
pipeline_func, | |
parser_func, | |
classifier_args={}, | |
placeholders=[""] | |
): | |
self.transformer_model_name = transformer_model_name | |
self.tokenizer_model_name = tokenizer_model_name | |
self.placeholders = placeholders | |
self.model = model_transformer.from_pretrained(self.transformer_model_name) | |
self.tokenizer = tokenizer_func.from_pretrained(self.tokenizer_model_name) | |
self.classifier = pipeline_func(model=self.model, tokenizer=self.tokenizer, padding=True, truncation=True, **classifier_args) | |
self.parser = parser_func | |
def predict(self, val): | |
result = self.classifier(val) | |
return self.parser(self, result) | |
def ParseEmotionOutput(self, result): | |
label = result[0]['label'] | |
score = result[0]['score'] | |
output_func = st.info | |
if self.transformer_model_name == "cardiffnlp/twitter-roberta-base-sentiment": | |
if label == "LABEL_0": | |
label = "NEGATIVE" | |
output_func = st.error | |
elif label == "LABEL_2": | |
label = "POSITIVE" | |
output_func = st.success | |
else: | |
label = "NEUTRAL" | |
elif self.transformer_model_name == "finiteautomata/beto-sentiment-analysis": | |
if label == "NEG": | |
label = "NEGATIVE" | |
output_func = st.error | |
elif label == "POS": | |
label = "POSITIVE" | |
output_func = st.success | |
else: | |
label = "NEUTRAL" | |
elif self.transformer_model_name == "bhadresh-savani/distilbert-base-uncased-emotion": | |
if label == "sadness": | |
output_func = st.info | |
elif label == "joy": | |
output_func = st.success | |
elif label == "love": | |
output_func = st.success | |
elif label == "anger": | |
output_func = st.error | |
elif label == "fear": | |
output_func = st.info | |
elif label == "surprise": | |
output_func = st.error | |
label = label.upper() | |
elif self.transformer_model_name == "siebert/sentiment-roberta-large-english": | |
if label == "NEGATIVE": | |
output_func = st.error | |
elif label == "POSITIVE": | |
output_func = st.success | |
return label, score, output_func | |
def ParsePatentOutput(self, result): | |
return result | |
def emotion_model_change(): | |
st.session_state.emotion_model = ModelImplementation( | |
st.session_state.emotion_model_name, | |
AutoModelForSequenceClassification, | |
st.session_state.emotion_model_name, | |
AutoTokenizer, | |
pipeline, | |
ParseEmotionOutput, | |
classifier_args={ "task" : "sentiment-analysis" }, | |
placeholders=["@AmericanAir just landed - 3hours Late Flight - and now we need to wait TWENTY MORE MINUTES for a gate! I have patience but none for incompetence."] | |
) | |
if "emotion_model_name" not in st.session_state: | |
st.session_state.emotion_model_name = "cardiffnlp/twitter-roberta-base-sentiment" | |
emotion_model_change() | |
if "patent_data" not in st.session_state: | |
f = open('./data/val.json') | |
valData = json.load(f) | |
f.close() | |
patent_data = {} | |
for num, label, abstract, claim in zip(valData["patent_numbers"],valData["labels"], valData["abstracts"], valData["claims"]): | |
patent_data[num] = {"patent_number":num,"label":label,"abstract":abstract,"claim":claim} | |
st.session_state.patent_data = patent_data | |
st.session_state.patent_num = list(patent_data.keys())[0] | |
st.session_state.weight = 0.5 | |
st.session_state.patent_abstract_model = ModelImplementation( | |
'rk2546/uspto-patents-abstracts', | |
DistilBertForSequenceClassification, | |
'distilbert-base-uncased', | |
DistilBertTokenizerFast, | |
TextClassificationPipeline, | |
ParsePatentOutput, | |
classifier_args={"return_all_scores":True}, | |
) | |
print("Patent abstracts model initialized") | |
st.session_state.patent_claim_model = ModelImplementation( | |
'rk2546/uspto-patents-claims', | |
DistilBertForSequenceClassification, | |
'distilbert-base-uncased', | |
DistilBertTokenizerFast, | |
TextClassificationPipeline, | |
ParsePatentOutput, | |
classifier_args={"return_all_scores":True}, | |
) | |
print("Patent claims model initialized") | |
# Title | |
st.title("CSGY-6613 Project") | |
# Subtitle | |
st.markdown("_**Ryan Kim (rk2546)**_") | |
sentimentTab, patentTab = st.tabs([ | |
"Emotion Analysis [Milestone #2]", | |
"Patent Prediction [Milestone #3]" | |
]) | |
with sentimentTab: | |
st.subheader("Sentiment Analysis") | |
if "emotion_model" not in st.session_state: | |
st.write("Loading model...") | |
else: | |
model_option = st.selectbox( | |
"What sentiment analysis model do you want to use? NOTE: Lag may occur when loading a new model!", | |
emotion_model_names, | |
on_change=emotion_model_change, | |
key="emotion_model_name" | |
) | |
form = st.form(key='sentiment-analysis-form') | |
text_input = form.text_area( | |
"Enter some text for sentiment analysis! If you just want to test it out without entering anything, just press the \"Submit\" button and the model will look at the placeholder.", | |
placeholder=st.session_state.emotion_model.placeholders[0] | |
) | |
submit = form.form_submit_button('Submit') | |
if submit: | |
if text_input is None or len(text_input.strip()) == 0: | |
to_eval = st.session_state.emotion_model.placeholders[0] | |
else: | |
to_eval = text_input.strip() | |
label, score, output_func = st.session_state.emotion_model.predict(to_eval) | |
output_func("**{}**: {}".format(label,score)) | |
with patentTab: | |
st.subheader("USPTO Patent Evaluation") | |
st.markdown("Below are two inputs - one for an **ABSTRACT** and another for a list of **CLAIMS**. Enter both and select the \"Submit\" button to evaluate the patenteability of your idea.") | |
patent_select_list = list(st.session_state.patent_data.keys()) | |
patent_index_option = st.selectbox( | |
"Want to pre-populate with an existing patent? Select the index number of below.", | |
patent_select_list, | |
key="patent_num", | |
) | |
if "patent_abstract_model" not in st.session_state or "patent_claim_model" not in st.session_state: | |
st.write("Loading models...") | |
else: | |
with st.form(key='patent-form'): | |
col1, col2 = st.columns(2) | |
with col1: | |
abstract_input = st.text_area( | |
"Enter the abstract of the patent below", | |
placeholder=st.session_state.patent_data[st.session_state.patent_num]["abstract"], | |
height=200 | |
) | |
with col2: | |
claim_input = st.text_area( | |
"Enter the claims of the patent below", | |
placeholder=st.session_state.patent_data[st.session_state.patent_num]["claim"], | |
height=200 | |
) | |
weight_val = st.slider( | |
"How much do the abstract and claims weight when aggregating a total softmax score?", | |
min_value=-1.0, | |
max_value=1.0, | |
value=0.5, | |
) | |
submit = st.form_submit_button('Submit') | |
if submit: | |
is_custom = False | |
if abstract_input is None or len(abstract_input.strip()) == 0: | |
abstract_to_eval = st.session_state.patent_data[st.session_state.patent_num]["abstract"].strip() | |
else: | |
abstract_to_eval = abstract_input.strip() | |
is_custom = True | |
if claim_input is None or len(claim_input.strip()) == 0: | |
claim_to_eval = st.session_state.patent_data[st.session_state.patent_num]["claim"].strip() | |
else: | |
claim_to_eval = claim_input.strip() | |
is_custom = True | |
abstract_response = st.session_state.patent_abstract_model.predict(abstract_to_eval) | |
claim_response = st.session_state.patent_claim_model.predict(claim_to_eval) | |
claim_weight = (1+weight_val)/2 | |
abstract_weight = 1-claim_weight | |
aggregate_score = [ | |
{'label':'REJECTED','score':abstract_response[0][0]['score']*abstract_weight + claim_response[0][0]['score']*claim_weight}, | |
{'label':'ACCEPTED','score':abstract_response[0][1]['score']*abstract_weight + claim_response[0][1]['score']*claim_weight} | |
] | |
aggregate_score_sorted = sorted(aggregate_score, key=lambda d: d['score'], reverse=True) | |
answerCol1, answerCol2, answerCol3 = st.columns(3) | |
with answerCol1: | |
st.slider( | |
"Abstract Acceptance Likelihood", | |
min_value=0.0, | |
max_value=100.0, | |
value=abstract_response[0][1]["score"]*100.0, | |
disabled=True | |
) | |
with answerCol2: | |
output_func = st.info | |
if aggregate_score_sorted[0]["label"] == "REJECTED": | |
output_func = st.error | |
else: | |
output_func = st.success | |
output_func(""" | |
**Final Rating: {}** | |
{}% | |
""".format(aggregate_score_sorted[0]["label"],aggregate_score_sorted[0]["score"]*100.0)) | |
with answerCol3: | |
st.slider( | |
"Claim Acceptance Likelihood", | |
min_value=0.0, | |
max_value=100.0, | |
value=claim_response[0][1]["score"]*100.0, | |
disabled=True | |
) | |
#if not is_custom: | |
# st.markdown('**Original Score:**') | |
# st.markdown(st.session_state.patent_data[st.session_state.patent_num]["label"]) | |
st.write("") |