Image features and image tokens do not match
url: https://colab.research.google.com/drive/1WQEL58JYdqVd3Ruytcm4uQImCN8vFxn6?usp=sharing
import gc
def clear_memory():
"""Clears GPU memory."""
torch.cuda.empty_cache()
gc.collect()
def extract_question_from_image(image_input, device="cuda"):
"""
Extracts the question and options from an image using a multimodal model.
"""
if isinstance(image_input, str):
try:
image = Image.open(image_input).convert("RGB")
except Exception as e:
print(f"Error opening image: {e}")
return "Error: Could not open image."
else:
image = image_input.convert("RGB")
# Resize for efficiency and to avoid OOM errors. Adjust as needed.
max_size = (512, 512)
image.thumbnail(max_size, Image.LANCZOS)
prompt = "Extract the question from image and diagram and all options from this image." # Simple, direct prompt
try:
encoding = processor(
images=image,
text=prompt, # Pass the prompt text here
return_tensors="pt",
padding=True, # Add padding
).to(device)
with torch.no_grad():
outputs = model.generate(
**encoding,
max_new_tokens=1024, # Adjust as needed
do_sample=False, # Use greedy decoding for deterministic output
temperature=0.1,
num_beams=1
)
# Decode, skipping the prompt tokens
input_length = encoding.input_ids.shape[1]
generated_text = tokenizer.decode(
outputs[0][input_length:], skip_special_tokens=True
).strip()
return generated_text
except Exception as e:
print(f"Error during generation: {e}")
clear_memory()
return "Error: Could not extract question from image."
finally:
# Clean up, even if there's no error
if 'encoding' in locals():
for k in encoding:
if isinstance(encoding[k], torch.Tensor):
encoding[k] = encoding[k].cpu()
if 'outputs' in locals():
outputs = outputs.cpu()
clear_memory()
def process_image_sample(sample, device="cuda"):
"""Processes a single image sample from the dataset."""
if "image" not in sample or not sample["image"]:
print("No image found in the sample.")
return None
# Display the image (optional, good for debugging)
display(sample["image"]) # Uncomment if you're in a notebook environment
try:
extracted_text = extract_question_from_image(sample["image"], device)
print(f"Extracted Text: {extracted_text}")
return extracted_text
except Exception as e:
print(f"Error processing image: {e}")
return None
finally:
clear_memory()
--- Example Usage ---
if name == "main":
# Process a few samples (adjust the range as needed)
for i in range(min(5, len(dataset))): # Process up to 5 samples, or fewer if the dataset is smaller
print(f"Processing sample {i+1}:")
process_image_sample(dataset[i])
print("-" * 20)
clear_memory()