Spaces:
Sleeping
Sleeping
import time | |
import numpy as np | |
import pandas as pd | |
import gradio as gr | |
import torch | |
import faiss | |
from sklearn.preprocessing import normalize | |
from transformers import AutoTokenizer, AutoModelForQuestionAnswering | |
from sentence_transformers import SentenceTransformer, util | |
from pythainlp import Tokenizer | |
import pickle | |
import re | |
from pythainlp.tokenize import sent_tokenize | |
from unstructured.partition.html import partition_html | |
DEFAULT_MODEL = 'wangchanberta' | |
DEFAULT_SENTENCE_EMBEDDING_MODEL = 'intfloat/multilingual-e5-base' | |
EMBEDDINGS_PATH = 'data/embeddings.pkl' | |
MODEL_DICT = { | |
'wangchanberta': 'Chananchida/wangchanberta-xet_ref-params', | |
'wangchanberta-hyp': 'Chananchida/wangchanberta-xet_hyp-params', | |
} | |
def load_model(model_name=DEFAULT_MODEL): | |
model = AutoModelForQuestionAnswering.from_pretrained(MODEL_DICT[model_name]) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_DICT[model_name]) | |
print('Load model done') | |
return model, tokenizer | |
def load_embedding_model(model_name=DEFAULT_SENTENCE_EMBEDDING_MODEL): | |
if torch.cuda.is_available(): | |
embedding_model = SentenceTransformer(model_name, device='cuda') | |
else: | |
embedding_model = SentenceTransformer(model_name) | |
print('Load sentence embedding model done') | |
return embedding_model | |
def set_index(vector): | |
if torch.cuda.is_available(): | |
res = faiss.StandardGpuResources() | |
index = faiss.IndexFlatL2(vector.shape[1]) | |
gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index) | |
gpu_index_flat.add(vector) | |
index = gpu_index_flat | |
else: | |
index = faiss.IndexFlatL2(vector.shape[1]) | |
index.add(vector) | |
return index | |
def get_embeddings(embedding_model, text_list): | |
return embedding_model.encode(text_list) | |
def prepare_sentences_vector(encoded_list): | |
encoded_list = [i.reshape(1, -1) for i in encoded_list] | |
encoded_list = np.vstack(encoded_list).astype('float32') | |
encoded_list = normalize(encoded_list) | |
return encoded_list | |
def load_embeddings(file_path=EMBEDDINGS_PATH): | |
with open(file_path, "rb") as fIn: | |
stored_data = pickle.load(fIn) | |
stored_sentences = stored_data['sentences'] | |
stored_embeddings = stored_data['embeddings'] | |
print('Load (questions) embeddings done') | |
return stored_embeddings | |
def faiss_search(index, question_vector, k=1): | |
distances, indices = index.search(question_vector, k) | |
return distances,indices | |
def model_pipeline(model, tokenizer, question, context): | |
inputs = tokenizer(question, context, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
answer_start_index = outputs.start_logits.argmax() | |
answer_end_index = outputs.end_logits.argmax() | |
predict_answer_tokens = inputs.input_ids[0, answer_start_index: answer_end_index + 1] | |
Answer = tokenizer.decode(predict_answer_tokens) | |
return Answer.replace('<unk>','@') | |
def predict_test(embedding_model, context, question, index): # sent_tokenize pythainlp | |
t = time.time() | |
question = question.strip() | |
question_vector = get_embeddings(embedding_model, question) | |
question_vector = prepare_sentences_vector([question_vector]) | |
distances, indices = faiss_search(index, question_vector, 3) # Retrieve top 3 indices | |
most_similar_contexts = '' | |
for i in range(3): # Loop through top 3 indices | |
most_sim_context = context[indices[0][i]].strip() | |
# most_similar_contexts.append(most_sim_context) | |
most_similar_contexts += str(i)+': '+most_sim_context + "\n\n" | |
print(most_similar_contexts) | |
return most_similar_contexts | |
if __name__ == "__main__": | |
url = "https://www.dataxet.co/media-landscape/2024-th" | |
elements = partition_html(url=url) | |
context = [str(element) for element in elements if len(str(element)) >60] | |
embedding_model = load_embedding_model() | |
index = set_index(prepare_sentences_vector(get_embeddings(embedding_model, context))) | |
def chat_interface(question, history): | |
response = predict_test(embedding_model, context, question, index) | |
return response | |
examples=['ภูมิทัศน์สื่อไทยในปี 2567 มีแนวโน้มว่า ', | |
'Fragmentation คือ', | |
'ติ๊กต๊อก คือ', | |
'รายงานจาก Reuters Institute' | |
] | |
interface = gr.ChatInterface(fn=chat_interface, | |
examples=examples) | |
interface.launch() |