Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,811 Bytes
d8e895c 60d35af d8e895c 60d35af 6f3c591 d8e895c 7cbe2d5 fce8bf1 d8e895c 60d35af d8e895c 60d35af d8e895c 329ec70 d8e895c 329ec70 d8e895c 329ec70 d8e895c 329ec70 d8e895c 329ec70 d8e895c 329ec70 d8e895c 329ec70 d8e895c 329ec70 d8e895c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import os
from http import HTTPStatus
from fastapi.responses import StreamingResponse
from fastapi import FastAPI, Query
from typing import List
import spaces
import torch
import uvicorn
import time
import numpy as np
os.system("pip install transformers")
os.system("pip install accelerate")
os.system("pip install peft")
os.system("pip install -U FlagEmbedding")
#fmt: off
from transformers import AutoModelForCausalLM, AutoTokenizer
from FlagEmbedding import BGEM3FlagModel
app = FastAPI()
zero = torch.Tensor([0]).cuda()
print(zero.device)
model_name = "BAAI/bge-m3"
model = BGEM3FlagModel(model_name,
use_fp16=True)
@spaces.GPU
def get_rag_text(sentence: str, candidates: List[str], top_k: int):
start_time = time.time()
query_embeddings = model.encode([sentence],
batch_size=1,
max_length=8192, # If you don't need such a long length, you can set a smaller value to speed up the encoding process.
)['dense_vecs']
key_embeddings = model.encode(candidates)['dense_vecs']
similarity = query_embeddings @ key_embeddings.T
similarity = similarity.squeeze(0)
elapsed_time = time.time() - start_time
print(elapsed_time)
rag_result = ""
top_k_indices = np.argsort(similarity)[-top_k:]
for idx in top_k_indices:
rag_result += (candidates[idx] + "/n")
rag_result = rag_result.rstrip()
return {"rag_result": rag_result}
@app.get("/")
async def get_rag_result(prompt: str, candidates: List[str] = Query(...), top_k: int = Query(...)):
rag_text = get_rag_text(prompt, candidates, top_k)
return rag_text
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |