bryanmildort's picture
Update app.py
19487c7
raw history blame
No virus
1.93 kB
import streamlit as st
import pandas as pd
import re
def summarize_function(notes):
gen_text = pipe(notes, max_new_tokens=(len(notes.split(' '))*2*.215), temperature=0.8, num_return_sequences=1, top_p=0.2)[0]['generated_text'][len(notes):]
for i in range(len(gen_text)):
if gen_text[-i-8:].startswith('[Notes]:'):
gen_text = gen_text[:-i-8]
st.write('Summary: ')
return gen_text
st.markdown("<h1 style='text-align: center; color: #489DDB;'>GPT Clinical Notes Summarizer 0.1v</h1>", unsafe_allow_html=True)
st.markdown("<h6 style='text-align: center; color: #489DDB;'>by Bryan Mildort</h1>", unsafe_allow_html=True)
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# from accelerate import infer_auto_device_map
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
# device_str = f"""Device being used: {device}"""
# st.write(device_str)
# device_map = infer_auto_device_map(model, dtype="float16")
# st.write(device_map)
@st.cache(allow_output_mutation=True)
def load_model():
model = AutoModelForCausalLM.from_pretrained("bryanmildort/gpt_neo_notes", low_cpu_mem_usage=True)
# model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained("bryanmildort/gpt_neo_notes")
return pipeline("text-generation", model=model, tokenizer=tokenizer)
pipe = load_model()
notes_df = pd.read_csv('notes_small.csv')
examples_tuple = ()
for i in range(len(notes_df)):
examples_tuple += (f"Patient {i+1}", )
example = st.sidebar.selectbox('Example', (examples_tuple), index=0)
st.write(example)
prompt = notes_df.iloc[int(example[-1:])-1].PARSED
input_text = st.text_area("Notes:", prompt)
if st.button('Summarize'):
parsed_input = re.sub(r'\n\s*\n', '\n\n', input_text)
parsed_input = re.sub(r'\n+', '\n',parsed_input)
final_input = f"""[Notes]:\n{parsed_input}\n[Summary]:\n"""
st.write(summarize_function(final_input))