Spaces:
Runtime error
Runtime error
import gradio as gr | |
import json | |
from pathlib import Path | |
from sklearn.metrics.pairwise import cosine_similarity | |
from InstructorEmbedding import INSTRUCTOR | |
model = INSTRUCTOR("hkunlp/instructor-large") | |
EMBED_FILE_PATH = Path(__file__).parent / "top15_disorders_emb.json" | |
case_note_instruction = ( | |
"Represent the case note for a possible DSM-5 mental health diagnosis:" | |
) | |
def read_json_file(file_path): | |
with open(file_path, "r", encoding="utf-8") as file: | |
data = json.load(file) | |
return data | |
def get_case_note_embedding(case_note): | |
case_note_emb = model.encode([[case_note_instruction, case_note]]) | |
return case_note_emb | |
def get_top_n_diagnoses(case_note): | |
top_n = 3 | |
case_note_emb = get_case_note_embedding(case_note) | |
diagnoses_embed = read_json_file(EMBED_FILE_PATH)["diagnoses"] | |
diagnosis_list = [] | |
for diagnosis in diagnoses_embed: | |
diagnosis_name = diagnosis["name"] | |
criteria_scores = ( | |
[] | |
) # Store criterion descriptions and scores for each diagnosis | |
count = 0 | |
sum_score = 0 | |
for criterion in diagnosis["criteria"]: | |
count += 1 | |
score = cosine_similarity(criterion["embedding"], case_note_emb) | |
sum_score += score | |
criterion_data = { | |
"description": criterion["description"], | |
"score": float(score), | |
} | |
if "symptom_list" in criterion: | |
criterion_data["symptom_list"] = criterion["symptom_list"] | |
criteria_scores.append(criterion_data) | |
sum_score /= count | |
diagnosis_list.append( | |
{ | |
"name": diagnosis_name, | |
"sum_score": float(sum_score), | |
"criteria_scores": criteria_scores, | |
} | |
) | |
# Sort the diagnoses based on the sum_score in descending order | |
sorted_diagnoses = sorted( | |
diagnosis_list, key=lambda x: x["sum_score"], reverse=True | |
) | |
# Select the top n diagnoses | |
top_n_diagnoses = sorted_diagnoses[:top_n] | |
# Convert the top_n_diagnoses to JSON format | |
top_n_diagnoses_json = json.dumps(top_n_diagnoses, indent=2) | |
return top_n_diagnoses_json | |
def greet(name): | |
return "Hello " + name + "!!" | |
iface = gr.Interface(fn=get_top_n_diagnoses, inputs="text", outputs="text") | |
iface.launch() | |