import os from http import HTTPStatus from fastapi.responses import StreamingResponse from fastapi import FastAPI, Query from typing import List import spaces import torch import uvicorn import time import numpy as np os.system("pip install transformers") os.system("pip install accelerate") os.system("pip install peft") os.system("pip install -U FlagEmbedding") #fmt: off from transformers import AutoModelForCausalLM, AutoTokenizer from FlagEmbedding import BGEM3FlagModel app = FastAPI() zero = torch.Tensor([0]).cuda() print(zero.device) model_name = "BAAI/bge-m3" model = BGEM3FlagModel(model_name, use_fp16=True) @spaces.GPU def get_rag_text(sentence: str, candidates: List[str], top_k: int): start_time = time.time() query_embeddings = model.encode([sentence], batch_size=1, max_length=8192, # If you don't need such a long length, you can set a smaller value to speed up the encoding process. )['dense_vecs'] key_embeddings = model.encode(candidates)['dense_vecs'] similarity = query_embeddings @ key_embeddings.T similarity = similarity.squeeze(0) elapsed_time = time.time() - start_time print(elapsed_time) rag_result = "" top_k_indices = np.argsort(similarity)[-top_k:] for idx in top_k_indices: rag_result += (candidates[idx] + "/n") rag_result = rag_result.rstrip() return {"rag_result": rag_result} @app.get("/") async def get_rag_result(prompt: str, candidates: List[str] = Query(...), top_k: int = Query(...)): rag_text = get_rag_text(prompt, candidates, top_k) return rag_text if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)