File size: 1,811 Bytes
d8e895c
60d35af
d8e895c
 
 
 
 
 
 
 
60d35af
6f3c591
d8e895c
 
7cbe2d5
fce8bf1
 
d8e895c
60d35af
d8e895c
60d35af
d8e895c
 
 
329ec70
d8e895c
329ec70
 
 
d8e895c
 
329ec70
 
d8e895c
 
329ec70
 
 
 
 
d8e895c
329ec70
 
d8e895c
 
 
 
329ec70
 
 
d8e895c
 
 
 
 
 
329ec70
d8e895c
 
 
 
329ec70
 
 
 
d8e895c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
from http import HTTPStatus
from fastapi.responses import StreamingResponse
from fastapi import FastAPI, Query
from typing import List

import spaces
import torch
import uvicorn
import time
import numpy as np

os.system("pip install transformers")
os.system("pip install accelerate")
os.system("pip install peft")
os.system("pip install -U FlagEmbedding")

#fmt: off
from transformers import AutoModelForCausalLM, AutoTokenizer
from FlagEmbedding import BGEM3FlagModel

app = FastAPI()

zero = torch.Tensor([0]).cuda()
print(zero.device)

model_name = "BAAI/bge-m3"
model = BGEM3FlagModel(model_name,  
                       use_fp16=True)


@spaces.GPU
def get_rag_text(sentence: str, candidates: List[str], top_k: int):
    start_time = time.time()

    query_embeddings = model.encode([sentence], 
                            batch_size=1, 
                            max_length=8192, # If you don't need such a long length, you can set a smaller value to speed up the encoding process.
                            )['dense_vecs']
                            
    
    key_embeddings = model.encode(candidates)['dense_vecs']


    similarity = query_embeddings @ key_embeddings.T
    similarity = similarity.squeeze(0)

    elapsed_time = time.time() - start_time
    print(elapsed_time)     

    rag_result = ""
    top_k_indices = np.argsort(similarity)[-top_k:]
    for idx in top_k_indices:
        rag_result += (candidates[idx] + "/n")
    rag_result = rag_result.rstrip()
    
    return {"rag_result": rag_result}



@app.get("/")
async def get_rag_result(prompt: str, candidates: List[str] = Query(...), top_k: int = Query(...)):
    rag_text = get_rag_text(prompt, candidates, top_k)
    return rag_text
    


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)