Files changed (1) hide show
  1. app.py +17 -24
app.py CHANGED
@@ -2,7 +2,6 @@ import streamlit as st
2
  import pandas as pd
3
  from transformers import pipeline, AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
4
  from peft import PeftModel, PeftConfig
5
- import gradio as gr
6
 
7
  #Note this should be used always in compliance with applicable laws and regulations if used with real patient data.
8
 
@@ -16,7 +15,8 @@ peft_config = PeftConfig.from_pretrained("pseudolab/K23_MiniMed")
16
  peft_model = MistralForCausalLM.from_pretrained("pseudolab/K23_MiniMed", trust_remote_code=True)
17
  peft_model = PeftModel.from_pretrained(peft_model, "pseudolab/K23_MiniMed")
18
 
19
- text_generator = pipeline('text-generation', model=peft_model, tokenizer=tokenizer)
 
20
 
21
  # Prepare the context
22
  def prepare_context(data):
@@ -24,25 +24,23 @@ def prepare_context(data):
24
  data_str = data.to_string(index=False, header=False)
25
 
26
  # Tokenize the data
27
- # input_ids = tokenizer.encode(data_str, return_tensors="pt")
28
 
29
  # Truncate the input if it's too long for the model
30
- # max_length = tokenizer.model_max_length
31
- # if input_ids.shape[1] > max_length:
32
- # input_ids = input_ids[:, :max_length]
33
- input_ids = data_str
34
 
35
  return input_ids
36
 
37
- def fn(uploaded_file) -> str:
38
  data = pd.read_csv(uploaded_file)
39
- ret = ""
40
 
41
  # Generate text based on the context
42
  context = prepare_context(data)
43
- # generated_text = pipeline('text-generation', model=peft_model, tokenizer=tokenizer)(context)[0]['generated_text']
44
- generated_text = text_generator(context)[0]['generated_text']
45
- ret += generated_text
46
 
47
  # Internally prompt the model to data analyze the EHR patient data
48
  prompt = "You are an Electronic Health Records analyst with nursing school training. Please analyze patient data that you are provided here. Give an organized, step-by-step, formatted health records analysis. You will always be truthful and if you do nont know the answer say you do not know."
@@ -52,15 +50,10 @@ def fn(uploaded_file) -> str:
52
  input_ids = tokenizer.encode(prompt, return_tensors="pt")
53
 
54
  # Generate text based on the prompt
55
- # generated_text = pipeline('text-generation', model=peft_model, tokenizer=tokenizer)(input_ids=input_ids)[0]['generated_text']
56
- generated_text = text_generator(prompt)[0]['generated_text']
57
- ret += generated_text
58
-
59
- return ret
60
-
61
-
62
- demo = gr.Interface(fn=fn, inputs="file", outputs="text", theme="pseudolab/huggingface-korea-theme")
63
-
64
-
65
- if __name__ == "__main__":
66
- demo.launch(show_api=False)
 
2
  import pandas as pd
3
  from transformers import pipeline, AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
4
  from peft import PeftModel, PeftConfig
 
5
 
6
  #Note this should be used always in compliance with applicable laws and regulations if used with real patient data.
7
 
 
15
  peft_model = MistralForCausalLM.from_pretrained("pseudolab/K23_MiniMed", trust_remote_code=True)
16
  peft_model = PeftModel.from_pretrained(peft_model, "pseudolab/K23_MiniMed")
17
 
18
+ #Upload Patient Data
19
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
20
 
21
  # Prepare the context
22
  def prepare_context(data):
 
24
  data_str = data.to_string(index=False, header=False)
25
 
26
  # Tokenize the data
27
+ input_ids = tokenizer.encode(data_str, return_tensors="pt")
28
 
29
  # Truncate the input if it's too long for the model
30
+ max_length = tokenizer.model_max_length
31
+ if input_ids.shape[1] > max_length:
32
+ input_ids = input_ids[:, :max_length]
 
33
 
34
  return input_ids
35
 
36
+ if uploaded_file is not None:
37
  data = pd.read_csv(uploaded_file)
38
+ st.write(data)
39
 
40
  # Generate text based on the context
41
  context = prepare_context(data)
42
+ generated_text = pipeline('text-generation', model=model)(context)[0]['generated_text']
43
+ st.write(generated_text)
 
44
 
45
  # Internally prompt the model to data analyze the EHR patient data
46
  prompt = "You are an Electronic Health Records analyst with nursing school training. Please analyze patient data that you are provided here. Give an organized, step-by-step, formatted health records analysis. You will always be truthful and if you do nont know the answer say you do not know."
 
50
  input_ids = tokenizer.encode(prompt, return_tensors="pt")
51
 
52
  # Generate text based on the prompt
53
+ generated_text = pipeline('text-generation', model=model)(input_ids=input_ids)[0]['generated_text']
54
+ st.write(generated_text)
55
+ else:
56
+ st.write("Please enter patient data")
57
+
58
+ else:
59
+ st.write("No file uploaded")