File size: 1,928 Bytes
a2b2aea
b878344
898d20c
a2b2aea
 
2b7f563
a2b2aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc069c0
68f4d46
af1fb7b
19487c7
af1fb7b
19487c7
 
a2b2aea
946197d
a2b2aea
 
b878344
 
 
 
a2b2aea
b878344
 
a2b2aea
b878344
a2b2aea
 
9af7842
 
 
b878344
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import streamlit as st
import pandas as pd
import re

def summarize_function(notes):
    gen_text = pipe(notes, max_new_tokens=(len(notes.split(' '))*2*.215), temperature=0.8, num_return_sequences=1, top_p=0.2)[0]['generated_text'][len(notes):]
    for i in range(len(gen_text)):
        if gen_text[-i-8:].startswith('[Notes]:'):
            gen_text = gen_text[:-i-8]
    st.write('Summary: ')
    return gen_text

st.markdown("<h1 style='text-align: center; color: #489DDB;'>GPT Clinical Notes Summarizer 0.1v</h1>", unsafe_allow_html=True)
st.markdown("<h6 style='text-align: center; color: #489DDB;'>by Bryan Mildort</h1>", unsafe_allow_html=True)

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# from accelerate import infer_auto_device_map
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
# device_str = f"""Device being used: {device}"""
# st.write(device_str)
# device_map = infer_auto_device_map(model, dtype="float16")
# st.write(device_map)

@st.cache(allow_output_mutation=True)
def load_model():
    model = AutoModelForCausalLM.from_pretrained("bryanmildort/gpt_neo_notes", low_cpu_mem_usage=True)
    # model = model.to(device)
    tokenizer = AutoTokenizer.from_pretrained("bryanmildort/gpt_neo_notes")
    return pipeline("text-generation", model=model, tokenizer=tokenizer)   
    

pipe = load_model()


notes_df = pd.read_csv('notes_small.csv')
examples_tuple = ()
for i in range(len(notes_df)):
    examples_tuple += (f"Patient {i+1}", )

example = st.sidebar.selectbox('Example', (examples_tuple), index=0)
st.write(example)

prompt = notes_df.iloc[int(example[-1:])-1].PARSED
input_text = st.text_area("Notes:", prompt)
if st.button('Summarize'):
    parsed_input = re.sub(r'\n\s*\n', '\n\n', input_text)
    parsed_input = re.sub(r'\n+', '\n',parsed_input)
    final_input = f"""[Notes]:\n{parsed_input}\n[Summary]:\n"""
    st.write(summarize_function(final_input))