Kalbe-x-Bangkit commited on
Commit
9bc7520
1 Parent(s): c2c9e04

Update app.py

Browse files

change model to IDEFICS2 MedVQA

Files changed (1) hide show
  1. app.py +111 -13
app.py CHANGED
@@ -1,5 +1,13 @@
 
 
 
 
1
  import gradio as gr
2
- from transformers import pipeline
 
 
 
 
3
 
4
  # Project description
5
  description = """
@@ -14,9 +22,8 @@ The model is trained using the [Hugging face](https://huggingface.co/datasets/fl
14
  Reference: [ScienceDirect](https://www.sciencedirect.com/science/article/abs/pii/S0933365723001252)
15
 
16
  ## Model Architecture
17
- The model uses a Parameterized Hypercomplex Shared Encoder network (PHYSEnet).
18
 
19
- ![Model Architecture](path/to/your/image.png)
20
 
21
  Reference: [ScienceDirect](https://www.sciencedirect.com/science/article/abs/pii/S0933365723001252)
22
 
@@ -24,19 +31,87 @@ Reference: [ScienceDirect](https://www.sciencedirect.com/science/article/abs/pii
24
  Please select the example below or upload 4 pairs of mammography exam results.
25
  """
26
 
27
- # Load the Visual QA model
28
- generator = pipeline("visual-question-answering", model="jihadzakki/blip1-medvqa")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def format_answer(image, question, history):
31
  try:
32
- result = generator(image, question, max_new_tokens=50)
33
- predicted_answer = result[0].get('answer', 'No answer found')
34
- history.append((image, f"Question: {question} | Answer: {predicted_answer}"))
35
-
36
- return f"Predicted Answer: {predicted_answer}", history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  except Exception as e:
 
 
38
  return f"Error: {str(e)}", history
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def switch_theme(mode):
41
  if mode == "Light Mode":
42
  return gr.themes.Default()
@@ -60,9 +135,9 @@ with gr.Blocks(
60
  secondary_hue=gr.themes.colors.red,
61
  )
62
  ) as VisualQAApp:
63
- gr.Markdown(description, elem_classes="description")
64
 
65
- gr.Markdown("# Visual Question Answering using BLIP Model", elem_classes="title")
66
 
67
  with gr.Row():
68
  with gr.Column():
@@ -82,6 +157,29 @@ with gr.Blocks(
82
  show_progress=True
83
  )
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  with gr.Row():
86
  history_gallery = gr.Gallery(label="History Log", elem_id="history_log")
87
  submit_button.click(
@@ -117,4 +215,4 @@ with gr.Blocks(
117
  outputs=[feedback_input]
118
  )
119
 
120
- VisualQAApp.launch(share=True)
 
1
+ import os
2
+ import subprocess
3
+ from PIL import Image
4
+ import io
5
  import gradio as gr
6
+ from transformers import AutoProcessor, TextIteratorStreamer
7
+ from transformers import Idefics2ForConditionalGeneration
8
+ import torch
9
+ from peft import LoraConfig
10
+ from transformers import AutoProcessor, BitsAndBytesConfig, IdeficsForVisionText2Text
11
 
12
  # Project description
13
  description = """
 
22
  Reference: [ScienceDirect](https://www.sciencedirect.com/science/article/abs/pii/S0933365723001252)
23
 
24
  ## Model Architecture
 
25
 
26
+ ![Model Architecture](img/Model-Architecture.png)
27
 
28
  Reference: [ScienceDirect](https://www.sciencedirect.com/science/article/abs/pii/S0933365723001252)
29
 
 
31
  Please select the example below or upload 4 pairs of mammography exam results.
32
  """
33
 
34
+ DEVICE = torch.device("cuda")
35
+
36
+ USE_LORA = False
37
+ USE_QLORA = True
38
+
39
+ if USE_QLORA or USE_LORA:
40
+ lora_config = LoraConfig(
41
+ r=8,
42
+ lora_alpha=8,
43
+ lora_dropout=0.1,
44
+ target_modules='.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$',
45
+ use_dora=False if USE_QLORA else True,
46
+ init_lora_weights="gaussian"
47
+ )
48
+ if USE_QLORA:
49
+ bnb_config = BitsAndBytesConfig(
50
+ load_in_4bit=True,
51
+ bnb_4bit_quant_type="nf4",
52
+ bnb_4bit_compute_dtype=torch.float16
53
+ )
54
+
55
+ model = Idefics2ForConditionalGeneration.from_pretrained(
56
+ # "jihadzakki/idefics2-8b-vqarad-delta",
57
+ torch_dtype=torch.float16,
58
+ quantization_config=bnb_config
59
+ )
60
+
61
+
62
+ processor = AutoProcessor.from_pretrained(
63
+ "HuggingFaceM4/idefics2-8b",
64
+ )
65
 
66
  def format_answer(image, question, history):
67
  try:
68
+ messages = [
69
+ {
70
+ "role": "user",
71
+ "content": [
72
+ {"type": "image"},
73
+ {"type": "text", "text": question}
74
+ ]
75
+ }
76
+ ]
77
+
78
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
79
+ inputs = processor(text=[text.strip()], images=[image], return_tensors="pt", padding=True)
80
+ inputs = {key: value.to(DEVICE) for key, value in inputs.items()}
81
+ generated_ids = model.generate(**inputs, max_new_tokens=64)
82
+ generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True)[0]
83
+
84
+ history.append((image, f"Question: {question} | Answer: {generated_texts}"))
85
+
86
+ # Store the predicted answer in a variable before deleting intermediate variables
87
+ predicted_answer = f"Predicted Answer: {generated_texts}"
88
+
89
+ # Clear the cache and delete unnecessary variables
90
+ del inputs
91
+ del generated_ids
92
+ del generated_texts
93
+ torch.cuda.empty_cache()
94
+
95
+ return predicted_answer, history
96
  except Exception as e:
97
+ # Clear the cache in case of an error
98
+ torch.cuda.empty_cache()
99
  return f"Error: {str(e)}", history
100
 
101
+ def clear_history():
102
+ return "", []
103
+
104
+ def undo_last(history):
105
+ if history:
106
+ history.pop()
107
+ return "", history
108
+
109
+ def retry_last(image, question, history):
110
+ if history:
111
+ last_image, last_entry = history[-1]
112
+ return format_answer(last_image, question, history[:-1])
113
+ return "No previous analysis to retry.", history
114
+
115
  def switch_theme(mode):
116
  if mode == "Light Mode":
117
  return gr.themes.Default()
 
135
  secondary_hue=gr.themes.colors.red,
136
  )
137
  ) as VisualQAApp:
138
+ gr.Markdown(description, elem_classes="title") # Display the project description
139
 
140
+ gr.Markdown("## Demo")
141
 
142
  with gr.Row():
143
  with gr.Column():
 
157
  show_progress=True
158
  )
159
 
160
+ with gr.Row():
161
+ retry_button = gr.Button("Retry")
162
+ undo_button = gr.Button("Undo")
163
+ clear_button = gr.Button("Clear")
164
+
165
+ retry_button.click(
166
+ retry_last,
167
+ inputs=[image_input, question_input, history_state],
168
+ outputs=[answer_output, history_state]
169
+ )
170
+
171
+ undo_button.click(
172
+ undo_last,
173
+ inputs=[history_state],
174
+ outputs=[answer_output, history_state]
175
+ )
176
+
177
+ clear_button.click(
178
+ clear_history,
179
+ inputs=[],
180
+ outputs=[answer_output, history_state]
181
+ )
182
+
183
  with gr.Row():
184
  history_gallery = gr.Gallery(label="History Log", elem_id="history_log")
185
  submit_button.click(
 
215
  outputs=[feedback_input]
216
  )
217
 
218
+ VisualQAApp.launch(share=True, debug=True)