Spaces:
Paused
Paused
File size: 11,951 Bytes
76cf4e1 00c9a21 2715575 00c9a21 32bcbdd 00c9a21 32bcbdd 00c9a21 32bcbdd 00c9a21 32bcbdd 00c9a21 32bcbdd 00c9a21 32bcbdd 00c9a21 32bcbdd 00c9a21 32bcbdd 00c9a21 32bcbdd 00c9a21 2715575 00c9a21 32bcbdd 00c9a21 32bcbdd 00c9a21 32bcbdd 00c9a21 32bcbdd 00c9a21 2715575 32bcbdd 00c9a21 32bcbdd 2715575 32bcbdd 2715575 32bcbdd 2715575 32bcbdd 00c9a21 2715575 00c9a21 32bcbdd 00c9a21 32bcbdd 00c9a21 32bcbdd 00c9a21 32bcbdd 00c9a21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 |
# !wget -nc https://raw.githubusercontent.com/baobuiquang/datasets/main/sample.xlsx >& /dev/null
# !pip install gradio==4.21.0 >& /dev/null
# ==============================
# ========== HARDCODE ==========
X_LIST_NAME = "Tên chỉ số"
# ==============================
# ========== PACKAGES ==========
import gradio as gr # gradio==4.21.0
import pandas as pd
import numpy as np
import torch
import time
from transformers import AutoTokenizer, AutoModel
from datetime import datetime
# pd.options.mode.chained_assignment = None # default='warn'
# ===========================
# ========== FILES ==========
FILE_NAME = "data/sample.xlsx"
df_map = pd.read_excel(FILE_NAME, header=None, sheet_name=None)
df_map_sheet_names = pd.ExcelFile(FILE_NAME).sheet_names
# ============================
# ========== MODELS ==========
MODEL_NAME = "baobuiquang/XLM-ROBERTA-ME5-BASE"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
# ===============================
# ========== FUNCTIONS ==========
# Text -> Embedding
def text_to_embedding(text):
lower_text = text.lower() # Lowercasing
encoded_input = tokenizer(lower_text, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = model(**encoded_input)
embedding = mean_pooling(model_output, encoded_input['attention_mask'])
return embedding[0]
# List of Texts -> List of Embeddings
def texts_to_embeddings(list_of_texts):
list_of_lower_texts = [t.lower() for t in list_of_texts] # Lowercasing
encoded_input = tokenizer(list_of_lower_texts, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = model(**encoded_input)
list_of_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
return list_of_embeddings
# Mean Pooling
# - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# Cosine Similarity between 2 embeddings
def cosine_similarity(a, b):
return np.dot(a, b)/(np.linalg.norm(a)*np.linalg.norm(b))
# Find index of the max similarity when comparing an embedding to a list
def similarity(my_embedding, list_of_embeddings):
list_of_sim = [0] * len(list_of_embeddings)
max_sim = -1.0
max_sim_index = 0
for i in range(len(list_of_embeddings)):
cos_sim = cosine_similarity(my_embedding, list_of_embeddings[i])
list_of_sim[i] = cos_sim
if cos_sim > max_sim:
max_sim = cos_sim
max_sim_index = i
return {"max_index": max_sim_index, "max": max_sim, "list": list_of_sim}
# ===================================
# ========== PREPROCESSING ==========
# preprocessed_df_map ----------------------------------------------------------
# - A list of dataframes (preprocessed), each dataframe contains data from 1 sheet from the XLSX file
preprocessed_df_map = []
for sheet_name in df_map_sheet_names:
# Get sheet data
df = pd.DataFrame(df_map[sheet_name])
# Setup header
header_position = df[df[0] == "#"].index[0]
new_header = []
for e in df.loc[header_position]:
if isinstance(e, datetime):
new_header.append(
f"ngày {e.strftime('%d').lstrip('0')} tháng {e.strftime('%m').lstrip('0')} năm {e.strftime('%Y')} {e.strftime('%d').lstrip('0')}/{e.strftime('%m').lstrip('0')}/{e.strftime('%Y')} {e.strftime('%d')}/{e.strftime('%m')}/{e.strftime('%Y')}"
)
else:
new_header.append(e)
df = df.rename(columns = dict(zip(df.columns, new_header)))
df = df.iloc[header_position+1:]
# # Preprocess column "#" values
# df['#'] = df['#'].replace(to_replace = r'^\d+(\.\d+)?$', value = np.nan, regex=True)
# df['#'] = df['#'].fillna(method = 'ffill')
# df = df.dropna(thresh = df.shape[1] * 0.25, axis = 0) # Keep rows that have at least 25% values are not NaN
# df = df.dropna(thresh = df.shape[1] * 0.25, axis = 1) # Keep cols that have at least 25% values are not NaN
# df = df.rename(columns={'#': 'Nhóm chỉ số'})
# # Move column "#" to the end
# columns = list(df.columns)
# columns.append(columns.pop(0))
# df = df.reindex(columns=columns)
# General Preprocess
df = df.reset_index(drop=True)
df = df.fillna('No data')
df = df.astype(str)
# Return the preprocessed sheet
preprocessed_df_map.append(df)
# ========================================
# ========== FEATURE EXTRACTION ==========
# embeddings_map ---------------------------------------------------------------
# - A list of pre-calculated embeddings (vectors) of x/y axis in the corresponding dataframe in the `preprocessed_df_map`
x_list_embeddings_map = []
y_list_embeddings_map = []
for i in range(len(preprocessed_df_map)):
df = preprocessed_df_map[i]
# HARDCODE
x_list = list(df[X_LIST_NAME])
y_list = list(df.columns)
# Only need to calculate once
x_list_embeddings = texts_to_embeddings(x_list)
y_list_embeddings = texts_to_embeddings(y_list)
# Return the embeddings map
x_list_embeddings_map.append(x_list_embeddings)
y_list_embeddings_map.append(y_list_embeddings)
# ==========================
# ========== MAIN ==========
def chatbot_mechanism(message, history, additional_input_1):
# Clarify namings
question = message
sheet_id = additional_input_1
# Select the right data
df = preprocessed_df_map[sheet_id]
x_list_embeddings = x_list_embeddings_map[sheet_id]
y_list_embeddings = y_list_embeddings_map[sheet_id]
# Find the position of the needed cell
question_embedding = text_to_embedding(question)
x_sim = similarity(question_embedding, x_list_embeddings)
y_sim = similarity(question_embedding, y_list_embeddings)
x_index = x_sim['max_index']
y_index = y_sim['max_index']
x_score = x_sim['max']
y_score = y_sim['max']
x_text = str(df.loc[x_index, 'Tên chỉ số'])
y_text = str(df.columns[y_index])
# Small adjustment for better print
if y_text.count('/') == 4:
y_text = y_text[-10:] # If y_text is preprocessed datetime format, trim it
# Just add some text to warn users
eval_text = ""
eval_text_sub_title = ""
if x_score < 0.85 or y_score < 0.85:
eval_text_sub_title = "Cảnh báo:"
eval_text = "⚠️ Đặc trưng trích xuất không rõ ràng ⚠️"
# Score display
x_score_display = str(round((x_score - 0.8) / (1.0 - 0.8) * 100, 1))
y_score_display = str(round((y_score - 0.8) / (1.0 - 0.8) * 100, 1))
# Cell value
cell_value = df.iloc[x_index, y_index]
# Final print
final_output_message = f"\
<div style='color: gray; font-size: 80%; font-family: courier, monospace;'>\
Kết quả:\
</div>\
<div style='font-weight: bold;'>\
{cell_value}\
</div>\
<div style='color: gray; font-size: 80%; font-family: courier, monospace; margin-top: 6px;'>\
Đặc trưng trích xuất được:\
</div>\
• {x_text}<br>\
• {y_text}<br>\
<div style='color: gray; font-size: 80%; font-family: courier, monospace; margin-top: 6px;'>\
Đánh giá:\
</div>\
Độ tương quan: [x={x_score_display}%, y={y_score_display}%]<br>\
<div style='color: gray; font-size: 80%; font-family: courier, monospace; margin-top: 6px;'>\
{eval_text_sub_title}\
</div>\
<div style='color: red; font-weight: bold;'>\
{eval_text}\
</div>\
"
return final_output_message
# for i in range(len(final_output_message)):
# time.sleep(0.1)
# yield final_output_message[: i+1]
textbox_input = gr.Textbox(
label = "Câu hỏi",
placeholder = "Hãy đặt một câu hỏi",
container = False,
scale = 7,
)
with gr.Blocks(
title = "CHATBOT",
theme = gr.themes.Base(
primary_hue = "stone",
),
css = '\
footer { visibility: hidden; display: none; }\
[data-testid="block-label"] { visibility: hidden; display: none; } \
',
# .gradio-container { max-width: 1000px !important; }\
) as app:
with gr.Row():
with gr.Column(scale=1):
additional_input_1 = gr.Radio(
choices = df_map_sheet_names,
value = "Tư pháp", # Default
type = "index", # Return index instead of value
label = "Dữ liệu",
)
gr.Markdown(
"""
File dữ liệu: [`sample.xlsx`](https://view.officeapps.live.com/op/view.aspx?src=https%3A%2F%2Fraw.githubusercontent.com%2Fbaobuiquang%2Fdatasets%2Fmain%2Fsample.xlsx&wdOrigin=BROWSELINK)
"""
)
with gr.Column(scale=2):
gr.ChatInterface(
fn = chatbot_mechanism,
chatbot = gr.Chatbot(
bubble_full_width = False,
render = False,
height = 450,
),
textbox = textbox_input,
additional_inputs = [
additional_input_1,
],
retry_btn = None,
undo_btn = "Xoá lệnh chat gần nhất",
clear_btn = "Xoá toàn bộ đoạn chat",
submit_btn = "Gửi",
stop_btn = "Dừng",
autofocus = True,
)
with gr.Column(scale=1):
gr.Examples(
label = 'Câu hỏi ví dụ (Dữ liệu "Tư pháp")',
examples_per_page = 100,
examples = [
"Tổng số hồ sơ chứng thực bản sao từ bản chính tới ngày 10/1/2024 là bao nhiêu?", # 100
"15 tháng 1 năm 2024, hãy tìm dữ liệu tổng số hồ sơ chứng thực hợp đồng, giao dịch.", # 219
"Tổng số hồ sơ chứng thực chữ ký vào ngày 12 tháng 1 năm 2024 là bao nhiêu?", # 165
"Có bao nhiêu HS chứng thực việc sửa đổi, bổ sung, hủy bỏ ngày 14/01/2024?", # 194
"Tính đến ngày 11 tháng 1, 2024, số hồ sơ đăng ký kết hôn là bao nhiêu?", # 177
],
inputs = [textbox_input],
)
gr.Markdown(
"""
Câu trả lời đúng cho các ví dụ: 100, 219, 165, 194, 177
"""
)
gr.Examples(
label = 'Câu hỏi ví dụ (Dữ liệu "Công an huyện")',
examples_per_page = 100,
examples = [
"Số vụ phạm tội công nghệ cao ngày 19 tháng 3 năm 2024 là bao nhiêu?", # 121
"Tới ngày 20/3/2024, có mấy vụ án đặc biệt nghiêm trọng?", # 208
"Ngày 22 tháng 3 năm 2024, có bao nhiêu người chết do TNGT", # 273
"Có bao nhiêu vụ cháy cho đến ngày 24/03/2024?", # 437
"Tìm thông tin số vụ tai nạn giao thông tại ngày 18/3 năm 2024.", # 104
],
inputs = [textbox_input],
)
gr.Markdown(
"""
Câu trả lời đúng cho các ví dụ: 121, 208, 273, 437, 104
"""
)
app.launch(debug = False, share = False) |