MiniMed_EHR_Analyst / MiniMed_EHR_Analyst_Spaces.py
Solshine's picture
Changed model to squarelike/llama2-ko-medical-7b
d603b0b
raw
history blame
2.24 kB
import streamlit as st
import pandas as pd
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
#Note this should be used always in compliance with applicable laws and regulations if used with real patient data.
# Load the tokenizer and model:
# Now is squarelike/llama2-ko-medical-7b
# formerly pseudolab/K23_MiniMed by Tonic (Note: This is a large model and will take a while to download)
# Config issues persist with this model, unfortunately. It may not be ready for use.
tokenizer = AutoTokenizer.from_pretrained("squarelike/llama2-ko-medical-7b")
model = AutoModelForCausalLM.from_pretrained("squarelike/llama2-ko-medical-7b")
#Upload Patient Data
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
# Prepare the context
def prepare_context(data):
# Format the data as a string
data_str = data.to_string(index=False, header=False)
# Tokenize the data
input_ids = tokenizer.encode(data_str, return_tensors="pt")
# Truncate the input if it's too long for the model
max_length = tokenizer.model_max_length
if input_ids.shape[1] > max_length:
input_ids = input_ids[:, :max_length]
return input_ids
if uploaded_file is not None:
data = pd.read_csv(uploaded_file)
st.write(data)
# Generate text based on the context
context = prepare_context(data)
generated_text = pipeline('text-generation', model=model)(context)[0]['generated_text']
st.write(generated_text)
# Internally prompt the model to data analyze the EHR patient data
prompt = "You are an Electronic Health Records analyst with nursing school training. Please analyze patient data that you are provided here. Give an organized, step-by-step, formatted health records analysis. You will always be truthful and if you do nont know the answer say you do not know."
if prompt:
# Tokenize the prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# Generate text based on the prompt
generated_text = pipeline('text-generation', model=model)(input_ids=input_ids)[0]['generated_text']
st.write(generated_text)
else:
st.write("Please enter patient data")
else:
st.write("No file uploaded")