Spaces:
Runtime error
Runtime error
import json | |
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline | |
from tokenizers import Tokenizer | |
def fake_hash(x): | |
return 0 | |
def initialize(): | |
model_name = 'distilbert-base-cased' | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained('./final_model') | |
the_pipeline = TextClassificationPipeline( | |
model=model, | |
tokenizer=tokenizer, | |
return_all_scores=True, | |
device=-1 | |
) | |
cat_mapping_file = open('cat_mapping.json', 'r') | |
cat_name_mapping_file = open('cat_name_mapping.json', 'r') | |
cat_mapping = json.load(cat_mapping_file) | |
cat_name_mapping = json.load(cat_name_mapping_file) | |
return the_pipeline, cat_mapping, cat_name_mapping | |
def get_top(the_pipeline, cat_mapping, title, summary, thresh=0.95): | |
if title == '' or summary == '': | |
return 'Not enough data to compute.' | |
question = title + ' || ' + summary | |
if len(question) > 4000: | |
return 'Your input is supsiciously long, try something shorter.' | |
try: | |
result = the_pipeline(question)[0] | |
result.sort(key=lambda x: -x['score']) | |
current_sum = 0 | |
scores = [] | |
for score in result: | |
scores.append(score) | |
current_sum += score['score'] | |
if current_sum >= thresh: | |
break | |
for i in range(len(result)): | |
result[i]['label'] = cat_mapping[result[i]['label'][6:]] | |
return scores | |
except BaseException: | |
return 'Something unexpected happened, I\'m sorry. Try again.' | |
st.markdown('## Welcome to the CS article classification page!') | |
st.markdown('### What\'s below is pretty much self-explanatory.') | |
img_source = 'https://sun9-55.userapi.com/impg/azBQ_VTvbgEVonbL9hhFEpwyKAhjAtpVl4H2GQ/I4Vq0H6c3UM.jpg' | |
img_params = 'size=1200x900&quality=96&sign=f42419d9cdbf6fe55016fb002e4e85ae&type=album' | |
st.markdown( | |
f'<img src="{img_source}?{img_params}" width="70%"><br>', | |
unsafe_allow_html=True | |
) | |
title = st.text_input( | |
'Please, insert the title of the CS article you are interested in.', | |
placeholder='The title (e. g. Incorporating alien technologies in CV)' | |
) | |
summary = st.text_area( | |
'Now, please, insert the summary of the CS article you are interested in.', | |
height=250, placeholder='The summary itself.' | |
) | |
the_pipeline, cat_mapping, cat_name_mapping = initialize() | |
scores = get_top(the_pipeline, cat_mapping, title, summary) | |
if isinstance(scores, str): | |
st.markdown(scores) | |
else: | |
for score in scores: | |
percent = round(score['score'] * 100, 2) | |
category_short = score['label'] | |
category_full = cat_name_mapping[category_short] | |
st.markdown(f'I\'m {percent}\% certain that the article is from the {category_short} category, which is "{category_full}"') | |