TRACES commited on
Commit
a17012f
1 Parent(s): 9ec462e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +12 -36
main.py CHANGED
@@ -11,21 +11,15 @@ from sklearn.feature_extraction.text import TfidfVectorizer
11
  def load_models():
12
  st.session_state.loaded = True
13
 
14
- # with open('models/tfidf_vectorizer_svm_model_2_classes_gpt_chatgpt_detection_tfidf_bg_0.886_F1_score.pkl', 'rb') as f:
15
- # st.session_state.tfidf_vectorizer_disinformation = pickle.load(f)
16
-
17
  with open('models/tfidf_vectorizer_untrue_inform_detection_tfidf_bg_0.96_F1_score_3Y_N_Q1_082023.pkl', 'rb') as f:
18
  st.session_state.tfidf_vectorizer_untrue_inf = pickle.load(f)
19
 
20
- # with open('models/svm_model_2_classes_gpt_chatgpt_detection_tfidf_bg_0.886_F1_score.pkl', 'rb') as f:
21
- # st.session_state.gpt_detector = pickle.load(f)
22
-
23
  with open('models/SVM_model_untrue_inform_detection_tfidf_bg_0.96_F1_score_3Y_N_Q1_082023.pkl', 'rb') as f:
24
  st.session_state.untrue_detector = pickle.load(f)
25
 
26
  st.session_state.bert_disinfo = pipeline(task="text-classification",
27
- model=BertForSequenceClassification.from_pretrained("TRACES/private-bert", use_auth_token=os.environ['ACCESS_TOKEN'], num_labels=2),
28
- tokenizer=AutoTokenizer.from_pretrained("TRACES/private-bert", use_auth_token=os.environ['ACCESS_TOKEN']))
29
  st.session_state.bert_gpt = pipeline(task="text-classification",
30
  model=BertForSequenceClassification.from_pretrained("usmiva/bert-deepfake-bg", num_labels=2),
31
  tokenizer=AutoTokenizer.from_pretrained("usmiva/bert-deepfake-bg"))
@@ -52,16 +46,13 @@ if all([
52
  'untrue_detector_result' not in st.session_state,
53
  'bert_disinfo_result' not in st.session_state
54
  ]):
55
- # st.session_state.gpt_detector_result = ''
56
- # st.session_state.gpt_detector_probability = [1, 0]
57
-
58
 
59
  st.session_state.untrue_detector_result = ''
60
  st.session_state.untrue_detector_probability = 1
61
 
62
  st.session_state.bert_disinfo_result = [{'label': '', 'score': 1}]
63
 
64
- st.session_state.bert_gpt_result = [{'label': '', 'score': 1}]
65
 
66
  content = load_content()
67
  if 'loaded' not in st.session_state:
@@ -98,10 +89,7 @@ if st.session_state.agree:
98
  content['text_placeholder'][st.session_state.lang]).strip('\n')
99
 
100
  if st.button(content['analyze_button'][st.session_state.lang]):
101
- # user_tfidf_disinformation = st.session_state.tfidf_vectorizer_disinformation.transform([user_input])
102
- # st.session_state.gpt_detector_result = st.session_state.gpt_detector.predict(user_tfidf_disinformation)[0]
103
- # st.session_state.gpt_detector_probability = st.session_state.gpt_detector.predict_proba(user_tfidf_disinformation)[0]
104
-
105
 
106
  user_tfidf_untrue_inf = st.session_state.tfidf_vectorizer_untrue_inf.transform([user_input])
107
  st.session_state.untrue_detector_result = st.session_state.untrue_detector.predict(user_tfidf_untrue_inf)[0]
@@ -110,19 +98,16 @@ if st.session_state.agree:
110
 
111
  st.session_state.bert_disinfo_result = st.session_state.bert_disinfo(user_input)
112
 
113
- st.session_state.bert_gpt_result = st.session_state.bert_gpt(user_input)
114
-
115
 
116
-
117
- # if st.session_state.gpt_detector_result == 1:
118
- # st.warning(content['gpt_getect_yes'][st.session_state.lang] +
119
- # str(round(st.session_state.gpt_detector_probability[1] * 100, 2)) +
120
- # content['gpt_yes_proba'][st.session_state.lang], icon="⚠️")
121
- # else:
122
- # st.success(content['gpt_getect_no'][st.session_state.lang] +
123
- # str(round(st.session_state.gpt_detector_probability[0] * 100, 2)) +
124
- # content['gpt_no_proba'][st.session_state.lang], icon="✅")
125
 
 
 
 
 
 
 
 
 
126
 
127
  if st.session_state.untrue_detector_result == 0:
128
  st.warning(content['untrue_getect_yes'][st.session_state.lang] +
@@ -142,15 +127,6 @@ if st.session_state.agree:
142
  str(round(st.session_state.bert_disinfo_result[0]['score'] * 100, 2)) +
143
  content['bert_no_2'][st.session_state.lang], icon="✅")
144
 
145
- if st.session_state.bert_gpt_result[0]['label'] == 'LABEL_1':
146
- st.warning(content['bert_gpt'][st.session_state.lang] +
147
- str(round(st.session_state.bert_gpt_result[0]['score'] * 100, 2)) +
148
- content['bert_gpt_prob'][st.session_state.lang], icon = "⚠️")
149
- else:
150
- st.success(content['bert_human'][st.session_state.lang] +
151
- str(round(st.session_state.bert_gpt_result[0]['score'] * 100, 2)) +
152
- content['bert_human_prob'][st.session_state.lang], icon="✅")
153
-
154
 
155
  st.info(content['disinformation_definition'][st.session_state.lang], icon="ℹ️")
156
 
 
11
  def load_models():
12
  st.session_state.loaded = True
13
 
 
 
 
14
  with open('models/tfidf_vectorizer_untrue_inform_detection_tfidf_bg_0.96_F1_score_3Y_N_Q1_082023.pkl', 'rb') as f:
15
  st.session_state.tfidf_vectorizer_untrue_inf = pickle.load(f)
16
 
 
 
 
17
  with open('models/SVM_model_untrue_inform_detection_tfidf_bg_0.96_F1_score_3Y_N_Q1_082023.pkl', 'rb') as f:
18
  st.session_state.untrue_detector = pickle.load(f)
19
 
20
  st.session_state.bert_disinfo = pipeline(task="text-classification",
21
+ model=BertForSequenceClassification.from_pretrained("usmiva/bert-desinform-bg", num_labels=2),
22
+ tokenizer=AutoTokenizer.from_pretrained("usmiva/bert-desinform-bg"))
23
  st.session_state.bert_gpt = pipeline(task="text-classification",
24
  model=BertForSequenceClassification.from_pretrained("usmiva/bert-deepfake-bg", num_labels=2),
25
  tokenizer=AutoTokenizer.from_pretrained("usmiva/bert-deepfake-bg"))
 
46
  'untrue_detector_result' not in st.session_state,
47
  'bert_disinfo_result' not in st.session_state
48
  ]):
49
+ st.session_state.bert_gpt_result = [{'label': '', 'score': 1}]
 
 
50
 
51
  st.session_state.untrue_detector_result = ''
52
  st.session_state.untrue_detector_probability = 1
53
 
54
  st.session_state.bert_disinfo_result = [{'label': '', 'score': 1}]
55
 
 
56
 
57
  content = load_content()
58
  if 'loaded' not in st.session_state:
 
89
  content['text_placeholder'][st.session_state.lang]).strip('\n')
90
 
91
  if st.button(content['analyze_button'][st.session_state.lang]):
92
+ st.session_state.bert_gpt_result = st.session_state.bert_gpt(user_input)
 
 
 
93
 
94
  user_tfidf_untrue_inf = st.session_state.tfidf_vectorizer_untrue_inf.transform([user_input])
95
  st.session_state.untrue_detector_result = st.session_state.untrue_detector.predict(user_tfidf_untrue_inf)[0]
 
98
 
99
  st.session_state.bert_disinfo_result = st.session_state.bert_disinfo(user_input)
100
 
 
 
101
 
 
 
 
 
 
 
 
 
 
102
 
103
+ if st.session_state.bert_gpt_result[0]['label'] == 'LABEL_1':
104
+ st.warning(content['bert_gpt'][st.session_state.lang] +
105
+ str(round(st.session_state.bert_gpt_result[0]['score'] * 100, 2)) +
106
+ content['bert_gpt_prob'][st.session_state.lang], icon = "⚠️")
107
+ else:
108
+ st.success(content['bert_human'][st.session_state.lang] +
109
+ str(round(st.session_state.bert_gpt_result[0]['score'] * 100, 2)) +
110
+ content['bert_human_prob'][st.session_state.lang], icon="✅")
111
 
112
  if st.session_state.untrue_detector_result == 0:
113
  st.warning(content['untrue_getect_yes'][st.session_state.lang] +
 
127
  str(round(st.session_state.bert_disinfo_result[0]['score'] * 100, 2)) +
128
  content['bert_no_2'][st.session_state.lang], icon="✅")
129
 
 
 
 
 
 
 
 
 
 
130
 
131
  st.info(content['disinformation_definition'][st.session_state.lang], icon="ℹ️")
132