will.k commited on
Commit
dc822cd
1 Parent(s): 556f311
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -16,18 +16,21 @@ peft_config = PeftConfig.from_pretrained("pseudolab/K23_MiniMed")
16
  peft_model = MistralForCausalLM.from_pretrained("pseudolab/K23_MiniMed", trust_remote_code=True)
17
  peft_model = PeftModel.from_pretrained(peft_model, "pseudolab/K23_MiniMed")
18
 
 
 
19
  # Prepare the context
20
  def prepare_context(data):
21
  # Format the data as a string
22
  data_str = data.to_string(index=False, header=False)
23
 
24
  # Tokenize the data
25
- input_ids = tokenizer.encode(data_str, return_tensors="pt")
26
 
27
  # Truncate the input if it's too long for the model
28
- max_length = tokenizer.model_max_length
29
- if input_ids.shape[1] > max_length:
30
- input_ids = input_ids[:, :max_length]
 
31
 
32
  return input_ids
33
 
@@ -37,7 +40,8 @@ def fn(uploaded_file) -> str:
37
 
38
  # Generate text based on the context
39
  context = prepare_context(data)
40
- generated_text = pipeline('text-generation', model=peft_model)(context)[0]['generated_text']
 
41
  ret += generated_text
42
 
43
  # Internally prompt the model to data analyze the EHR patient data
@@ -48,13 +52,14 @@ def fn(uploaded_file) -> str:
48
  input_ids = tokenizer.encode(prompt, return_tensors="pt")
49
 
50
  # Generate text based on the prompt
51
- generated_text = pipeline('text-generation', model=peft_model)(input_ids=input_ids)[0]['generated_text']
 
52
  ret += generated_text
53
 
54
  return ret
55
 
56
 
57
- demo = gr.Interface(fn=fn, inputs="file", outputs="text")
58
 
59
 
60
  if __name__ == "__main__":
 
16
  peft_model = MistralForCausalLM.from_pretrained("pseudolab/K23_MiniMed", trust_remote_code=True)
17
  peft_model = PeftModel.from_pretrained(peft_model, "pseudolab/K23_MiniMed")
18
 
19
+ text_generator = pipeline('text-generation', model=peft_model, tokenizer=tokenizer)
20
+
21
  # Prepare the context
22
  def prepare_context(data):
23
  # Format the data as a string
24
  data_str = data.to_string(index=False, header=False)
25
 
26
  # Tokenize the data
27
+ # input_ids = tokenizer.encode(data_str, return_tensors="pt")
28
 
29
  # Truncate the input if it's too long for the model
30
+ # max_length = tokenizer.model_max_length
31
+ # if input_ids.shape[1] > max_length:
32
+ # input_ids = input_ids[:, :max_length]
33
+ input_ids = data_str
34
 
35
  return input_ids
36
 
 
40
 
41
  # Generate text based on the context
42
  context = prepare_context(data)
43
+ # generated_text = pipeline('text-generation', model=peft_model, tokenizer=tokenizer)(context)[0]['generated_text']
44
+ generated_text = text_generator(context)[0]['generated_text']
45
  ret += generated_text
46
 
47
  # Internally prompt the model to data analyze the EHR patient data
 
52
  input_ids = tokenizer.encode(prompt, return_tensors="pt")
53
 
54
  # Generate text based on the prompt
55
+ # generated_text = pipeline('text-generation', model=peft_model, tokenizer=tokenizer)(input_ids=input_ids)[0]['generated_text']
56
+ generated_text = text_generator(prompt)[0]['generated_text']
57
  ret += generated_text
58
 
59
  return ret
60
 
61
 
62
+ demo = gr.Interface(fn=fn, inputs="file", outputs="text", theme="pseudolab/huggingface-korea-theme")
63
 
64
 
65
  if __name__ == "__main__":