Spaces:
Sleeping
Sleeping
File size: 1,872 Bytes
b2f9c3c 0fc39f5 b2f9c3c ade21fe c72cbc0 b2f9c3c cf0fcac 70eb6a4 0fc39f5 ffa5f5d b2f9c3c 70eb6a4 b2f9c3c 0f19f2b aaf24e8 8f2a805 aaf24e8 b2f9c3c 1060e26 0f19f2b 94ba614 59eb01b 70eb6a4 0bfc10b c64e7a9 70eb6a4 0bfc10b 70eb6a4 094537b 8f2a805 a29a79d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from sentence_transformers import SentenceTransformer
from fastapi import FastAPI
import pickle
import pandas as pd
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import torch
corpus = pickle.load(open("./corpus/all_embeddings_disease.pickle", "rb")).astype("float")
# label_encoder = pickle.load(open("./corpus/label_encoder.pickle", "rb"))
# model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')
df = pd.DataFrame(pickle.load(open("./corpus/y_all_disease.pickle", "rb")))
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Disease(BaseModel):
id: int
name: str
url: str
score: float
class Symptoms(BaseModel):
query: str
@app.get("/")
def home():
print(df.iloc[0])
return {"Hello": "World!"}
@app.post("/", response_model=list[Disease])
async def predict(symptoms: Symptoms):
query_embedding = model.encode(symptoms.query).astype('float')
similarity_vectors = model.similarity(query_embedding, corpus)[0]
scores, indicies = torch.topk(similarity_vectors, k=len(corpus))
# id_ = df.iloc[indicies].reset_index(drop=True)
ls = df.iloc[indicies].copy()
# print(ls.iloc[0])
# id_ = id_.drop_duplicates("label")
ls["scores"] = scores
# scores = scores[id_.index]
# diseases = label_encoder.inverse_transform(id_.label.values)
# id_ = id_.label.values
diseases = [dict({"id": value[0],
"name": value[1],
"url" : value[2],
"score" : value[3]})
for value in zip(ls.index, ls["name"], ls["url"], ls["scores"])]
return diseases |