Rohith1112 commited on
Commit
e7bcdf0
·
verified ·
1 Parent(s): aff94f6
Files changed (1) hide show
  1. app.py +30 -77
app.py CHANGED
@@ -4,11 +4,9 @@ import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
7
- from datasets import load_dataset
8
- from evaluate import load # For evaluation metrics
9
 
10
  # Model setup
11
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Use GPU if available
12
  dtype = torch.float32
13
  model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
14
  proj_out_num = 256
@@ -16,8 +14,8 @@ proj_out_num = 256
16
  # Load model and tokenizer
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_name_or_path,
19
- torch_dtype=dtype,
20
- device_map=device.type,
21
  trust_remote_code=True
22
  )
23
 
@@ -29,50 +27,43 @@ tokenizer = AutoTokenizer.from_pretrained(
29
  trust_remote_code=True
30
  )
31
 
32
- # Load the M3D-Cap dataset
33
- dataset = load_dataset("GoodBaiBai88/M3D-Cap")
34
-
35
  # Chat history storage
36
  chat_history = []
37
  current_image = None
38
 
39
  def extract_and_display_images(image_path):
40
- try:
41
- npy_data = np.load(image_path)
42
- if npy_data.ndim == 4 and npy_data.shape[1] == 32:
43
- npy_data = npy_data[0]
44
- elif npy_data.ndim != 3 or npy_data.shape[0] != 32:
45
- return "Invalid .npy file format. Expected shape (1, 32, 256, 256) or (32, 256, 256)."
46
-
47
- fig, axes = plt.subplots(4, 8, figsize=(12, 6))
48
- for i, ax in enumerate(axes.flat):
49
- ax.imshow(npy_data[i], cmap='gray')
50
- ax.axis('off')
51
-
52
- image_output = "extracted_images.png"
53
- plt.savefig(image_output, bbox_inches='tight')
54
- plt.close()
55
- return image_output
56
- except Exception as e:
57
- return f"Error processing image: {str(e)}"
58
 
59
  def process_image(question):
60
  global current_image
61
  if current_image is None:
62
  return "Please upload an image first."
63
 
64
- try:
65
- image_np = np.load(current_image)
66
- image_tokens = "<im_patch>" * proj_out_num
67
- input_txt = image_tokens + question
68
- input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
69
-
70
- image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
71
- generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
72
- generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
73
- return generated_texts[0]
74
- except Exception as e:
75
- return f"Error generating response: {str(e)}"
76
 
77
  def chat_interface(question):
78
  global chat_history
@@ -80,48 +71,13 @@ def chat_interface(question):
80
  chat_history.append((question, response))
81
  return chat_history
82
 
 
83
  def upload_image(image):
84
  global current_image
85
  current_image = image.name
86
  extracted_image_path = extract_and_display_images(current_image)
87
  return "Image uploaded and processed successfully!", extracted_image_path
88
 
89
- def test_model_with_dataset():
90
- # Load evaluation metrics
91
- bleu = load("bleu")
92
- rouge = load("rouge")
93
-
94
- # Initialize lists to store predictions and references
95
- predictions = []
96
- references = []
97
-
98
- # Iterate over the dataset
99
- for example in dataset['train']: # Use 'train', 'validation', or 'test' split
100
- image_path = example['image'] # Assuming 'image' contains the path to the .npy file
101
- question = example['caption'] # Assuming 'caption' contains the question or caption
102
-
103
- # Upload the image
104
- upload_image({"name": image_path})
105
-
106
- # Get the model's response
107
- response = process_image(question)
108
-
109
- # Store predictions and references
110
- predictions.append(response)
111
- references.append(question)
112
-
113
- # Print results for debugging
114
- print(f"Question: {question}")
115
- print(f"Model Response: {response}")
116
- print("---")
117
-
118
- # Compute evaluation metrics
119
- bleu_score = bleu.compute(predictions=predictions, references=references)
120
- rouge_score = rouge.compute(predictions=predictions, references=references)
121
-
122
- print(f"BLEU Score: {bleu_score}")
123
- print(f"ROUGE Score: {rouge_score}")
124
-
125
  # Gradio UI
126
  with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
127
  gr.Markdown("ICliniq AI-Powered Medical Image Analysis Workspace")
@@ -139,7 +95,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
139
  submit_button.click(chat_interface, question_input, chat_list)
140
  question_input.submit(chat_interface, question_input, chat_list)
141
 
142
- # Uncomment to test the model with the dataset
143
- # test_model_with_dataset()
144
-
145
  chat_ui.launch()
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
 
 
7
 
8
  # Model setup
9
+ device = torch.device('cpu') # Use 'cuda' if GPU is available
10
  dtype = torch.float32
11
  model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
12
  proj_out_num = 256
 
14
  # Load model and tokenizer
15
  model = AutoModelForCausalLM.from_pretrained(
16
  model_name_or_path,
17
+ torch_dtype=torch.float32,
18
+ device_map='cpu',
19
  trust_remote_code=True
20
  )
21
 
 
27
  trust_remote_code=True
28
  )
29
 
 
 
 
30
  # Chat history storage
31
  chat_history = []
32
  current_image = None
33
 
34
  def extract_and_display_images(image_path):
35
+ npy_data = np.load(image_path)
36
+ if npy_data.ndim == 4 and npy_data.shape[1] == 32:
37
+ npy_data = npy_data[0]
38
+ elif npy_data.ndim != 3 or npy_data.shape[0] != 32:
39
+ return "Invalid .npy file format. Expected shape (1, 32, 256, 256) or (32, 256, 256)."
40
+
41
+ fig, axes = plt.subplots(4, 8, figsize=(12, 6))
42
+ for i, ax in enumerate(axes.flat):
43
+ ax.imshow(npy_data[i], cmap='gray')
44
+ ax.axis('off')
45
+
46
+ image_output = "extracted_images.png"
47
+ plt.savefig(image_output, bbox_inches='tight')
48
+ plt.close()
49
+ return image_output
50
+
 
 
51
 
52
  def process_image(question):
53
  global current_image
54
  if current_image is None:
55
  return "Please upload an image first."
56
 
57
+ image_np = np.load(current_image)
58
+ image_tokens = "<im_patch>" * proj_out_num
59
+ input_txt = image_tokens + question
60
+ input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
61
+
62
+ image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
63
+ generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
64
+ generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
65
+ return generated_texts[0]
66
+
 
 
67
 
68
  def chat_interface(question):
69
  global chat_history
 
71
  chat_history.append((question, response))
72
  return chat_history
73
 
74
+
75
  def upload_image(image):
76
  global current_image
77
  current_image = image.name
78
  extracted_image_path = extract_and_display_images(current_image)
79
  return "Image uploaded and processed successfully!", extracted_image_path
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # Gradio UI
82
  with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
83
  gr.Markdown("ICliniq AI-Powered Medical Image Analysis Workspace")
 
95
  submit_button.click(chat_interface, question_input, chat_list)
96
  question_input.submit(chat_interface, question_input, chat_list)
97
 
 
 
 
98
  chat_ui.launch()