import time import numpy as np import pandas as pd import gradio as gr import torch import faiss from sklearn.preprocessing import normalize from transformers import AutoTokenizer, AutoModelForQuestionAnswering from sentence_transformers import SentenceTransformer, util from pythainlp import Tokenizer import pickle import re from pythainlp.tokenize import sent_tokenize from unstructured.partition.html import partition_html DEFAULT_MODEL = 'wangchanberta' DEFAULT_SENTENCE_EMBEDDING_MODEL = 'intfloat/multilingual-e5-base' MODEL_DICT = { 'wangchanberta': 'Chananchida/wangchanberta-xet_ref-params', 'wangchanberta-hyp': 'Chananchida/wangchanberta-xet_hyp-params', } def load_model(model_name=DEFAULT_MODEL): model = AutoModelForQuestionAnswering.from_pretrained(MODEL_DICT[model_name]) tokenizer = AutoTokenizer.from_pretrained(MODEL_DICT[model_name]) print('Load model done') return model, tokenizer def load_embedding_model(model_name=DEFAULT_SENTENCE_EMBEDDING_MODEL): if torch.cuda.is_available(): embedding_model = SentenceTransformer(model_name, device='cuda') else: embedding_model = SentenceTransformer(model_name) print('Load sentence embedding model done') return embedding_model def set_index(vector): if torch.cuda.is_available(): res = faiss.StandardGpuResources() index = faiss.IndexFlatL2(vector.shape[1]) gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index) gpu_index_flat.add(vector) index = gpu_index_flat else: index = faiss.IndexFlatL2(vector.shape[1]) index.add(vector) return index def get_embeddings(embedding_model, text_list): return embedding_model.encode(text_list) def prepare_sentences_vector(encoded_list): encoded_list = [i.reshape(1, -1) for i in encoded_list] encoded_list = np.vstack(encoded_list).astype('float32') encoded_list = normalize(encoded_list) return encoded_list def faiss_search(index, question_vector, k=1): distances, indices = index.search(question_vector, k) return distances,indices def model_pipeline(model, tokenizer, question, context): inputs = tokenizer(question, context, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) answer_start_index = outputs.start_logits.argmax() answer_end_index = outputs.end_logits.argmax() predict_answer_tokens = inputs.input_ids[0, answer_start_index: answer_end_index + 1] Answer = tokenizer.decode(predict_answer_tokens) return Answer.replace('','@') def predict_test(embedding_model, context, question, index, url): t = time.time() question = question.strip() question_vector = get_embeddings(embedding_model, question) question_vector = prepare_sentences_vector([question_vector]) distances, indices = faiss_search(index, question_vector, 3) most_similar_contexts = '' for i in range(3): most_sim_context = context[indices[0][i]].strip() answer_url = f"{url}#:~:text={most_sim_context}" # encoded_url = urllib.parse.quote(answer_url) most_similar_contexts += f'[ {i+1} ]: {most_sim_context}\n\n' print(most_similar_contexts) return most_similar_contexts if __name__ == "__main__": url = "https://www.dataxet.co/media-landscape/2024-th" elements = partition_html(url=url) context = [str(element) for element in elements if len(str(element)) >60] embedding_model = load_embedding_model() index = set_index(prepare_sentences_vector(get_embeddings(embedding_model, context))) def chat_interface(question, history): response = predict_test(embedding_model, context, question, index, url) return response examples=['ภูมิทัศน์สื่อไทยในปี 2567 มีแนวโน้มว่า ', 'Fragmentation คือ', 'ติ๊กต๊อก คือ', 'รายงานจาก Reuters Institute' ] interface = gr.ChatInterface(fn=chat_interface, examples=examples) interface.launch()