from fastapi import FastAPI from sentence_transformers import CrossEncoder, SentenceTransformer from sentence_transformers.util import cos_sim import torch import numpy as np from typing import List from pydantic import BaseModel app = FastAPI() class InputListModel(BaseModel): keywords: List[str] contents: List[str] class InputModel(BaseModel): keyword: str content: str # model = CrossEncoder( # # "jinaai/jina-reranker-v2-base-multilingual", # "Alibaba-NLP/gte-multilingual-reranker-base", # trust_remote_code=True, # ) model = SentenceTransformer( "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", trust_remote_code=True ) @app.get("/") def greet_json(): return {"Hello": "World!"} @app.post("/predict") async def predict(inp: InputModel): text_emb = model.encode(inp.content, convert_to_tensor=True) summarize = model.encode(inp.keyword, convert_to_tensor=True) out = (torch.nn.functional.cosine_similarity(text_emb, summarize, dim=-1) + 1)/2 # out = (cos_sim(text_emb, summarize) + 1)/2 return {"results":out.tolist()} @app.post("/predict_list") async def predict_list(inp: InputListModel): text_emb = model.encode(inp.contents, convert_to_tensor=True) summarize = model.encode(inp.keywords, convert_to_tensor=True) out = (torch.nn.functional.cosine_similarity(text_emb, summarize, dim=-1) + 1)/2 # out = (cos_sim(text_emb, summarize) + 1)/2 return {"results":out.tolist()} # @app.post("/predict_list") # async def predict_list(inp : InputListModel): # sentence_pairs = [[query, doc] for query,doc in zip(inp.keywords, inp.contents)] # scores = model.predict(sentence_pairs, convert_to_tensor=False)#.tolist() # # (-scores).argsort().tolist() # return {"results":scores.tolist()} # @app.post("/predict") # async def predict(inp : InputModel): # sentence_pairs = [[inp.keyword, inp.content]] # scores = model.predict(sentence_pairs, convert_to_tensor=False)#.tolist() # # (-scores).argsort().tolist() # return {"results":scores.tolist()[0]} # keywords = model.encode(inp.keywords) # contents = model.encode(inp.contents) # return {"results":np.linalg.norm(contents-keywords).tolist()}