Update project_model.py

#1
Files changed (1) hide show
  1. project_model.py +123 -48
project_model.py CHANGED
@@ -9,66 +9,121 @@ Original file is located at
9
 
10
  # project_module.py
11
 
12
- import torch, cv2, time, os
 
13
  import numpy as np
14
  from PIL import Image
15
  from ultralytics import YOLO
16
  from transformers import pipeline, DPTFeatureExtractor, DPTForDepthEstimation
17
  from TTS.api import TTS
18
-
19
-
20
  from huggingface_hub import login
21
- import os
22
 
23
- # Login using token stored in environment variable
24
  login(token=os.environ["HUGGING_FACE_HUB_TOKEN"])
25
 
26
- # Load models
 
27
 
28
- device = "cuda" if torch.cuda.is_available() else "cpu" # Enable GPU
 
 
 
29
 
30
- yolo_model = YOLO("yolov9c.pt") # Load YOLOv9
31
- depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval() # Load MiDaS
32
- depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
33
- whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-small", device=0 if torch.cuda.is_available() else -1) # Load Whisper
34
- # Load Gemma-3-4B
 
 
 
35
  gemma_pipe = pipeline(
36
  "image-text-to-text",
37
  model="google/gemma-3-4b-it",
38
  device=0 if torch.cuda.is_available() else -1,
39
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
40
  )
41
- tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC") # Load Text-to-Speech (TTS)
42
 
43
- # Function to process image and audio
44
- def process_inputs(image: Image.Image, audio_path: str):
45
- # Convert PIL image to OpenCV format
46
- rgb_image = np.array(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  cv2_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
48
- pil_image = image
49
 
50
- # YOLO Detection
51
  yolo_results = yolo_model.predict(cv2_image)[0]
52
  boxes = yolo_results.boxes
53
  class_names = yolo_model.names
54
 
55
- # MiDaS Depth
56
  depth_inputs = depth_feat(images=pil_image, return_tensors="pt").to(device)
57
  with torch.no_grad():
58
  depth_output = depth_model(**depth_inputs)
59
  depth_map = depth_output.predicted_depth.squeeze().cpu().numpy()
60
  depth_map_resized = cv2.resize(depth_map, (rgb_image.shape[1], rgb_image.shape[0]))
61
 
62
- # Visual Context
63
  shared_visual_context = []
64
  for box in boxes:
65
  x1, y1, x2, y2 = map(int, box.xyxy[0])
66
  label = class_names[int(box.cls[0])]
67
  conf = float(box.conf[0])
 
 
68
  depth_crop = depth_map_resized[y1:y2, x1:x2]
69
  avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None
 
 
70
  x_center = (x1 + x2) / 2
71
  pos = "left" if x_center < rgb_image.shape[1] / 3 else "right" if x_center > 2 * rgb_image.shape[1] / 3 else "center"
 
72
  shared_visual_context.append({
73
  "label": label,
74
  "confidence": conf,
@@ -76,35 +131,55 @@ def process_inputs(image: Image.Image, audio_path: str):
76
  "position": pos
77
  })
78
 
79
- # Build Context Text
80
- def build_context_description(context):
81
- descriptions = []
82
- for obj in context:
83
- d = f"{obj['avg_depth']:.1f} units" if obj["avg_depth"] else "unknown"
84
- s = obj.get("position", "unknown")
85
- c = obj.get("confidence", 0.0)
86
- descriptions.append(f"a {obj['label']} ({c:.2f} confidence) is at {d} on the {s}")
87
- return "In the image, " + ", ".join(descriptions) + "."
88
-
89
- context_text = build_context_description(shared_visual_context)
90
-
91
- # Transcribe audio
92
- transcription = whisper_pipe(audio_path)["text"]
93
- vqa_prompt = context_text + " " + transcription
94
-
95
- # GEMMA answer
96
- messages = [{
97
- "role": "user",
98
- "content": [
99
- {"type": "image", "image": pil_image},
100
- {"type": "text", "text": vqa_prompt}
101
- ]
102
- }]
103
- gemma_output = gemma_pipe(text=messages, max_new_tokens=200)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  answer = gemma_output[0]["generated_text"][-1]["content"]
105
 
106
- # Generate speech
107
  output_audio_path = "response.wav"
108
- tts.tts_to_file(text=answer, file_path=output_audio_path)
 
 
 
109
 
110
  return answer, output_audio_path
 
9
 
10
  # project_module.py
11
 
12
+ # Import libraries for ML, CV, NLP, audio, and TTS
13
+ import torch, cv2, os
14
  import numpy as np
15
  from PIL import Image
16
  from ultralytics import YOLO
17
  from transformers import pipeline, DPTFeatureExtractor, DPTForDepthEstimation
18
  from TTS.api import TTS
 
 
19
  from huggingface_hub import login
 
20
 
21
+ # Authenticate to Hugging Face using environment token
22
  login(token=os.environ["HUGGING_FACE_HUB_TOKEN"])
23
 
24
+ # Set device for computation (GPU if available)
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
+ # Load all models
28
+ yolo_model = YOLO("yolov9c.pt") # YOLOv9 for object detection
29
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval() # MiDaS for depth
30
+ depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") # Feature extractor for depth model
31
 
32
+ # Whisper for audio transcription
33
+ whisper_pipe = pipeline(
34
+ "automatic-speech-recognition",
35
+ model="openai/whisper-small",
36
+ device=0 if torch.cuda.is_available() else -1
37
+ )
38
+
39
+ # GEMMA for image+text to text QA
40
  gemma_pipe = pipeline(
41
  "image-text-to-text",
42
  model="google/gemma-3-4b-it",
43
  device=0 if torch.cuda.is_available() else -1,
44
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
45
  )
 
46
 
47
+ # Text-to-speech
48
+ tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC")
49
+
50
+ # -------------------------------
51
+ # Session Management Class
52
+ # -------------------------------
53
+
54
+ class VisualQAState:
55
+ """
56
+ Stores the current image context and chat history for follow-up questions.
57
+ """
58
+ def __init__(self):
59
+ self.current_image: Image.Image = None
60
+ self.visual_context: str = ""
61
+ self.message_history = []
62
+
63
+ def reset(self, image: Image.Image, visual_context: str):
64
+ """
65
+ Called when a new image is uploaded.
66
+ Resets context and starts new message history.
67
+ """
68
+ self.current_image = image
69
+ self.visual_context = visual_context
70
+ self.message_history = [{
71
+ "role": "user",
72
+ "content": [
73
+ {"type": "image", "image": self.current_image},
74
+ {"type": "text", "text": self.visual_context}
75
+ ]
76
+ }]
77
+
78
+ def add_question(self, question: str):
79
+ """
80
+ Adds a follow-up text message to the chat.
81
+ """
82
+ self.message_history.append({
83
+ "role": "user",
84
+ "content": [{"type": "text", "text": question}]
85
+ })
86
+
87
+ # -------------------------------
88
+ # Generate Context from Image
89
+ # -------------------------------
90
+
91
+ def generate_visual_context(pil_image: Image.Image) -> str:
92
+ """
93
+ Processes the image to extract object labels, depth info, and locations.
94
+ Builds a natural language context description for use in prompting.
95
+ """
96
+ # Convert to OpenCV and RGB formats
97
+ rgb_image = np.array(pil_image)
98
  cv2_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
 
99
 
100
+ # Object detection using YOLO
101
  yolo_results = yolo_model.predict(cv2_image)[0]
102
  boxes = yolo_results.boxes
103
  class_names = yolo_model.names
104
 
105
+ # Depth estimation using MiDaS
106
  depth_inputs = depth_feat(images=pil_image, return_tensors="pt").to(device)
107
  with torch.no_grad():
108
  depth_output = depth_model(**depth_inputs)
109
  depth_map = depth_output.predicted_depth.squeeze().cpu().numpy()
110
  depth_map_resized = cv2.resize(depth_map, (rgb_image.shape[1], rgb_image.shape[0]))
111
 
112
+ # Extract contextual information for each object
113
  shared_visual_context = []
114
  for box in boxes:
115
  x1, y1, x2, y2 = map(int, box.xyxy[0])
116
  label = class_names[int(box.cls[0])]
117
  conf = float(box.conf[0])
118
+
119
+ # Compute average depth of object
120
  depth_crop = depth_map_resized[y1:y2, x1:x2]
121
  avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None
122
+
123
+ # Determine object horizontal position
124
  x_center = (x1 + x2) / 2
125
  pos = "left" if x_center < rgb_image.shape[1] / 3 else "right" if x_center > 2 * rgb_image.shape[1] / 3 else "center"
126
+
127
  shared_visual_context.append({
128
  "label": label,
129
  "confidence": conf,
 
131
  "position": pos
132
  })
133
 
134
+ # Convert context to a readable sentence
135
+ descriptions = []
136
+ for obj in shared_visual_context:
137
+ d = f"{obj['avg_depth']:.1f} units" if obj["avg_depth"] else "unknown"
138
+ s = obj.get("position", "unknown")
139
+ c = obj.get("confidence", 0.0)
140
+ descriptions.append(f"a {obj['label']} ({c:.2f} confidence) is at {d} on the {s}")
141
+
142
+ return "In the image, " + ", ".join(descriptions) + "."
143
+
144
+ # -------------------------------
145
+ # Main Multimodal Processing Function
146
+ # -------------------------------
147
+
148
+ def process_inputs(
149
+ session: VisualQAState,
150
+ image: Image.Image = None,
151
+ question: str = "",
152
+ audio_path: str = None,
153
+ enable_tts: bool = True
154
+ ):
155
+ """
156
+ Handles a new image upload or a follow-up question.
157
+ Combines image context, audio transcription, and text input to generate a GEMMA-based answer.
158
+ Optionally outputs audio using TTS.
159
+ """
160
+
161
+ # If new image is provided, reset session and build new context
162
+ if image:
163
+ visual_context = generate_visual_context(image)
164
+ session.reset(image, visual_context)
165
+
166
+ # If user gave an audio clip, transcribe it and append to question
167
+ if audio_path:
168
+ audio_text = whisper_pipe(audio_path)["text"]
169
+ question += " " + audio_text
170
+
171
+ # Append question to conversation history
172
+ session.add_question(question)
173
+
174
+ # Generate response using GEMMA with full conversation history
175
+ gemma_output = gemma_pipe(text=session.message_history, max_new_tokens=200)
176
  answer = gemma_output[0]["generated_text"][-1]["content"]
177
 
178
+ # If TTS is enabled, synthesize answer as speech
179
  output_audio_path = "response.wav"
180
+ if enable_tts:
181
+ tts.tts_to_file(text=answer, file_path=output_audio_path)
182
+ else:
183
+ output_audio_path = None
184
 
185
  return answer, output_audio_path