import json import os import streamlit as st import pickle from transformers import AutoTokenizer, BertForSequenceClassification, pipeline from sklearn.feature_extraction.text import TfidfVectorizer def load_models(): st.session_state.loaded = True with open('models/tfidf_vectorizer_svm_model_2_classes_gpt_chatgpt_detection_tfidf_bg_0.886_F1_score.pkl', 'rb') as f: st.session_state.tfidf_vectorizer_disinformation = pickle.load(f) with open('models/tfidf_vectorizer_untrue_inform_detection_tfidf_bg_0.96_F1_score.pkl', 'rb') as f: st.session_state.tfidf_vectorizer_untrue_inf = pickle.load(f) with open('models/svm_model_2_classes_gpt_chatgpt_detection_tfidf_bg_0.886_F1_score.pkl', 'rb') as f: st.session_state.gpt_detector = pickle.load(f) with open('models/SVM_model_untrue_inform_detection_tfidf_bg_0.96_F1_score.pkl', 'rb') as f: st.session_state.untrue_detector = pickle.load(f) st.session_state.bert = pipeline(task="text-classification", model=BertForSequenceClassification.from_pretrained("TRACES/private-bert", use_auth_token=os.environ['ACCESS_TOKEN'], num_labels=2), tokenizer=AutoTokenizer.from_pretrained("TRACES/private-bert", use_auth_token=os.environ['ACCESS_TOKEN'])) def load_content(): with open('resource/page_content.json', encoding='utf8') as json_file: return json.load(json_file) def switch_lang(lang): if 'lang' in st.session_state: if lang == 'bg': st.session_state.lang = 'bg' else: st.session_state.lang = 'en' if 'lang' not in st.session_state: st.session_state.lang = 'bg' if all([ 'gpt_detector_result' not in st.session_state, 'untrue_detector_result' not in st.session_state, 'bert_result' not in st.session_state ]): st.session_state.gpt_detector_result = '' st.session_state.gpt_detector_probability = [1, 0] st.session_state.untrue_detector_result = '' st.session_state.untrue_detector_probability = 1 st.session_state.bert_result = [{'label': '', 'score': 1}] content = load_content() if 'loaded' not in st.session_state: load_models() ####################################################################################################################### st.title(content['title'][st.session_state.lang]) col1, col2, col3 = st.columns([1, 1, 10]) with col1: st.button( label='EN', key='en', on_click=switch_lang, args=['en'] ) with col2: st.button( label='BG', key='bg', on_click=switch_lang, args=['bg'] ) if 'agree' not in st.session_state: st.session_state.agree = False if st.session_state.agree: tab_tool, tab_terms = st.tabs([content['tab_tool'][st.session_state.lang], content['tab_terms'][st.session_state.lang]]) with tab_tool: user_input = st.text_area(content['textbox_title'][st.session_state.lang], content['text_placeholder'][st.session_state.lang]).strip('\n') if st.button(content['analyze_button'][st.session_state.lang]): user_tfidf_disinformation = st.session_state.tfidf_vectorizer_disinformation.transform([user_input]) st.session_state.gpt_detector_result = st.session_state.gpt_detector.predict(user_tfidf_disinformation)[0] st.session_state.gpt_detector_probability = st.session_state.gpt_detector.predict_proba(user_tfidf_disinformation)[0] user_tfidf_untrue_inf = st.session_state.tfidf_vectorizer_untrue_inf.transform([user_input]) st.session_state.untrue_detector_result = st.session_state.untrue_detector.predict(user_tfidf_untrue_inf)[0] st.session_state.untrue_detector_probability = st.session_state.untrue_detector.predict_proba(user_tfidf_untrue_inf)[0] st.session_state.untrue_detector_probability = max(st.session_state.untrue_detector_probability[0], st.session_state.untrue_detector_probability[1]) st.session_state.bert_result = st.session_state.bert(user_input) if st.session_state.gpt_detector_result == 1: st.warning(content['gpt_getect_yes'][st.session_state.lang] + str(round(st.session_state.gpt_detector_probability[1] * 100, 2)) + content['gpt_yes_proba'][st.session_state.lang], icon="⚠️") else: st.success(content['gpt_getect_no'][st.session_state.lang] + str(round(st.session_state.gpt_detector_probability[0] * 100, 2)) + content['gpt_no_proba'][st.session_state.lang], icon="✅") if st.session_state.untrue_detector_result == 1: st.warning(content['untrue_getect_yes'][st.session_state.lang] + str(round(st.session_state.untrue_detector_probability * 100, 2)) + content['untrue_yes_proba'][st.session_state.lang], icon="⚠️") else: st.success(content['untrue_getect_no'][st.session_state.lang] + str(round(st.session_state.untrue_detector_probability * 100, 2)) + content['untrue_no_proba'][st.session_state.lang], icon="✅") if st.session_state.bert_result[0]['label'] == 'LABEL_1': st.warning(content['bert_yes_1'][st.session_state.lang] + str(round(st.session_state.bert_result[0]['score'] * 100, 2)) + content['bert_yes_2'][st.session_state.lang], icon = "⚠️") else: st.success(content['bert_no_1'][st.session_state.lang] + str(round(st.session_state.bert_result[0]['score'] * 100, 2)) + content['bert_no_2'][st.session_state.lang], icon="✅") st.info(content['disinformation_definition'][st.session_state.lang], icon="ℹ️") with tab_terms: st.write(content['disclaimer'][st.session_state.lang]) else: st.write(content['disclaimer_title'][st.session_state.lang]) st.write(content['disclaimer'][st.session_state.lang]) if st.button(content['disclaimer_agree_text'][st.session_state.lang]): st.session_state.agree = True st.experimental_rerun()