Spaces:
Running
Running
SonFox2920
commited on
Commit
•
eee5090
1
Parent(s):
eddc94e
Upload 6 files
Browse files- Mbert.py +57 -0
- Model/classifier.pt +3 -0
- config.py +1 -0
- predictor.py +249 -0
- requirements.txt +20 -0
- utilities.py +103 -0
Mbert.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
|
5 |
+
class SentencePairDataset(Dataset):
|
6 |
+
def __init__(self, sentence_pairs, labels, tokenizer, max_length):
|
7 |
+
self.sentence_pairs = sentence_pairs
|
8 |
+
self.labels = labels
|
9 |
+
self.tokenizer = tokenizer
|
10 |
+
self.max_length = max_length
|
11 |
+
|
12 |
+
def __len__(self):
|
13 |
+
return len(self.sentence_pairs)
|
14 |
+
|
15 |
+
def __getitem__(self, idx):
|
16 |
+
sentence1, sentence2 = self.sentence_pairs[idx]
|
17 |
+
label = self.labels[idx]
|
18 |
+
encoding = self.tokenizer.encode_plus(
|
19 |
+
sentence1,
|
20 |
+
text_pair=sentence2,
|
21 |
+
add_special_tokens=True,
|
22 |
+
max_length=self.max_length,
|
23 |
+
return_token_type_ids=False,
|
24 |
+
padding="max_length",
|
25 |
+
return_attention_mask=True,
|
26 |
+
return_tensors="pt",
|
27 |
+
truncation=True,
|
28 |
+
)
|
29 |
+
return {
|
30 |
+
"input_ids": encoding["input_ids"].flatten(),
|
31 |
+
"attention_mask": encoding["attention_mask"].flatten(),
|
32 |
+
"label": torch.tensor(label, dtype=torch.long),
|
33 |
+
}
|
34 |
+
|
35 |
+
class MBERTClassifier(nn.Module):
|
36 |
+
def __init__(self, mbert, num_classes):
|
37 |
+
super(MBERTClassifier, self).__init__()
|
38 |
+
self.mbert = mbert
|
39 |
+
self.layer_norm = nn.LayerNorm(self.mbert.config.hidden_size)
|
40 |
+
self.dropout = nn.Dropout(0.2)
|
41 |
+
self.batch_norm = nn.BatchNorm1d(self.mbert.config.hidden_size)
|
42 |
+
self.linear = nn.LazyLinear(num_classes)
|
43 |
+
self.activation = nn.ELU()
|
44 |
+
|
45 |
+
def forward(self, input_ids, attention_mask):
|
46 |
+
_, pooled_output = self.mbert(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)
|
47 |
+
norm_output = self.layer_norm(pooled_output)
|
48 |
+
batch_norm_output = self.batch_norm(norm_output)
|
49 |
+
logits = self.linear(batch_norm_output)
|
50 |
+
activated_output = self.activation(logits)
|
51 |
+
dropout_output = self.dropout(activated_output)
|
52 |
+
return dropout_output
|
53 |
+
|
54 |
+
def predict_proba(self, input_ids, attention_mask):
|
55 |
+
logits = self.forward(input_ids, attention_mask)
|
56 |
+
probabilities = torch.softmax(logits, dim=-1)
|
57 |
+
return probabilities
|
Model/classifier.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:731231da2eeb4ed18488952d6f79a43aae9ce60da24f6ec3922a06ef1e7eb556
|
3 |
+
size 711526350
|
config.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
hf_token = "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"
|
predictor.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
from utilities import predict
|
4 |
+
|
5 |
+
st.set_page_config(layout="wide")
|
6 |
+
|
7 |
+
def result_form(result, user_label):
|
8 |
+
if 'error' in result:
|
9 |
+
st.error(result['error'])
|
10 |
+
else:
|
11 |
+
st.subheader('Label probabilities:')
|
12 |
+
labels = ['SUPPORTED', 'REFUTED', 'NEI']
|
13 |
+
probabilities = {lbl: result['probabilities'].get(lbl, 0) for lbl in labels}
|
14 |
+
|
15 |
+
df = pd.DataFrame({label: [probabilities[label]] for label in labels})
|
16 |
+
|
17 |
+
def apply_background(val, label):
|
18 |
+
color = ''
|
19 |
+
if label == 'NEI':
|
20 |
+
color = '#FFD700'
|
21 |
+
elif label == 'REFUTED':
|
22 |
+
color = '#DC143C'
|
23 |
+
else: # Supported
|
24 |
+
color = '#7FFF00'
|
25 |
+
return f'background-color: {color}; color: black'
|
26 |
+
|
27 |
+
df_styled = df.style.apply(lambda x: [apply_background(x[name], name) for name in df.columns], axis=1)
|
28 |
+
df_styled = df_styled.format("{:.2%}")
|
29 |
+
|
30 |
+
st.dataframe(df_styled, hide_index=True, use_container_width=True)
|
31 |
+
|
32 |
+
predicted_label = max(probabilities, key=probabilities.get)
|
33 |
+
return probabilities[user_label] < 0.35 or predicted_label != user_label
|
34 |
+
|
35 |
+
def create_expander_with_check_button(label, title, context, predict_func):
|
36 |
+
claim_key = f"{label}_input"
|
37 |
+
evidence_key = f"{label}_evidence_selected"
|
38 |
+
label_e_ops = f"{label}_options"
|
39 |
+
evidence_input_key = f"{label}_evidence_input"
|
40 |
+
|
41 |
+
if label_e_ops not in st.session_state:
|
42 |
+
st.session_state[label_e_ops] = []
|
43 |
+
|
44 |
+
annotated_data = st.session_state['annotated_data']
|
45 |
+
with st.expander(label, expanded=True):
|
46 |
+
claim = st.text_input(f'Claim {label.upper()}', max_chars=500, key=claim_key)
|
47 |
+
if claim:
|
48 |
+
if not annotated_data[((annotated_data['Claim'] == claim) & (annotated_data['Label'] == label) & (annotated_data['Title'] == title))].empty:
|
49 |
+
st.warning(f"This claim with label '{label}' and title '{title}' already exists.")
|
50 |
+
else:
|
51 |
+
result = predict_func(context, claim)
|
52 |
+
if result_form(result, label):
|
53 |
+
if label != 'NEI': # NEI does not require evidence
|
54 |
+
# Display available sentences as options
|
55 |
+
sentences = context.split('.')
|
56 |
+
sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
|
57 |
+
st.multiselect(f"Selected evidence for {label}", sentences, default=st.session_state[label_e_ops], key=evidence_key)
|
58 |
+
else:
|
59 |
+
st.warning(f"The predicted probability for label '{label}' is too high or the predicted label is different. Please modify the claim.")
|
60 |
+
else:
|
61 |
+
st.warning("Please enter a claim.")
|
62 |
+
|
63 |
+
|
64 |
+
if 'annotated_data' not in st.session_state:
|
65 |
+
st.session_state['annotated_data'] = pd.DataFrame(columns=['Context', 'Claim', 'Label', 'Evidence', 'Title', 'Link'])
|
66 |
+
|
67 |
+
annotated_data = st.session_state['annotated_data']
|
68 |
+
|
69 |
+
def save_data(context, default_title, default_link):
|
70 |
+
annotated_data = st.session_state['annotated_data']
|
71 |
+
error = 'success'
|
72 |
+
|
73 |
+
for label in ['NEI', 'REFUTED', 'SUPPORTED']:
|
74 |
+
claim_key = f"{label}_input"
|
75 |
+
evidence_key = f"{label}_evidence_selected"
|
76 |
+
|
77 |
+
if st.session_state.get(claim_key, ''):
|
78 |
+
claim = st.session_state[claim_key]
|
79 |
+
evidence = st.session_state.get(evidence_key, [])
|
80 |
+
if label == 'NEI':
|
81 |
+
evidence = [] # No evidence required for NEI
|
82 |
+
if not annotated_data[((annotated_data['Claim'] == claim) & (annotated_data['Label'] == label) & (annotated_data['Title'] == default_title))].empty:
|
83 |
+
error = 'duplicate'
|
84 |
+
else:
|
85 |
+
annotated_data.loc[len(annotated_data)] = [context, claim, label, evidence, default_title, default_link]
|
86 |
+
|
87 |
+
st.session_state['annotated_data'] = annotated_data
|
88 |
+
return error
|
89 |
+
|
90 |
+
def enough_claims_entered(title):
|
91 |
+
annotated_data = st.session_state['annotated_data']
|
92 |
+
nei_claims = annotated_data[(annotated_data['Label'] == 'NEI') & (annotated_data['Title'] == title)].shape[0]
|
93 |
+
refuted_claims = annotated_data[(annotated_data['Label'] == 'REFUTED') & (annotated_data['Title'] == title)].shape[0]
|
94 |
+
supported_claims = annotated_data[(annotated_data['Label'] == 'SUPPORTED') & (annotated_data['Title'] == title)].shape[0]
|
95 |
+
|
96 |
+
return nei_claims >= 2 and refuted_claims >= 2 and supported_claims >= 2
|
97 |
+
|
98 |
+
def predictor_app():
|
99 |
+
tab0, tab1, tab2 = st.tabs(["Mission", "Annotate", "Save"])
|
100 |
+
st.sidebar.title("Dataset Upload")
|
101 |
+
uploaded_file = st.sidebar.file_uploader("Upload CSV file", type=["csv"])
|
102 |
+
with tab0:
|
103 |
+
c2 = st.container(border=True)
|
104 |
+
with c2:
|
105 |
+
st.title("Nhiệm vụ")
|
106 |
+
st.write("""
|
107 |
+
Nhiệm vụ của bạn là tạo các câu nhận định cho các nhãn sau: <span style='color:#7FFF00'>SUPPORTED</span> (Được hỗ trợ), <span style='color:#DC143C'>REFUTED</span> (Bị phủ nhận) hoặc <span style='color:#FFD700'>NEI</span> (Không đủ thông tin) dựa trên đoạn văn bản được cung cấp trước đó. Dưới đây là các bước để thực hiện nhiệm vụ này:
|
108 |
+
|
109 |
+
1. **Đọc đoạn văn bản (context)**: hiểu nội dung, thông tin của đoạn văn bản được cung cấp.
|
110 |
+
|
111 |
+
2. **Nhập câu nhận định**: dựa trên thông tin, nội dung đó, bạn hãy viết câu nhận định cho đoạn văn đó.
|
112 |
+
|
113 |
+
3. **Phân loại câu nhận định**: sau khi đã viết xong câu nhận định, bạn hãy sắp xếp nó vào một trong ba nhãn sau:
|
114 |
+
- <span style='color:#7FFF00'>SUPPORTED</span> (được hỗ trợ): đây là nhãn mà khi câu nhận định của bạn là chính xác theo những thông tin nội dung của đoạn văn bản (context) cung cấp.
|
115 |
+
- <span style='color:#DC143C'>REFUTED</span> (bị bác bỏ): ngược lại với “<span style='color:#7FFF00'>SUPPORTED</span>”, đây là nhãn mà khi câu nhận định của bạn là sai so với những thông tin nội dung của đoạn văn bản (context) đưa ra.
|
116 |
+
- <span style='color:#FFD700'>NEI</span> (không đủ thông tin): khi thông tin mà câu nhận định của bạn đưa ra chưa thể xác định được đúng hoặc sai dựa trên thông tin của đoạn văn bản (context) cung cấp; hoặc ít nhất một thông tin mà bạn đưa ra trong câu nhận định không xuất hiện ở đoạn văn bản (context).
|
117 |
+
|
118 |
+
4. **Chọn bằng chứng (Evidence)**: đối với hai nhãn <span style='color:#7FFF00'>SUPPORTED</span> & <span style='color:#DC143C'>REFUTED</span>, các bạn sẽ chọn bằng chứng (evidence) cho câu nhận định. Nghĩa là các bạn sẽ chọn những thông tin trong đoạn văn bản (context) để dựa theo đó để chứng minh rằng câu nhận định của bạn là đúng (đối với “<span style='color:#7FFF00'>SUPPORTED</span>”) hoặc sai (đối với “<span style='color:#DC143C'>REFUTED</span>”). Các bạn chỉ chọn những thông tin cần thiết (không chọn hết cả câu hoặc cả đoạn văn).
|
119 |
+
|
120 |
+
5. **Lưu dữ liệu**: Sau khi đã nhập đủ 2 câu (mỗi nhãn một câu), bạn nhấn vào nút “Save” bên dưới để lưu lại các câu đó. Sau khi lưu hoàn tất, thông báo sẽ hiện và các câu đã viết trước đó sẽ được xóa để bạn có thể nhập câu mới.
|
121 |
+
|
122 |
+
6. **Di chuyển đến đoạn văn bản (context) khác**: bạn có thể di chuyển qua lại giữa các context nhưng chỉ khi bạn đã tạo tối thiểu 6 câu nhận định (mỗi nhãn tối thiểu 2 câu).
|
123 |
+
|
124 |
+
Xem chi tiết hướng dẫn cách đặt câu nhận định [tại đây](https://docs.google.com/document/d/121GHPAOFa4_fhmXDGJFYCrmsStcXYc7H/edit).
|
125 |
+
|
126 |
+
Lấy các đoạn văn bản (context): [tại đây](https://drive.google.com/drive/folders/1bbW7qiglBZHvGs5oNF-s_eac09t5oWOW).
|
127 |
+
""", unsafe_allow_html=True)
|
128 |
+
|
129 |
+
if uploaded_file is None:
|
130 |
+
st.sidebar.warning("Please upload a CSV file.")
|
131 |
+
else:
|
132 |
+
df = pd.read_csv(uploaded_file)
|
133 |
+
require_columns = ['Summary', 'ID', 'Title', 'URL']
|
134 |
+
|
135 |
+
if not set(require_columns).issubset(df.columns):
|
136 |
+
st.error("Error: Upload Dataset is missing required columns.")
|
137 |
+
st.stop()
|
138 |
+
else:
|
139 |
+
max_index = len(df) - 1
|
140 |
+
current_index = st.session_state.get("current_index", 0)
|
141 |
+
current_row = df.iloc[current_index]
|
142 |
+
|
143 |
+
default_context = current_row['Summary']
|
144 |
+
default_ID = current_row['ID']
|
145 |
+
default_title = current_row['Title']
|
146 |
+
default_link = current_row['URL']
|
147 |
+
|
148 |
+
with tab1:
|
149 |
+
if uploaded_file is None:
|
150 |
+
st.error("Dataset not found")
|
151 |
+
else:
|
152 |
+
st.title("Fact Checking annotation app")
|
153 |
+
c1 = st.container(border=True)
|
154 |
+
with c1:
|
155 |
+
ten_file, id_cau, chu_de, link = st.columns(4)
|
156 |
+
with ten_file:
|
157 |
+
st.text_input("Tên File:", value=uploaded_file.name, disabled=True)
|
158 |
+
with id_cau:
|
159 |
+
st.text_input("ID Context: ", value=default_ID, disabled=True)
|
160 |
+
with chu_de:
|
161 |
+
st.text_input("Chủ đề:", value=default_title, disabled=True)
|
162 |
+
with link:
|
163 |
+
st.text_input("Link:", value=default_link, disabled=True)
|
164 |
+
|
165 |
+
c3 = st.container(border=True)
|
166 |
+
with c3:
|
167 |
+
left_column, right_column = st.columns([0.45, 0.55])
|
168 |
+
with left_column:
|
169 |
+
st.title("Context")
|
170 |
+
c3_1 = st.container(border=True, height=770)
|
171 |
+
with c3_1:
|
172 |
+
st.write(f'{default_context}')
|
173 |
+
|
174 |
+
with right_column:
|
175 |
+
st.title("Claim")
|
176 |
+
c3_2 = st.container(border=True, height=650)
|
177 |
+
with c3_2:
|
178 |
+
create_expander_with_check_button("SUPPORTED", default_title, default_context, predict)
|
179 |
+
create_expander_with_check_button("REFUTED", default_title, default_context, predict)
|
180 |
+
create_expander_with_check_button("NEI", default_title, default_context, predict)
|
181 |
+
|
182 |
+
# Update session state to track if claims and evidence are entered
|
183 |
+
st.session_state["NEI_claim_entered"] = bool(st.session_state.get("NEI_input", ''))
|
184 |
+
st.session_state["REFUTED_claim_entered"] = bool(st.session_state.get("REFUTED_input", ''))
|
185 |
+
st.session_state["SUPPORTED_claim_entered"] = bool(st.session_state.get("SUPPORTED_input", ''))
|
186 |
+
st.session_state["REFUTED_evidence_entered"] = bool(st.session_state.get("REFUTED_evidence_selected", []))
|
187 |
+
st.session_state["SUPPORTED_evidence_entered"] = bool(st.session_state.get("SUPPORTED_evidence_selected", []))
|
188 |
+
|
189 |
+
all_claims_entered = st.session_state["NEI_claim_entered"] and \
|
190 |
+
st.session_state["REFUTED_claim_entered"] and \
|
191 |
+
st.session_state["SUPPORTED_claim_entered"]
|
192 |
+
|
193 |
+
all_evidence_selected = st.session_state["REFUTED_evidence_entered"] and \
|
194 |
+
st.session_state["SUPPORTED_evidence_entered"]
|
195 |
+
|
196 |
+
previous, next_, save, close = st.columns(4)
|
197 |
+
error = ''
|
198 |
+
with previous:
|
199 |
+
pr = st.button("Previous")
|
200 |
+
if pr:
|
201 |
+
if enough_claims_entered(default_title):
|
202 |
+
if current_index > 0:
|
203 |
+
st.session_state["current_index"] = current_index - 1
|
204 |
+
st.experimental_rerun()
|
205 |
+
else:
|
206 |
+
st.session_state["current_index"] = max_index
|
207 |
+
st.experimental_rerun()
|
208 |
+
else:
|
209 |
+
error = 'n_enough'
|
210 |
+
|
211 |
+
with next_:
|
212 |
+
next_b = st.button("Next")
|
213 |
+
if next_b:
|
214 |
+
if enough_claims_entered(default_title):
|
215 |
+
if current_index < max_index:
|
216 |
+
st.session_state["current_index"] = current_index + 1
|
217 |
+
st.experimental_rerun()
|
218 |
+
else:
|
219 |
+
st.session_state["current_index"] = 0
|
220 |
+
st.experimental_rerun()
|
221 |
+
else:
|
222 |
+
error = 'n_enough'
|
223 |
+
|
224 |
+
with save:
|
225 |
+
save_button = st.button("Save")
|
226 |
+
if save_button:
|
227 |
+
if all_claims_entered and all_evidence_selected:
|
228 |
+
error = save_data(default_context, default_title, default_link)
|
229 |
+
else:
|
230 |
+
error = 'save_fail'
|
231 |
+
|
232 |
+
if error == 'success':
|
233 |
+
st.success("Data saved successfully.")
|
234 |
+
elif error == 'duplicate':
|
235 |
+
st.warning(f"Maybe one of these claims with title '{default_title}' already exists.")
|
236 |
+
elif error == 'n_enough':
|
237 |
+
st.warning("Enter at least two claims for each label for this title before navigating.")
|
238 |
+
else:
|
239 |
+
st.warning("Please enter all claims and select all evidence before saving.")
|
240 |
+
|
241 |
+
with tab2:
|
242 |
+
st.title("Saved Annotations")
|
243 |
+
if annotated_data.empty:
|
244 |
+
st.info("No annotations saved yet.")
|
245 |
+
else:
|
246 |
+
st.dataframe(annotated_data)
|
247 |
+
|
248 |
+
if __name__ == '__main__':
|
249 |
+
predictor_app()
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
pandas
|
3 |
+
numpy
|
4 |
+
scikit-learn
|
5 |
+
nltk
|
6 |
+
pyvi
|
7 |
+
torch
|
8 |
+
transformers
|
9 |
+
Flask
|
10 |
+
requests
|
11 |
+
# streamlit==1.33.0
|
12 |
+
# pandas==2.0.3
|
13 |
+
# numpy==1.23.5
|
14 |
+
# scikit-learn
|
15 |
+
# nltk==3.8.1
|
16 |
+
# pyvi==0.1.1
|
17 |
+
# torch==2.2.2
|
18 |
+
# transformers==4.39.3
|
19 |
+
# Flask==3.0.3
|
20 |
+
# requests==2.31.0
|
utilities.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
warnings.filterwarnings('ignore')
|
3 |
+
import logging
|
4 |
+
logging.disable(logging.WARNING)
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import random
|
9 |
+
from transformers import AutoModel, AutoTokenizer
|
10 |
+
from Mbert import MBERTClassifier, SentencePairDataset
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
#from evidence_retrieval import evidence_top_n, similarities
|
13 |
+
from config import hf_token
|
14 |
+
import pandas as pd
|
15 |
+
|
16 |
+
# Thiết lập seed cố định
|
17 |
+
def set_seed(seed):
|
18 |
+
random.seed(seed)
|
19 |
+
np.random.seed(seed)
|
20 |
+
torch.manual_seed(seed)
|
21 |
+
if torch.cuda.is_available():
|
22 |
+
torch.cuda.manual_seed(seed)
|
23 |
+
torch.cuda.manual_seed_all(seed)
|
24 |
+
torch.backends.cudnn.deterministic = True
|
25 |
+
torch.backends.cudnn.benchmark = False
|
26 |
+
|
27 |
+
# Gọi hàm set_seed với seed cố định, ví dụ: 42
|
28 |
+
set_seed(42)
|
29 |
+
device = torch.device("cpu")
|
30 |
+
modelname = "SonFox2920/MBert_FC"
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained(modelname, token=hf_token)
|
32 |
+
mbert = AutoModel.from_pretrained(modelname, token=hf_token).to(device)
|
33 |
+
model = MBERTClassifier(mbert, num_classes=3).to(device)
|
34 |
+
model.load_state_dict(torch.load('Model/classifier.pt', map_location=device))
|
35 |
+
|
36 |
+
def predict(context, claim):
|
37 |
+
data = pd.DataFrame([{'context': context, 'claim': claim}])
|
38 |
+
|
39 |
+
# list_evidence_top5 = []
|
40 |
+
# list_evidence_top1 = []
|
41 |
+
|
42 |
+
# for i in range(len(data)):
|
43 |
+
# statement = data.claim[i]
|
44 |
+
# context = data.context[i]
|
45 |
+
# evidence_top5, top5_consine = evidence_top_n(context, statement)
|
46 |
+
# evidence_top1, top1_consine, rank_5 = similarities(evidence_top5, statement, top5_consine)
|
47 |
+
# evidence_top1 = "".join(evidence_top1)
|
48 |
+
# list_evidence_top5.append(rank_5)
|
49 |
+
# list_evidence_top1.append(evidence_top1)
|
50 |
+
|
51 |
+
# data['evidence_top5'] = list_evidence_top5
|
52 |
+
# data['evidence'] = list_evidence_top1
|
53 |
+
|
54 |
+
X1_pub_test = data['claim']
|
55 |
+
X2_pub_test = data['context']
|
56 |
+
X_pub_test = [(X1_pub_test, X2_pub_test) for (X1_pub_test, X2_pub_test) in zip(X1_pub_test, X2_pub_test)]
|
57 |
+
y_pub_test = [1]
|
58 |
+
|
59 |
+
test_dataset = SentencePairDataset(X_pub_test, y_pub_test, tokenizer, 256)
|
60 |
+
test_loader_pub = DataLoader(test_dataset, batch_size=1)
|
61 |
+
|
62 |
+
model.eval()
|
63 |
+
predictions = []
|
64 |
+
probabilities = []
|
65 |
+
|
66 |
+
for batch in test_loader_pub:
|
67 |
+
input_ids = batch["input_ids"].to(device)
|
68 |
+
attention_mask = batch["attention_mask"].to(device)
|
69 |
+
with torch.no_grad():
|
70 |
+
outputs = model(input_ids, attention_mask)
|
71 |
+
probs = torch.nn.functional.softmax(outputs, dim=1)
|
72 |
+
predicted = torch.argmax(outputs, dim=1)
|
73 |
+
predictions.extend(predicted.cpu().numpy().tolist())
|
74 |
+
probabilities.extend(probs.cpu().numpy().tolist())
|
75 |
+
|
76 |
+
data['verdict'] = predictions
|
77 |
+
data['verdict'] = data['verdict'].replace(0, "SUPPORTED")
|
78 |
+
data['verdict'] = data['verdict'].replace(1, "REFUTED")
|
79 |
+
data['verdict'] = data['verdict'].replace(2, "NEI")
|
80 |
+
|
81 |
+
result = {
|
82 |
+
'verdict': data['verdict'][0],
|
83 |
+
'probabilities': {
|
84 |
+
'SUPPORTED': probabilities[0][0],
|
85 |
+
'REFUTED': probabilities[0][1],
|
86 |
+
'NEI': probabilities[0][2]
|
87 |
+
}
|
88 |
+
}
|
89 |
+
|
90 |
+
return result
|
91 |
+
|
92 |
+
# # Set default context and claim
|
93 |
+
# context = "Trái Đất là hành tinh duy nhất trong Hệ Mặt Trời được biết đến là nơi có sự sống tồn tại. Nó là hành tinh lớn thứ ba trong hệ này về kích thước và khối lượng. Trái Đất hình cầu với bề mặt gồm nước và đất liền, được bao phủ bởi lớp khí quyển. Khí quyển của Trái Đất chủ yếu bao gồm nitơ và oxy, cùng với các khí nhà kính như hơi nước và carbon dioxide. Trái Đất quay quanh Mặt Trời theo một quỹ đạo hình ellip, hoàn thành một vòng quay trong khoảng 365 ngày, gây ra sự luân phiên của các mùa."
|
94 |
+
# claim = "Trái Đất là hành tinh duy nhất trong Hệ Mặt Trời được biết đến là nơi có sự sống tồn tại."
|
95 |
+
|
96 |
+
# verdict, probabilities = predict(context, claim)
|
97 |
+
|
98 |
+
# print(f"Verdict: {verdict}")
|
99 |
+
|
100 |
+
# # Display percentages with colors
|
101 |
+
# labels = ["SUPPORTED", "REFUTED", "NEI"]
|
102 |
+
# for i, label in enumerate(labels):
|
103 |
+
# print(f'{label}: {probabilities[0][i]*100:.2f}%')
|