|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
import pandas as pd |
|
from sentence_transformers import SentenceTransformer |
|
import chromadb |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import uvicorn |
|
import requests |
|
|
|
app = FastAPI() |
|
|
|
origins = [ |
|
"http://localhost:5173", |
|
"localhost:5173" |
|
] |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=origins, |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"] |
|
) |
|
|
|
|
|
df = pd.read_csv("hf://datasets/QuyenAnhDE/Diseases_Symptoms/Diseases_Symptoms.csv") |
|
df['Symptoms'] = df['Symptoms'].str.split(',') |
|
df['Symptoms'] = df['Symptoms'].apply(lambda x: [s.strip() for s in x]) |
|
|
|
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
|
client = chromadb.PersistentClient(path='./chromadb') |
|
collection = client.get_or_create_collection(name="symptomsvector") |
|
|
|
class SymptomQuery(BaseModel): |
|
symptom: str |
|
|
|
|
|
@app.post("/find_matching_symptoms") |
|
def find_matching_symptoms(query: SymptomQuery): |
|
|
|
symptoms = query.symptom.split(',') |
|
all_results = [] |
|
|
|
for symptom in symptoms: |
|
symptom = symptom.strip() |
|
query_embedding = model.encode([symptom]) |
|
|
|
|
|
results = collection.query( |
|
query_embeddings=query_embedding.tolist(), |
|
n_results=3 |
|
) |
|
all_results.extend(results['documents'][0]) |
|
|
|
|
|
matching_symptoms = list(dict.fromkeys(all_results)) |
|
|
|
return {"matching_symptoms": matching_symptoms} |
|
|
|
|
|
@app.post("/find_matching_diseases") |
|
def find_matching_diseases(query: SymptomQuery): |
|
|
|
query_embedding = model.encode([query.symptom]) |
|
|
|
|
|
results = collection.query( |
|
query_embeddings=query_embedding.tolist(), |
|
n_results=5 |
|
) |
|
|
|
|
|
matching_symptoms = results['documents'][0] |
|
|
|
|
|
matching_diseases = df[df['Symptoms'].apply(lambda x: any(s in matching_symptoms for s in x))] |
|
|
|
return {"matching_diseases": matching_diseases['Name'].tolist()} |
|
|
|
|
|
@app.post("/find_disease_list") |
|
def find_disease_list(query: SymptomQuery): |
|
|
|
query_embedding = model.encode([query.symptom]) |
|
|
|
|
|
results = collection.query( |
|
query_embeddings=query_embedding.tolist(), |
|
n_results=5 |
|
) |
|
|
|
|
|
matching_symptoms = results['documents'][0] |
|
|
|
|
|
matching_diseases = df[df['Symptoms'].apply(lambda x: any(s in matching_symptoms for s in x))] |
|
|
|
|
|
disease_list = [] |
|
symptoms_list = [] |
|
unique_symptoms_list = [] |
|
for _, row in matching_diseases.iterrows(): |
|
disease_info = { |
|
'Disease': row['Name'], |
|
'Symptoms': row['Symptoms'], |
|
'Treatments': row['Treatments'] |
|
} |
|
disease_list.append(disease_info) |
|
symptoms_info = row['Symptoms'] |
|
symptoms_list.append(symptoms_info) |
|
for i in range(len(symptoms_list)): |
|
for j in range(len(symptoms_list[i])): |
|
if symptoms_list[i][j] not in unique_symptoms_list: |
|
unique_symptoms_list.append(symptoms_list[i][j]) |
|
return {"disease_list": disease_list, "unique_symptoms_list": unique_symptoms_list} |
|
|
|
class SelectedSymptomsQuery(BaseModel): |
|
selected_symptoms: list |
|
|
|
@app.post("/find_disease") |
|
def find_disease(query: SelectedSymptomsQuery): |
|
selected_symptoms = query.selected_symptoms |
|
|
|
matching_diseases = df[df['Symptoms'].apply(lambda x: any(s in x for s in selected_symptoms))] |
|
|
|
|
|
matching_diseases['match_count'] = matching_diseases['Symptoms'].apply(lambda x: sum(s in selected_symptoms for s in x)) |
|
matching_diseases = matching_diseases.sort_values(by='match_count', ascending=False) |
|
|
|
|
|
disease_list = [] |
|
max_match_count_disease = None |
|
max_match_count = -1 |
|
|
|
for _, row in matching_diseases.iterrows(): |
|
disease_info = { |
|
'Disease': row['Name'], |
|
'Symptoms': row['Symptoms'], |
|
'Treatments': row['Treatments'], |
|
'MatchCount': row['match_count'] |
|
} |
|
disease_list.append(disease_info) |
|
|
|
|
|
if row['match_count'] > max_match_count: |
|
max_match_count = row['match_count'] |
|
max_match_count_disease = disease_info |
|
|
|
return {"disease_list": disease_list, "max_match_count_disease": max_match_count_disease} |
|
class DiseaseListQuery(BaseModel): |
|
disease_list: list |
|
|
|
class DiseaseDetail(BaseModel): |
|
Disease: str |
|
Symptoms: list |
|
Treatments: str |
|
MatchCount: int |
|
|
|
@app.post("/pass2llm") |
|
def pass2llm(query: DiseaseDetail): |
|
|
|
disease_list_details = query |
|
|
|
|
|
headers = { |
|
"Authorization": "Bearer 2npJaJjnLBj1RGPcGf0QiyAAJHJ_5qqtw2divkpoAipqN9WLG", |
|
"Ngrok-Version": "2" |
|
} |
|
response = requests.get("https://api.ngrok.com/endpoints", headers=headers) |
|
|
|
|
|
if response.status_code == 200: |
|
llm_api_response = response.json() |
|
public_url = llm_api_response['endpoints'][0]['public_url'] |
|
|
|
|
|
prompt = f"Here is a list of diseases and their details: {disease_list_details}. Please generate a summary." |
|
|
|
|
|
llm_headers = { |
|
"Content-Type": "application/json" |
|
} |
|
llm_payload = { |
|
"model": "llama3", |
|
"prompt": prompt, |
|
"stream": False |
|
} |
|
llm_response = requests.post(f"{public_url}/api/generate", headers=llm_headers, json=llm_payload) |
|
|
|
|
|
if llm_response.status_code == 200: |
|
llm_response_json = llm_response.json() |
|
return {"message": "Successfully passed to LLM!", "llm_response": llm_response_json.get("response")} |
|
else: |
|
return {"message": "Failed to get response from LLM!", "error": llm_response.text} |
|
else: |
|
return {"message": "Failed to get public URL from Ngrok!", "error": response.text} |
|
|
|
|
|
|
|
|