thankrandomness's picture
change data ingestion logic
b8e33aa
import os
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
import chromadb
#from chromadb.utils import PersistenceManager
import gradio as gr
# Load the Hugging Face token from the environment variable
# hf_token = os.getenv("HF_API_TOKEN")
# Load the private dataset using the token
#dataset = load_dataset("thankrandomness/mimic-iii", token=hf_token)
dataset = load_dataset("thankrandomness/mimic-iii-sample")
# Load PubMedBERT model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
model = AutoModel.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
# Initialize ChromaDB client with persistence
#persistence_manager = PersistenceManager("/mnt/data/chromadb")
#client = chromadb.Client(persistence_manager=persistence_manager)
#client = chromadb.Client()
#collection = client.get_or_create_collection(name="pubmedbert_matryoshka_embeddings")
#collection = client.get_or_create_collection(name="pubmedbert_embeddings")
# Function to embed text
def embed_text(text, max_length=512):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length)
with torch.no_grad():
embeddings = model(**inputs).last_hidden_state.mean(dim=1).squeeze()
return embeddings.numpy()
# Initialize ChromaDB client
client = chromadb.Client()
collection = client.create_collection(name="pubmedbert_embeddings")
# Process the dataset and upsert into ChromaDB
for i, row in enumerate(dataset['train']):
for note in row['notes']:
text = note.get('text', '')
annotations_list = []
for annotation in note.get('annotations', []):
try:
code = annotation['code']
code_system = annotation['code_system']
description = annotation['description']
#annotations_list.append(f"{code}: {code_system}: {description}")
annotations_list.append({"code": code, "code_system": code_system, "description": description})
except KeyError as e:
print(f"Skipping annotation due to missing key: {e}")
print(f"Processed annotations for note {note['note_id']}: {annotations_list}")
if text and annotations_list:
embeddings = embed_text([text])[0]
# Upsert data, embeddings, and annotations into ChromaDB
for j, annotation in enumerate(annotations_list):
collection.upsert(
ids=[f"note_{note['note_id']}_{j}"],
embeddings=[embeddings],
metadatas=[annotation]
)
else:
print(f"Skipping note {note['note_id']} due to missing 'text' or 'annotations'")
# Define retrieval function
def retrieve_relevant_text(input_text):
input_embedding = embed_text([input_text])[0] # Get the embedding for the single input text
results = collection.query(query_embeddings=[input_embedding], n_results=5)
print(results)
# Extract code and similarity scores
output = []
for result in results['results']:
print(result)
for annotation in result["metadata"]["annotations"]:
output.append({
"similarity_score": result["distances"],
"annotation": annotation
})
return output
# Gradio interface
def gradio_interface(input_text):
results = retrieve_relevant_text(input_text)
formatted_results = [
f"Similarity Score: {result['similarity_score']:.2f}, Code: {result['code']}, Description: {result['description']}"
for result in results
]
return formatted_results
interface = gr.Interface(fn=gradio_interface, inputs="text", outputs="text")
interface.launch()