bryanmildort's picture
Update app.py
8aac452
raw history blame
No virus
2.73 kB
import streamlit as st
import pandas as pd
from PIL import Image
import re
def summarize_function(notes):
gen_text = pipe(notes, max_new_tokens=(len(notes.split(' '))*2*.215), temperature=0.8, num_return_sequences=1, top_p=0.2)[0]['generated_text'][len(notes):]
for i in range(len(gen_text)):
if gen_text[-i-8:].startswith('[Notes]:'):
gen_text = gen_text[:-i-8]
st.write('Summary: ')
return gen_text
notes_df = pd.read_csv('notes_small.csv')
examples_tuple = ()
for i in range(len(notes_df)):
examples_tuple += (f"Patient {i+1}", )
example = st.sidebar.selectbox('Example', (examples_tuple), index=0)
st.markdown("<h1 style='text-align: center; color: #489DDB;'>GPT Clinical Notes Summarizer</h1>", unsafe_allow_html=True)
st.markdown("<h6 style='text-align: center; color: #489DDB;'>by Bryan Mildort</h1>", unsafe_allow_html=True)
st.sidebar.markdown("<h1 style='text-align: center; color: #489DDB;'>GPT Clinical Notes Summarizer 0.1v</h1>", unsafe_allow_html=True)
st.sidebar.markdown("<h6 style='text-align: center; color: #489DDB;'>The model for this application was created with the generous support of the Google TPU Research Cloud (TRC). This demo is for investigative research purposes only. The model is assumed to have several limiations and biases, so please oversee responses with human moderation. It is not intended for production ready enterprises and is displayed to illustrate the capabilities of Large Language Models for health care research.</h1>", unsafe_allow_html=True)
tower = Image.open('howard_social.png')
seal = Image.open('Howard_University_seal.svg.png')
st.sidebar.image(tower)
# st.sidebar.image(seal)
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# from accelerate import infer_auto_device_map
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
# device_str = f"""Device being used: {device}"""
# st.write(device_str)
# device_map = infer_auto_device_map(model, dtype="float16")
# st.write(device_map)
@st.cache(allow_output_mutation=True)
def load_model():
model = AutoModelForCausalLM.from_pretrained("bryanmildort/gpt_neo_notes", low_cpu_mem_usage=True)
# model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained("bryanmildort/gpt_neo_notes")
return pipeline("text-generation", model=model, tokenizer=tokenizer)
pipe = load_model()
prompt = notes_df.iloc[int(example[-1:])-1].PARSED
input_text = st.text_area("Notes:", prompt)
if st.button('Summarize'):
parsed_input = re.sub(r'\n\s*\n', '\n\n', input_text)
parsed_input = re.sub(r'\n+', '\n',parsed_input)
final_input = f"""[Notes]:\n{parsed_input}\n[Summary]:\n"""
st.write(summarize_function(final_input))