import json import streamlit as st from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline from tokenizers import Tokenizer def fake_hash(x): return 0 @st.cache(hash_funcs={Tokenizer: fake_hash}, suppress_st_warning=True, allow_output_mutation=True) def initialize(): model_name = 'distilbert-base-cased' tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained('./final_model') the_pipeline = TextClassificationPipeline( model=model, tokenizer=tokenizer, return_all_scores=True, device=-1 ) cat_mapping_file = open('cat_mapping.json', 'r') cat_name_mapping_file = open('cat_name_mapping.json', 'r') cat_mapping = json.load(cat_mapping_file) cat_name_mapping = json.load(cat_name_mapping_file) return the_pipeline, cat_mapping, cat_name_mapping def get_top(the_pipeline, cat_mapping, title, summary, thresh=0.95): if title == '' or summary == '': return 'Not enough data to compute.' question = title + ' || ' + summary if len(question) > 4000: return 'Your input is supsiciously long, try something shorter.' try: result = the_pipeline(question)[0] result.sort(key=lambda x: -x['score']) current_sum = 0 scores = [] for score in result: scores.append(score) current_sum += score['score'] if current_sum >= thresh: break for i in range(len(result)): result[i]['label'] = cat_mapping[result[i]['label'][6:]] return scores except BaseException: return 'Something unexpected happened, I\'m sorry. Try again.' st.markdown('## Welcome to the CS article classification page!') st.markdown('### What\'s below is pretty much self-explanatory.') img_source = 'https://sun9-55.userapi.com/impg/azBQ_VTvbgEVonbL9hhFEpwyKAhjAtpVl4H2GQ/I4Vq0H6c3UM.jpg' img_params = 'size=1200x900&quality=96&sign=f42419d9cdbf6fe55016fb002e4e85ae&type=album' st.markdown( f'
', unsafe_allow_html=True ) title = st.text_input( 'Please, insert the title of the CS article you are interested in.', placeholder='The title (e. g. Incorporating alien technologies in CV)' ) summary = st.text_area( 'Now, please, insert the summary of the CS article you are interested in.', height=250, placeholder='The summary itself.' ) the_pipeline, cat_mapping, cat_name_mapping = initialize() scores = get_top(the_pipeline, cat_mapping, title, summary) if isinstance(scores, str): st.markdown(scores) else: for score in scores: percent = round(score['score'] * 100, 2) category_short = score['label'] category_full = cat_name_mapping[category_short] st.markdown(f'I\'m {percent}\% certain that the article is from the {category_short} category, which is "{category_full}"')