vojtam's picture
Update app.py
e0f946e verified
import gradio as gr
from datasets import load_dataset
import torch.nn.functional as F
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
import pickle
with open('book_embeddings.pkl', 'rb') as file:
book_embeddings = pickle.load(file)
model_checkpoint = 'intfloat/multilingual-e5-large'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModel.from_pretrained(model_checkpoint)
books_data = load_dataset('vojtam/czech_books_descriptions', split="train+test")
books_data.set_format('pandas')
def average_pool(last_hidden_states, attention_mask):
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def create_embeddings(tokenizer, model, input_texts, batch_size=32):
embeddings_list = []
for i in range(0, len(input_texts), batch_size):
batch_texts = input_texts[i:i + batch_size]
batch_dict = tokenizer(batch_texts, max_length=512, padding=True, truncation=True, return_tensors='pt')
# Get embeddings for batch
with torch.no_grad():
outputs = model(**batch_dict)
batch_embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
embeddings_list.append(batch_embeddings)
if (i + batch_size) % (batch_size * 10) == 0:
print(f"Processed {i + batch_size}/{len(input_texts)} texts")
return torch.cat(embeddings_list, dim=0)
def find_similar_books(query: str, n = 5):
input_query = "query: " + query
query_embedding = create_embeddings(tokenizer, model, input_query)
scores = ((query_embedding @ book_embeddings.T) * 100).detach().numpy()[0]
top_indices = np.argsort(scores)[-n:][::-1]
return books_data[top_indices]
css = """
.full-height-gallery {
height: calc(100vh - 250px);
overflow-y: auto;
}
#submit-btn {
background-color: #ff5b00;
color: #ffffff;
}
"""
with gr.Blocks(css=css) as intf:
with gr.Row():
text_input = gr.Textbox(label="Popis knihy", info = "Zadejte popis knihy, kterou byste si chtěli přečíst a aplikace najde nejpodobněší knihy dle vašeho popisu", placeholder='Zadejte popis, například "drama z prostředí nemocnice"')
n_books = gr.Number(value = 5, label = "Počet knih", info="Počet nejpodobnějších knih, které si přejete zobrazit", minimum = 1, step = 1)
with gr.Row():
submit_btn = gr.Button("Vyhledat knihy", elem_id="submit-btn")
clear_btn = gr.Button("Smazat")
with gr.Row():
dataframe = gr.Dataframe(label="Podobné knihy", show_label=False, elem_classes = ["full-height-gallery"])
submit_btn.click(fn=find_similar_books, inputs=[text_input, n_books], outputs=dataframe)
clear_btn.click(fn=lambda: [None, []], inputs=None, outputs=[text_input, dataframe])
intf.launch(share=True)