suwonpabby's picture
For Check
329ec70
raw
history blame
1.81 kB
import os
from http import HTTPStatus
from fastapi.responses import StreamingResponse
from fastapi import FastAPI, Query
from typing import List
import spaces
import torch
import uvicorn
import time
import numpy as np
os.system("pip install transformers")
os.system("pip install accelerate")
os.system("pip install peft")
os.system("pip install -U FlagEmbedding")
#fmt: off
from transformers import AutoModelForCausalLM, AutoTokenizer
from FlagEmbedding import BGEM3FlagModel
app = FastAPI()
zero = torch.Tensor([0]).cuda()
print(zero.device)
model_name = "BAAI/bge-m3"
model = BGEM3FlagModel(model_name,
use_fp16=True)
@spaces.GPU
def get_rag_text(sentence: str, candidates: List[str], top_k: int):
start_time = time.time()
query_embeddings = model.encode([sentence],
batch_size=1,
max_length=8192, # If you don't need such a long length, you can set a smaller value to speed up the encoding process.
)['dense_vecs']
key_embeddings = model.encode(candidates)['dense_vecs']
similarity = query_embeddings @ key_embeddings.T
similarity = similarity.squeeze(0)
elapsed_time = time.time() - start_time
print(elapsed_time)
rag_result = ""
top_k_indices = np.argsort(similarity)[-top_k:]
for idx in top_k_indices:
rag_result += (candidates[idx] + "/n")
rag_result = rag_result.rstrip()
return {"rag_result": rag_result}
@app.get("/")
async def get_rag_result(prompt: str, candidates: List[str] = Query(...), top_k: int = Query(...)):
rag_text = get_rag_text(prompt, candidates, top_k)
return rag_text
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)