File size: 2,269 Bytes
fcce6a8
b1a4b26
 
 
4cdd282
 
 
699e6ab
ff88e84
fcce6a8
 
 
7c49535
699e6ab
 
 
7c49535
 
 
 
 
b1a4b26
 
 
 
 
 
 
 
 
699e6ab
 
fcce6a8
 
 
699e6ab
b1a4b26
 
 
9833ff2
b1a4b26
9833ff2
b1a4b26
 
 
 
 
 
7c49535
d728235
b1a4b26
 
7c49535
b1a4b26
 
 
7c49535
b1a4b26
 
 
 
 
 
 
 
 
 
 
 
 
 
699e6ab
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from fastapi import FastAPI
from sentence_transformers import CrossEncoder, SentenceTransformer
from sentence_transformers.util import cos_sim

import torch
import numpy as np

from typing import List
from pydantic import BaseModel

app = FastAPI()

class InputListModel(BaseModel):
    keywords: List[str]
    contents: List[str]

class InputModel(BaseModel):
    keyword: str
    content: str


# model = CrossEncoder(
#     # "jinaai/jina-reranker-v2-base-multilingual",    
#     "Alibaba-NLP/gte-multilingual-reranker-base",
#     trust_remote_code=True,
# )

model = SentenceTransformer(
    "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
    trust_remote_code=True
)

@app.get("/")
def greet_json():
    return {"Hello": "World!"}

@app.post("/predict")
async def predict(inp: InputModel):

    text_emb = model.encode(inp.content, convert_to_tensor=True)
    
    summarize = model.encode(inp.keyword, convert_to_tensor=True)
    
    out = (torch.nn.functional.cosine_similarity(text_emb, summarize, dim=-1) + 1)/2
    # out = (cos_sim(text_emb, summarize) + 1)/2
    return {"results":out.tolist()}


@app.post("/predict_list")
async def predict_list(inp: InputListModel):
    text_emb = model.encode(inp.contents, convert_to_tensor=True)
    summarize = model.encode(inp.keywords, convert_to_tensor=True)

    out = (torch.nn.functional.cosine_similarity(text_emb, summarize, dim=-1) + 1)/2
    # out = (cos_sim(text_emb, summarize) + 1)/2
    return {"results":out.tolist()}

# @app.post("/predict_list")
# async def predict_list(inp : InputListModel):
#     sentence_pairs = [[query, doc] for query,doc in zip(inp.keywords, inp.contents)]
#     scores = model.predict(sentence_pairs, convert_to_tensor=False)#.tolist()
#     # (-scores).argsort().tolist()
#     return {"results":scores.tolist()}


# @app.post("/predict")
# async def predict(inp : InputModel):
#     sentence_pairs = [[inp.keyword, inp.content]]
#     scores = model.predict(sentence_pairs, convert_to_tensor=False)#.tolist()
#     # (-scores).argsort().tolist()
#     return {"results":scores.tolist()[0]}
    
    # keywords = model.encode(inp.keywords)
    # contents = model.encode(inp.contents)
    # return {"results":np.linalg.norm(contents-keywords).tolist()}