|
import os |
|
import torch |
|
from datasets import load_dataset |
|
from transformers import AutoTokenizer, AutoModel |
|
import chromadb |
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = load_dataset("thankrandomness/mimic-iii-sample") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") |
|
model = AutoModel.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def embed_text(text, max_length=512): |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length) |
|
with torch.no_grad(): |
|
embeddings = model(**inputs).last_hidden_state.mean(dim=1).squeeze() |
|
return embeddings.numpy() |
|
|
|
|
|
client = chromadb.Client() |
|
collection = client.create_collection(name="pubmedbert_embeddings") |
|
|
|
|
|
for i, row in enumerate(dataset['train']): |
|
for note in row['notes']: |
|
text = note.get('text', '') |
|
annotations_list = [] |
|
|
|
for annotation in note.get('annotations', []): |
|
try: |
|
code = annotation['code'] |
|
code_system = annotation['code_system'] |
|
description = annotation['description'] |
|
|
|
annotations_list.append({"code": code, "code_system": code_system, "description": description}) |
|
except KeyError as e: |
|
print(f"Skipping annotation due to missing key: {e}") |
|
|
|
print(f"Processed annotations for note {note['note_id']}: {annotations_list}") |
|
|
|
if text and annotations_list: |
|
embeddings = embed_text([text])[0] |
|
|
|
|
|
for j, annotation in enumerate(annotations_list): |
|
collection.upsert( |
|
ids=[f"note_{note['note_id']}_{j}"], |
|
embeddings=[embeddings], |
|
metadatas=[annotation] |
|
) |
|
else: |
|
print(f"Skipping note {note['note_id']} due to missing 'text' or 'annotations'") |
|
|
|
|
|
def retrieve_relevant_text(input_text): |
|
input_embedding = embed_text([input_text])[0] |
|
results = collection.query(query_embeddings=[input_embedding], n_results=5) |
|
print(results) |
|
|
|
output = [] |
|
for result in results['results']: |
|
print(result) |
|
for annotation in result["metadata"]["annotations"]: |
|
output.append({ |
|
"similarity_score": result["distances"], |
|
"annotation": annotation |
|
}) |
|
return output |
|
|
|
|
|
def gradio_interface(input_text): |
|
results = retrieve_relevant_text(input_text) |
|
formatted_results = [ |
|
f"Similarity Score: {result['similarity_score']:.2f}, Code: {result['code']}, Description: {result['description']}" |
|
for result in results |
|
] |
|
return formatted_results |
|
|
|
interface = gr.Interface(fn=gradio_interface, inputs="text", outputs="text") |
|
interface.launch() |