mbiswas's picture
Upload 10 files
b781107 verified
from constants import *
from utils import image_to_tensor, tokenizer, tensor_to_image, vocab_size, tokenizer
import torch
import torch.nn.functional as F
from PIL import ImageDraw, Image
from dataset import create_test_dataloader
from vision_language_model import VisionLanguageModel
model = VisionLanguageModel(
n_embd=HIDDEN_DIM,
vocab_size=vocab_size,
img_size=IMAGE_SIZE,
patch_size=PATCH_SIZE,
num_heads=NUM_HEADS,
num_blks_vit=NUM_LAYERS, # Or specific value for ViT layers
num_blks_dec=NUM_LAYERS, # Or specific value for Decoder layers
emb_dropout=DROPOUT,
blk_dropout=DROPOUT,
max_context=CONTEXT_LENGTH,
shared_embed_dim=SHARED_EMBED_DIM,
lambda_contrastive=LAMBDA_CONTRASTIVE,
lambda_regression=LAMBDA_REGRESSION # Pass the regression weight
).to(DEVICE)
MODEL_PATH = "model_regression_multi_first_100.pth" # "model_regression_multi_16.pth"
if DEVICE == "cuda":
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True))
else:
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True, map_location=torch.device('cpu')))
model.eval()
def generate_sample_from_image_text(
model,
image_path,
prompt_label,
tokenizer,
device,
max_new_tokens=70,
temperature=0.8,
top_k=10,
output_path="generated_output.png"
):
"""
Generates a prediction for an image and prompt text and saves it to a file.
Generation loop is implemented *within* this function.
Args:
model: The trained VisionLanguageModel.
image_path: Path to the input image.
prompt_label: Text prompt/label to use.
tokenizer: The tokenizer used for training.
device: The computation device ('cuda' or 'cpu').
max_new_tokens (int): Max tokens to generate after the prompt.
temperature (float): Softmax temperature for sampling.
top_k (int): K for top-k sampling (0 or None to disable).
output_path (str): Path where to save the output image.
Returns:
None. Saves the image with prompt and generated output to a file.
"""
model.eval() # Set the model to evaluation mode
try:
with torch.no_grad(): # No need to track gradients during inference
# --- 1. Prepare Initial Inputs ---
# Load and process image
image = Image.open(image_path)
image_tensor = image_to_tensor(image).unsqueeze(0).to(device) # Add batch dim
# Tokenize prompt
prompt_text = f"<point_start>{prompt_label}<point_end>"
prompt_tokens = tokenizer(prompt_text, return_tensors="pt", truncation=True, padding=False)
prompt_ids = prompt_tokens.input_ids.to(device)
prompt_attention_mask = prompt_tokens.attention_mask.to(device)
B = 1 # We are processing one sample at a time
print(f"--- Generating Sample (Manual Loop) ---")
print(f"Original Label/Prompt Hint: {prompt_label}")
print(f"Input Prompt Tokens Decoded: {prompt_text}")
# --- 2. Pre-compute Image & Prompt Embeddings (Part of VLM Forward Logic) ---
image_embeds_raw = model.vision_encoder(image_tensor) # (1, N_img, C)
image_embeds_decoder = model.multimodal_projector(image_embeds_raw) # (1, N_img, C)
prompt_embeds_decoder = model.decoder.token_embedding_table(prompt_ids) # (1, T_prompt, C)
result_start_token_id = tokenizer.encode("<result_start>", add_special_tokens=False)[0]
result_start_embed = model.decoder.token_embedding_table(
torch.tensor([[result_start_token_id]], device=device) # Shape (1, 1, C)
)
# The initial sequence fed to the decoder blocks consists of image + prompt
current_embeds = torch.cat([
image_embeds_decoder,
prompt_embeds_decoder,
result_start_embed # Add the embedding for the first expected output token
], dim=1)
generated_ids = [] # Store newly generated IDs
# --- 3. Autoregressive Generation Loop ---
for _ in range(max_new_tokens):
T_current = current_embeds.shape[1]
# Truncate if necessary (keep recent context)
if T_current > model.decoder.max_context: # Access max_context from decoder
print(f"Warning: Truncating context from {T_current} to {model.decoder.max_context}")
current_embeds = current_embeds[:, -model.decoder.max_context:, :]
T_current = model.decoder.max_context
# Prepare positional embeddings for current length
pos = torch.arange(0, T_current, dtype=torch.long, device=device)
pos = pos.clamp(max=model.decoder.max_context - 1) # Clamp indices
pos_emb = model.decoder.position_embedding_table(pos).unsqueeze(0) # (1, T_current, C)
x = current_embeds + pos_emb
# Create attention mask (all ones, causal handles future)
# Note: We don't need padding mask here as we handle one sequence without padding
attention_mask = torch.ones(B, T_current, device=device, dtype=torch.long)
# Pass through Decoder Blocks
for block in model.decoder.blocks:
# We assume the block forward takes (x, attention_mask)
x = block(x, attention_mask=attention_mask)
# Final Layer Norm and LM Head for the *last* token prediction
x = model.decoder.ln_f(x[:, -1:, :]) # (B, 1, C) -> (1, 1, C)
logits = model.decoder.lm_head(x) # (B, 1, V) -> (1, 1, V)
logits = logits.squeeze(1) # (B, V) -> (1, V)
# Sampling
logits = logits / temperature
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
# idx_next = torch.multinomial(probs, num_samples=1) # (1, 1) # test distribution
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # test deterministic
# Store generated ID
generated_ids.append(idx_next)
# Stop if EOS token is generated
if idx_next.item() == tokenizer.eos_token_id:
print("EOS token generated.")
break
# Prepare for next iteration: Append embedding of new token
next_token_embed = model.decoder.token_embedding_table(idx_next) # (1, 1, C)
current_embeds = torch.cat([current_embeds, next_token_embed], dim=1) # Append along sequence dim
# --- 4. Combine and Decode Results ---
if generated_ids:
generated_ids_tensor = torch.cat(generated_ids, dim=1) # (1, T_generated)
initial_target_ids = torch.tensor([[result_start_token_id]], device=device)
full_generated_sequence_ids = torch.cat([prompt_ids, initial_target_ids, generated_ids_tensor], dim=1)
else:
full_generated_sequence_ids = prompt_ids # Nothing was generated
full_decoded_text = tokenizer.decode(full_generated_sequence_ids[0], skip_special_tokens=False)
print(f"\nFull Generated Sequence (Manual Loop):\n{full_decoded_text}")
# --- 5. Save visualization to file ---
save_coords_visualization(
image_tensor=image_tensor[0], # Remove batch dim for visualization
full_decoded_text=full_decoded_text,
tokenizer=tokenizer,
image_size=IMAGE_SIZE, # Assumes IMAGE_SIZE is globally defined
num_bins=NUM_BINS, # Assumes NUM_BINS is globally defined
output_path=output_path
)
print(f"Visualization saved to: {output_path}")
except Exception as e:
print(f"An error occurred during sample generation: {e}")
import traceback
traceback.print_exc()
def generate_sample_from_test_loader(
model,
test_loader,
tokenizer,
device,
max_new_tokens=70,
temperature=0.8,
top_k=10,
output_path="generated_output.png",
TEST_BATCH=8,
TEST_IDX=1
):
"""
Generates a prediction for one sample from the test loader and saves it to a file.
Generation loop is implemented *within* this function.
Args:
model: The trained VisionLanguageModel.
test_loader: DataLoader for the test set.
tokenizer: The tokenizer used for training.
device: The computation device ('cuda' or 'cpu').
max_new_tokens (int): Max tokens to generate after the prompt.
temperature (float): Softmax temperature for sampling.
top_k (int): K for top-k sampling (0 or None to disable).
output_path (str): Path where to save the output image.
Returns:
None. Saves the image with prompt and generated output to a file.
"""
if not test_loader or len(test_loader.dataset) == 0:
print("Test loader is empty or not available.")
return
model.eval() # Set the model to evaluation mode
try:
# Get a single batch from the test loader
with torch.no_grad(): # No need to track gradients during inference
my_iter = iter(test_loader)
for i in range(TEST_BATCH):
_ = next(my_iter)
batch = next(my_iter)
if batch is None:
print("Test loader yielded an empty batch.")
return
if batch['image'].shape[0] == 0:
print("Test loader yielded a batch with 0 items.")
return
# --- 1. Prepare Initial Inputs ---
image_tensor = batch['image'][TEST_IDX:TEST_IDX+1].to(device) # (1, 3, H, W)
prompt_ids = batch['prompt_ids'][TEST_IDX:TEST_IDX+1].to(device) # (1, T_prompt)
prompt_attention_mask = batch['prompt_attention_mask'][TEST_IDX:TEST_IDX+1].to(device) # (1, T_prompt)
label = batch['label'][TEST_IDX]
B = 1 # We are processing one sample at a time
print(f"--- Generating Sample (Manual Loop) ---")
print(f"Original Label/Prompt Hint: {label}")
prompt_text = tokenizer.decode(prompt_ids[0], skip_special_tokens=False)
print(f"Input Prompt Tokens Decoded: {prompt_text}")
# --- 2. Pre-compute Image & Prompt Embeddings (Part of VLM Forward Logic) ---
image_embeds_raw = model.vision_encoder(image_tensor) # (1, N_img, C)
image_embeds_decoder = model.multimodal_projector(image_embeds_raw) # (1, N_img, C)
prompt_embeds_decoder = model.decoder.token_embedding_table(prompt_ids) # (1, T_prompt, C)
result_start_token_id = tokenizer.encode("<result_start>", add_special_tokens=False)[0]
result_start_embed = model.decoder.token_embedding_table(
torch.tensor([[result_start_token_id]], device=device) # Shape (1, 1, C)
)
# The initial sequence fed to the decoder blocks consists of image + prompt
current_embeds = torch.cat([
image_embeds_decoder,
prompt_embeds_decoder,
result_start_embed # Add the embedding for the first expected output token
], dim=1)
# current_embeds = torch.cat([image_embeds_decoder, prompt_embeds_decoder], dim=1) # (1, T_initial, C)
generated_ids = [] # Store newly generated IDs
# --- 3. Autoregressive Generation Loop ---
for _ in range(max_new_tokens):
T_current = current_embeds.shape[1]
# Truncate if necessary (keep recent context)
if T_current > model.decoder.max_context: # Access max_context from decoder
print(f"Warning: Truncating context from {T_current} to {model.decoder.max_context}")
current_embeds = current_embeds[:, -model.decoder.max_context:, :]
T_current = model.decoder.max_context
# Prepare positional embeddings for current length
pos = torch.arange(0, T_current, dtype=torch.long, device=device)
pos = pos.clamp(max=model.decoder.max_context - 1) # Clamp indices
pos_emb = model.decoder.position_embedding_table(pos).unsqueeze(0) # (1, T_current, C)
x = current_embeds + pos_emb
# Create attention mask (all ones, causal handles future)
# Note: We don't need padding mask here as we handle one sequence without padding
attention_mask = torch.ones(B, T_current, device=device, dtype=torch.long)
# Pass through Decoder Blocks
for block in model.decoder.blocks:
# We assume the block forward takes (x, attention_mask)
x = block(x, attention_mask=attention_mask)
# Final Layer Norm and LM Head for the *last* token prediction
x = model.decoder.ln_f(x[:, -1:, :]) # (B, 1, C) -> (1, 1, C)
logits = model.decoder.lm_head(x) # (B, 1, V) -> (1, 1, V)
logits = logits.squeeze(1) # (B, V) -> (1, V)
# Sampling
logits = logits / temperature
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
# idx_next = torch.multinomial(probs, num_samples=1) # (1, 1) # test distribution
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # test deterministic
# Store generated ID
generated_ids.append(idx_next)
# Stop if EOS token is generated
if idx_next.item() == tokenizer.eos_token_id:
print("EOS token generated.")
break
# Prepare for next iteration: Append embedding of new token
next_token_embed = model.decoder.token_embedding_table(idx_next) # (1, 1, C)
current_embeds = torch.cat([current_embeds, next_token_embed], dim=1) # Append along sequence dim
# --- 4. Combine and Decode Results ---
if generated_ids:
generated_ids_tensor = torch.cat(generated_ids, dim=1) # (1, T_generated)
initial_target_ids = torch.tensor([[result_start_token_id]], device=device)
full_generated_sequence_ids = torch.cat([prompt_ids, initial_target_ids, generated_ids_tensor], dim=1)
else:
full_generated_sequence_ids = prompt_ids # Nothing was generated
full_decoded_text = tokenizer.decode(full_generated_sequence_ids[0], skip_special_tokens=False)
print(f"\nFull Generated Sequence (Manual Loop):\n{full_decoded_text}")
# --- 5. Save visualization to file ---
save_coords_visualization(
image_tensor=image_tensor[0], # Remove batch dim for visualization
full_decoded_text=full_decoded_text,
tokenizer=tokenizer,
image_size=IMAGE_SIZE, # Assumes IMAGE_SIZE is globally defined
num_bins=NUM_BINS, # Assumes NUM_BINS is globally defined
output_path=output_path
)
print(f"Visualization saved to: {output_path}")
except StopIteration:
print("Test loader is exhausted.")
except Exception as e:
print(f"An error occurred during sample generation: {e}")
import traceback
traceback.print_exc()
def parse_coordinate_tokens(text, tokenizer, num_bins):
"""
Parses generated text to extract coordinate bin tokens.
Args:
text (str): The decoded output text from the model.
tokenizer: The tokenizer.
num_bins (int): The number of coordinate bins used.
Returns:
list[tuple(int, int)]: A list of (x_bin, y_bin) tuples, or None if parsing fails.
"""
coords = []
try:
# Basic parsing - look for the pattern
x_start_token = "<pointx_start>"
x_end_token = "<pointx_end>"
y_start_token = "<pointy_start>"
y_end_token = "<pointy_end>"
result_end_token = "<result_end>"
# Find where the actual results start
try:
start_index = text.index("<result_start>") + len("<result_start>")
except ValueError:
print("Warning: <result_start> not found in generated text.")
return None
# Find where results end
try:
end_index = text.index(result_end_token, start_index)
except ValueError:
end_index = len(text) # Use end of string if <result_end> is missing
print(f"Warning: {result_end_token} not found. Parsing until end of string.")
current_pos = start_index
while current_pos < end_index:
# Find next X coordinate
x_start_idx = text.find(x_start_token, current_pos)
if x_start_idx == -1 or x_start_idx >= end_index: break # No more x points found
x_start_idx += len(x_start_token)
x_end_idx = text.find(x_end_token, x_start_idx)
if x_end_idx == -1 or x_end_idx >= end_index: break # Malformed
x_token_str = text[x_start_idx:x_end_idx].strip()
# Find next Y coordinate (must follow X)
y_start_idx = text.find(y_start_token, x_end_idx)
if y_start_idx == -1 or y_start_idx >= end_index: break # No corresponding y point
y_start_idx += len(y_start_token)
y_end_idx = text.find(y_end_token, y_start_idx)
if y_end_idx == -1 or y_end_idx >= end_index: break # Malformed
y_token_str = text[y_start_idx:y_end_idx].strip()
x_token_str = x_token_str[:-1]
y_token_str = y_token_str[:-1]
# Convert token strings to bin numbers
try:
x_bin = int(x_token_str.split("_")[-1])
y_bin = int(y_token_str.split("_")[-1])
if 0 <= x_bin < num_bins and 0 <= y_bin < num_bins:
coords.append((x_bin, y_bin))
else:
print(f"Warning: Parsed bin indices out of range ({x_bin}, {y_bin}). Skipping.")
except (ValueError, IndexError):
print(f"Warning: Could not parse bins from tokens '{x_token_str}', '{y_token_str}'. Skipping.")
# Move search position past the found Y token
current_pos = y_end_idx + len(y_end_token)
return coords if coords else None
except Exception as e:
print(f"Error during coordinate parsing: {e}")
return None
def save_coords_visualization(image_tensor, full_decoded_text, tokenizer, image_size, num_bins, output_path):
"""Parses coords, draws them on the image, and saves to a file."""
parsed_bins = parse_coordinate_tokens(full_decoded_text, tokenizer, num_bins)
# Convert tensor to PIL image for drawing
try:
pil_image = tensor_to_image(image_tensor.cpu()) # Ensure tensor is on CPU
except Exception as e:
print(f"Error converting tensor to image: {e}")
# Create a placeholder image if conversion fails
pil_image = Image.new('RGB', (image_size, image_size), color='white')
draw = ImageDraw.Draw(pil_image)
draw.text((10, 10), "Image conversion failed", fill="black")
pil_image.save(output_path)
return
draw = ImageDraw.Draw(pil_image)
radius = 5 # Radius of the drawn point
if parsed_bins:
print(f"\nParsed Coordinate Bins: {parsed_bins}")
bin_size_pixels = image_size / num_bins
for x_bin, y_bin in parsed_bins:
# Calculate center of the bin in pixels
center_x = (x_bin + 0.5) * bin_size_pixels
center_y = (y_bin + 0.5) * bin_size_pixels
# Draw a circle
bbox = [center_x - radius, center_y - radius, center_x + radius, center_y + radius]
draw.ellipse(bbox, outline="red", width=3)
# Optional: Draw bin boundaries for debugging
# draw.rectangle([x_bin*bin_size_pixels, y_bin*bin_size_pixels, (x_bin+1)*bin_size_pixels, (y_bin+1)*bin_size_pixels], outline="blue", width=1)
# Add a text label with the coordinates at the top of the image
coord_text = f"Generated Point(s): {parsed_bins}"
draw.text((10, 10), coord_text, fill="red")
else:
print("\nCould not parse valid coordinates from the generated text.")
# Add a text label indicating no coordinates were found
draw.text((10, 10), "No Coordinates Parsed", fill="red")
# Save the image to file
pil_image.save(output_path)
import argparse
# --- Example Usage ---
# python infer.py --image ./data/test_images/image_1.png --prompt "a red apple"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--image', type=str, help='Path to input image')
parser.add_argument('--prompt', type=str, help='Prompt label for generation')
args = parser.parse_args()
if args.image and args.prompt:
# Use image and prompt based generation
if 'model' in locals() and 'tokenizer' in locals():
generate_sample_from_image_text(
model=model,
image_path=args.image,
prompt_label=args.prompt,
tokenizer=tokenizer,
device=DEVICE,
output_path="model_prediction.png"
)
else:
print("Please ensure 'model' and 'tokenizer' are loaded before running generation.")
else:
# Use test loader based generation
if 'model' in locals() and 'test_loader' in locals() and 'tokenizer' in locals():
test_loader = create_test_dataloader(batch_size=2, num_workers=0)
generate_sample_from_test_loader(
model=model,
test_loader=test_loader,
tokenizer=tokenizer,
device=DEVICE,
output_path="model_prediction.png"
)
else:
print("Please ensure 'model', 'test_loader', and 'tokenizer' are loaded before running generation.")