KaLM-Embedding / app.py
YanshekWoo's picture
Upload folder using huggingface_hub
c9d8253 verified
import json
import argparse
from pathlib import Path
from typing import List
import gradio as gr
import faiss
import numpy as np
import torch
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
file_example = """Please upload a JSON file with a "text" field (with optional "title" field). For example
```JSON
[
{"title": "", "text": "This an example text without the title"},
{"title": "Title A", "text": "This an example text with the title"},
{"title": "Title B", "text": "This an example text with the title"},
]
```
Due to the computation resources, please test with small scale data (<1000).
"""
def create_index(embeddings, use_gpu):
index = faiss.IndexFlatIP(len(embeddings[0]))
embeddings = np.asarray(embeddings, dtype=np.float32)
if use_gpu:
co = faiss.GpuMultipleClonerOptions()
co.shard = True
co.useFloat16 = True
index = faiss.index_cpu_to_all_gpus(index, co=co)
index.add(embeddings)
return index
def upload_file_fn(
file_path: List[str],
progress: gr.Progress = gr.Progress(track_tqdm=True)
):
try:
with open(file_path) as f:
document_data = json.load(f)
gr.Info(f"Upload {len(document_data)} documents.")
if len(document_data) > 1000:
gr.Info(f"Cut uploaded documents to 1000 due to the computation resource.")
document_data = document_data[: 1000]
documents = []
for obj in document_data:
text = obj["title"] + "\n" + obj["text"] if obj.get("title") else obj["text"]
if len(str(text).strip()):
documents.append(text)
else:
documents.append(model.tokenizer.eos_token)
except Exception as e:
print(e)
gr.Error("Read the file failed. Please check the data format.")
gr.Error(str(e))
return None, gr.update(interactive=False)
if len(documents) < 5:
gr.Error("Please upload more than 53 documents.")
return None, gr.update(interactive=False)
# documents_embeddings = model.encode(documents, show_progress_bar=True)
documents_embeddings = []
batch_size = 16
for i in tqdm(range(0, len(documents), batch_size)):
batch_documents = documents[i: i+batch_size]
batch_embeddings = model.encode(batch_documents, show_progress_bar=True)
documents_embeddings.extend(batch_embeddings)
document_index = create_index(documents_embeddings, use_gpu=False)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
document_state = {"document_data": document_data, "document_index": document_index}
return document_state, gr.update(interactive=True)
def clear_file_fn():
return None, gr.update(interactive=True)
def retrieve_document_fn(question, document_states, instruct):
num_retrieval_doc = 5
if document_states is None:
gr.Warning("Please upload documents first!")
return [None for i in range(num_retrieval_doc)] + [None]
document_data, document_index = document_states["document_data"], document_states["document_index"]
question_with_inst = str(instruct) + str(question)
if len(question_with_inst.strip()) == 0:
gr.Warning("Please enter a non-empty query.")
return None, None, None, None, None, document_states
question_embedding = model.encode([question_with_inst])
batch_scores, batch_inxs = document_index.search(question_embedding, k=min(len(document_data), 150))
answers = [document_data[i]["text"] for i in batch_inxs[0][:num_retrieval_doc]]
return answers[0], answers[1], answers[2], answers[3], answers[4],document_states
def main(args):
global model
model = SentenceTransformer(
args.model_name_or_path,
revision=args.revision,
)
document_state = gr.State()
with open(Path(__file__).parent / "resources/head.html") as html_file:
head = html_file.read().strip()
with gr.Blocks(theme=gr.themes.Soft(font="sans-serif").set(background_fill_primary="linear-gradient(90deg, #e3ffe7 0%, #d9e7ff 100%)", background_fill_primary_dark="linear-gradient(90deg, #4b6cb7 0%, #182848 100%)",),
head=head,
css=Path(__file__).parent / "resources/styles.css",
title="KaLM-Embedding",
fill_height=True,
analytics_enabled=False) as demo:
gr.Markdown(file_example)
doc_files_box = gr.File(label="Upload Documents", file_types=[".json"], file_count="single")
model_selection = gr.Radio(["HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5"], value="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5", label="Model Selection", interactive=False)
retrieval_interface = gr.Interface(
fn=retrieve_document_fn,
inputs=[gr.Textbox(label="Query"), document_state],
outputs=[gr.Text(label="Recall-1"), gr.Text(label="Recall-2"), gr.Text(label="Recall-3"), gr.Text(label="Recall-4"), gr.Text(label="Recall-5"), gr.State()],
additional_inputs=[gr.Textbox("Instruct: Given a query, retrieve documents that answer the query. \n Query: ", label="Instruct of Query", lines=2)],
concurrency_limit=1,
allow_flagging="never",
)
# retrieval_interface.input_components[0] = gr.update(interactive=False)
doc_files_box.upload(
upload_file_fn,
[doc_files_box],
[document_state, retrieval_interface.input_components[0]],
queue=True,
trigger_mode="once"
)
doc_files_box.clear(
clear_file_fn,
None,
[document_state, retrieval_interface.input_components[0]],
queue=True,
trigger_mode="once"
)
demo.launch()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5")
parser.add_argument("--revision", type=str, default=None)
args = parser.parse_args()
main(args)