|
from constants import * |
|
from utils import image_to_tensor, tokenizer, tensor_to_image, vocab_size, tokenizer |
|
import torch |
|
import torch.nn.functional as F |
|
from PIL import ImageDraw, Image |
|
from dataset import create_test_dataloader |
|
from vision_language_model import VisionLanguageModel |
|
|
|
|
|
model = VisionLanguageModel( |
|
n_embd=HIDDEN_DIM, |
|
vocab_size=vocab_size, |
|
img_size=IMAGE_SIZE, |
|
patch_size=PATCH_SIZE, |
|
num_heads=NUM_HEADS, |
|
num_blks_vit=NUM_LAYERS, |
|
num_blks_dec=NUM_LAYERS, |
|
emb_dropout=DROPOUT, |
|
blk_dropout=DROPOUT, |
|
max_context=CONTEXT_LENGTH, |
|
shared_embed_dim=SHARED_EMBED_DIM, |
|
lambda_contrastive=LAMBDA_CONTRASTIVE, |
|
lambda_regression=LAMBDA_REGRESSION |
|
).to(DEVICE) |
|
|
|
MODEL_PATH = "model_regression_multi_first_100.pth" |
|
|
|
if DEVICE == "cuda": |
|
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True)) |
|
else: |
|
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True, map_location=torch.device('cpu'))) |
|
model.eval() |
|
|
|
def generate_sample_from_image_text( |
|
model, |
|
image_path, |
|
prompt_label, |
|
tokenizer, |
|
device, |
|
max_new_tokens=70, |
|
temperature=0.8, |
|
top_k=10, |
|
output_path="generated_output.png" |
|
): |
|
""" |
|
Generates a prediction for an image and prompt text and saves it to a file. |
|
Generation loop is implemented *within* this function. |
|
|
|
Args: |
|
model: The trained VisionLanguageModel. |
|
image_path: Path to the input image. |
|
prompt_label: Text prompt/label to use. |
|
tokenizer: The tokenizer used for training. |
|
device: The computation device ('cuda' or 'cpu'). |
|
max_new_tokens (int): Max tokens to generate after the prompt. |
|
temperature (float): Softmax temperature for sampling. |
|
top_k (int): K for top-k sampling (0 or None to disable). |
|
output_path (str): Path where to save the output image. |
|
|
|
Returns: |
|
None. Saves the image with prompt and generated output to a file. |
|
""" |
|
model.eval() |
|
|
|
try: |
|
with torch.no_grad(): |
|
|
|
|
|
image = Image.open(image_path) |
|
image_tensor = image_to_tensor(image).unsqueeze(0).to(device) |
|
|
|
|
|
prompt_text = f"<point_start>{prompt_label}<point_end>" |
|
prompt_tokens = tokenizer(prompt_text, return_tensors="pt", truncation=True, padding=False) |
|
prompt_ids = prompt_tokens.input_ids.to(device) |
|
prompt_attention_mask = prompt_tokens.attention_mask.to(device) |
|
B = 1 |
|
|
|
print(f"--- Generating Sample (Manual Loop) ---") |
|
print(f"Original Label/Prompt Hint: {prompt_label}") |
|
print(f"Input Prompt Tokens Decoded: {prompt_text}") |
|
|
|
|
|
image_embeds_raw = model.vision_encoder(image_tensor) |
|
image_embeds_decoder = model.multimodal_projector(image_embeds_raw) |
|
prompt_embeds_decoder = model.decoder.token_embedding_table(prompt_ids) |
|
|
|
result_start_token_id = tokenizer.encode("<result_start>", add_special_tokens=False)[0] |
|
result_start_embed = model.decoder.token_embedding_table( |
|
torch.tensor([[result_start_token_id]], device=device) |
|
) |
|
|
|
|
|
current_embeds = torch.cat([ |
|
image_embeds_decoder, |
|
prompt_embeds_decoder, |
|
result_start_embed |
|
], dim=1) |
|
generated_ids = [] |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
T_current = current_embeds.shape[1] |
|
|
|
|
|
if T_current > model.decoder.max_context: |
|
print(f"Warning: Truncating context from {T_current} to {model.decoder.max_context}") |
|
current_embeds = current_embeds[:, -model.decoder.max_context:, :] |
|
T_current = model.decoder.max_context |
|
|
|
|
|
pos = torch.arange(0, T_current, dtype=torch.long, device=device) |
|
pos = pos.clamp(max=model.decoder.max_context - 1) |
|
pos_emb = model.decoder.position_embedding_table(pos).unsqueeze(0) |
|
x = current_embeds + pos_emb |
|
|
|
|
|
|
|
attention_mask = torch.ones(B, T_current, device=device, dtype=torch.long) |
|
|
|
|
|
for block in model.decoder.blocks: |
|
|
|
x = block(x, attention_mask=attention_mask) |
|
|
|
|
|
x = model.decoder.ln_f(x[:, -1:, :]) |
|
logits = model.decoder.lm_head(x) |
|
logits = logits.squeeze(1) |
|
|
|
|
|
logits = logits / temperature |
|
if top_k is not None and top_k > 0: |
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
idx_next = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
|
|
|
generated_ids.append(idx_next) |
|
|
|
|
|
if idx_next.item() == tokenizer.eos_token_id: |
|
print("EOS token generated.") |
|
break |
|
|
|
|
|
next_token_embed = model.decoder.token_embedding_table(idx_next) |
|
current_embeds = torch.cat([current_embeds, next_token_embed], dim=1) |
|
|
|
|
|
if generated_ids: |
|
generated_ids_tensor = torch.cat(generated_ids, dim=1) |
|
initial_target_ids = torch.tensor([[result_start_token_id]], device=device) |
|
full_generated_sequence_ids = torch.cat([prompt_ids, initial_target_ids, generated_ids_tensor], dim=1) |
|
else: |
|
full_generated_sequence_ids = prompt_ids |
|
|
|
full_decoded_text = tokenizer.decode(full_generated_sequence_ids[0], skip_special_tokens=False) |
|
print(f"\nFull Generated Sequence (Manual Loop):\n{full_decoded_text}") |
|
|
|
|
|
save_coords_visualization( |
|
image_tensor=image_tensor[0], |
|
full_decoded_text=full_decoded_text, |
|
tokenizer=tokenizer, |
|
image_size=IMAGE_SIZE, |
|
num_bins=NUM_BINS, |
|
output_path=output_path |
|
) |
|
print(f"Visualization saved to: {output_path}") |
|
|
|
except Exception as e: |
|
print(f"An error occurred during sample generation: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
def generate_sample_from_test_loader( |
|
model, |
|
test_loader, |
|
tokenizer, |
|
device, |
|
max_new_tokens=70, |
|
temperature=0.8, |
|
top_k=10, |
|
output_path="generated_output.png", |
|
TEST_BATCH=8, |
|
TEST_IDX=1 |
|
): |
|
""" |
|
Generates a prediction for one sample from the test loader and saves it to a file. |
|
Generation loop is implemented *within* this function. |
|
|
|
Args: |
|
model: The trained VisionLanguageModel. |
|
test_loader: DataLoader for the test set. |
|
tokenizer: The tokenizer used for training. |
|
device: The computation device ('cuda' or 'cpu'). |
|
max_new_tokens (int): Max tokens to generate after the prompt. |
|
temperature (float): Softmax temperature for sampling. |
|
top_k (int): K for top-k sampling (0 or None to disable). |
|
output_path (str): Path where to save the output image. |
|
|
|
Returns: |
|
None. Saves the image with prompt and generated output to a file. |
|
""" |
|
|
|
if not test_loader or len(test_loader.dataset) == 0: |
|
print("Test loader is empty or not available.") |
|
return |
|
|
|
model.eval() |
|
|
|
try: |
|
|
|
with torch.no_grad(): |
|
my_iter = iter(test_loader) |
|
for i in range(TEST_BATCH): |
|
_ = next(my_iter) |
|
batch = next(my_iter) |
|
|
|
if batch is None: |
|
print("Test loader yielded an empty batch.") |
|
return |
|
if batch['image'].shape[0] == 0: |
|
print("Test loader yielded a batch with 0 items.") |
|
return |
|
|
|
|
|
image_tensor = batch['image'][TEST_IDX:TEST_IDX+1].to(device) |
|
prompt_ids = batch['prompt_ids'][TEST_IDX:TEST_IDX+1].to(device) |
|
prompt_attention_mask = batch['prompt_attention_mask'][TEST_IDX:TEST_IDX+1].to(device) |
|
label = batch['label'][TEST_IDX] |
|
B = 1 |
|
|
|
print(f"--- Generating Sample (Manual Loop) ---") |
|
print(f"Original Label/Prompt Hint: {label}") |
|
prompt_text = tokenizer.decode(prompt_ids[0], skip_special_tokens=False) |
|
print(f"Input Prompt Tokens Decoded: {prompt_text}") |
|
|
|
|
|
image_embeds_raw = model.vision_encoder(image_tensor) |
|
image_embeds_decoder = model.multimodal_projector(image_embeds_raw) |
|
prompt_embeds_decoder = model.decoder.token_embedding_table(prompt_ids) |
|
|
|
result_start_token_id = tokenizer.encode("<result_start>", add_special_tokens=False)[0] |
|
result_start_embed = model.decoder.token_embedding_table( |
|
torch.tensor([[result_start_token_id]], device=device) |
|
) |
|
|
|
|
|
current_embeds = torch.cat([ |
|
image_embeds_decoder, |
|
prompt_embeds_decoder, |
|
result_start_embed |
|
], dim=1) |
|
|
|
generated_ids = [] |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
T_current = current_embeds.shape[1] |
|
|
|
|
|
if T_current > model.decoder.max_context: |
|
print(f"Warning: Truncating context from {T_current} to {model.decoder.max_context}") |
|
current_embeds = current_embeds[:, -model.decoder.max_context:, :] |
|
T_current = model.decoder.max_context |
|
|
|
|
|
pos = torch.arange(0, T_current, dtype=torch.long, device=device) |
|
pos = pos.clamp(max=model.decoder.max_context - 1) |
|
pos_emb = model.decoder.position_embedding_table(pos).unsqueeze(0) |
|
x = current_embeds + pos_emb |
|
|
|
|
|
|
|
attention_mask = torch.ones(B, T_current, device=device, dtype=torch.long) |
|
|
|
|
|
for block in model.decoder.blocks: |
|
|
|
x = block(x, attention_mask=attention_mask) |
|
|
|
|
|
x = model.decoder.ln_f(x[:, -1:, :]) |
|
logits = model.decoder.lm_head(x) |
|
logits = logits.squeeze(1) |
|
|
|
|
|
logits = logits / temperature |
|
if top_k is not None and top_k > 0: |
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
idx_next = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
|
|
|
generated_ids.append(idx_next) |
|
|
|
|
|
if idx_next.item() == tokenizer.eos_token_id: |
|
print("EOS token generated.") |
|
break |
|
|
|
|
|
next_token_embed = model.decoder.token_embedding_table(idx_next) |
|
current_embeds = torch.cat([current_embeds, next_token_embed], dim=1) |
|
|
|
|
|
if generated_ids: |
|
generated_ids_tensor = torch.cat(generated_ids, dim=1) |
|
initial_target_ids = torch.tensor([[result_start_token_id]], device=device) |
|
full_generated_sequence_ids = torch.cat([prompt_ids, initial_target_ids, generated_ids_tensor], dim=1) |
|
else: |
|
full_generated_sequence_ids = prompt_ids |
|
|
|
full_decoded_text = tokenizer.decode(full_generated_sequence_ids[0], skip_special_tokens=False) |
|
print(f"\nFull Generated Sequence (Manual Loop):\n{full_decoded_text}") |
|
|
|
|
|
save_coords_visualization( |
|
image_tensor=image_tensor[0], |
|
full_decoded_text=full_decoded_text, |
|
tokenizer=tokenizer, |
|
image_size=IMAGE_SIZE, |
|
num_bins=NUM_BINS, |
|
output_path=output_path |
|
) |
|
print(f"Visualization saved to: {output_path}") |
|
|
|
except StopIteration: |
|
print("Test loader is exhausted.") |
|
except Exception as e: |
|
print(f"An error occurred during sample generation: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
def parse_coordinate_tokens(text, tokenizer, num_bins): |
|
""" |
|
Parses generated text to extract coordinate bin tokens. |
|
|
|
Args: |
|
text (str): The decoded output text from the model. |
|
tokenizer: The tokenizer. |
|
num_bins (int): The number of coordinate bins used. |
|
|
|
Returns: |
|
list[tuple(int, int)]: A list of (x_bin, y_bin) tuples, or None if parsing fails. |
|
""" |
|
coords = [] |
|
try: |
|
|
|
x_start_token = "<pointx_start>" |
|
x_end_token = "<pointx_end>" |
|
y_start_token = "<pointy_start>" |
|
y_end_token = "<pointy_end>" |
|
result_end_token = "<result_end>" |
|
|
|
|
|
try: |
|
start_index = text.index("<result_start>") + len("<result_start>") |
|
except ValueError: |
|
print("Warning: <result_start> not found in generated text.") |
|
return None |
|
|
|
|
|
try: |
|
end_index = text.index(result_end_token, start_index) |
|
except ValueError: |
|
end_index = len(text) |
|
print(f"Warning: {result_end_token} not found. Parsing until end of string.") |
|
|
|
|
|
current_pos = start_index |
|
while current_pos < end_index: |
|
|
|
x_start_idx = text.find(x_start_token, current_pos) |
|
if x_start_idx == -1 or x_start_idx >= end_index: break |
|
x_start_idx += len(x_start_token) |
|
|
|
x_end_idx = text.find(x_end_token, x_start_idx) |
|
if x_end_idx == -1 or x_end_idx >= end_index: break |
|
|
|
x_token_str = text[x_start_idx:x_end_idx].strip() |
|
|
|
|
|
y_start_idx = text.find(y_start_token, x_end_idx) |
|
if y_start_idx == -1 or y_start_idx >= end_index: break |
|
y_start_idx += len(y_start_token) |
|
|
|
y_end_idx = text.find(y_end_token, y_start_idx) |
|
if y_end_idx == -1 or y_end_idx >= end_index: break |
|
|
|
y_token_str = text[y_start_idx:y_end_idx].strip() |
|
|
|
x_token_str = x_token_str[:-1] |
|
y_token_str = y_token_str[:-1] |
|
|
|
|
|
try: |
|
x_bin = int(x_token_str.split("_")[-1]) |
|
y_bin = int(y_token_str.split("_")[-1]) |
|
if 0 <= x_bin < num_bins and 0 <= y_bin < num_bins: |
|
coords.append((x_bin, y_bin)) |
|
else: |
|
print(f"Warning: Parsed bin indices out of range ({x_bin}, {y_bin}). Skipping.") |
|
except (ValueError, IndexError): |
|
print(f"Warning: Could not parse bins from tokens '{x_token_str}', '{y_token_str}'. Skipping.") |
|
|
|
|
|
current_pos = y_end_idx + len(y_end_token) |
|
|
|
return coords if coords else None |
|
|
|
except Exception as e: |
|
print(f"Error during coordinate parsing: {e}") |
|
return None |
|
|
|
|
|
def save_coords_visualization(image_tensor, full_decoded_text, tokenizer, image_size, num_bins, output_path): |
|
"""Parses coords, draws them on the image, and saves to a file.""" |
|
parsed_bins = parse_coordinate_tokens(full_decoded_text, tokenizer, num_bins) |
|
|
|
|
|
try: |
|
pil_image = tensor_to_image(image_tensor.cpu()) |
|
except Exception as e: |
|
print(f"Error converting tensor to image: {e}") |
|
|
|
pil_image = Image.new('RGB', (image_size, image_size), color='white') |
|
draw = ImageDraw.Draw(pil_image) |
|
draw.text((10, 10), "Image conversion failed", fill="black") |
|
pil_image.save(output_path) |
|
return |
|
|
|
draw = ImageDraw.Draw(pil_image) |
|
radius = 5 |
|
|
|
if parsed_bins: |
|
print(f"\nParsed Coordinate Bins: {parsed_bins}") |
|
bin_size_pixels = image_size / num_bins |
|
for x_bin, y_bin in parsed_bins: |
|
|
|
center_x = (x_bin + 0.5) * bin_size_pixels |
|
center_y = (y_bin + 0.5) * bin_size_pixels |
|
|
|
|
|
bbox = [center_x - radius, center_y - radius, center_x + radius, center_y + radius] |
|
draw.ellipse(bbox, outline="red", width=3) |
|
|
|
|
|
|
|
|
|
coord_text = f"Generated Point(s): {parsed_bins}" |
|
draw.text((10, 10), coord_text, fill="red") |
|
else: |
|
print("\nCould not parse valid coordinates from the generated text.") |
|
|
|
draw.text((10, 10), "No Coordinates Parsed", fill="red") |
|
|
|
|
|
pil_image.save(output_path) |
|
|
|
|
|
import argparse |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--image', type=str, help='Path to input image') |
|
parser.add_argument('--prompt', type=str, help='Prompt label for generation') |
|
args = parser.parse_args() |
|
if args.image and args.prompt: |
|
|
|
if 'model' in locals() and 'tokenizer' in locals(): |
|
generate_sample_from_image_text( |
|
model=model, |
|
image_path=args.image, |
|
prompt_label=args.prompt, |
|
tokenizer=tokenizer, |
|
device=DEVICE, |
|
output_path="model_prediction.png" |
|
) |
|
else: |
|
print("Please ensure 'model' and 'tokenizer' are loaded before running generation.") |
|
else: |
|
|
|
if 'model' in locals() and 'test_loader' in locals() and 'tokenizer' in locals(): |
|
test_loader = create_test_dataloader(batch_size=2, num_workers=0) |
|
generate_sample_from_test_loader( |
|
model=model, |
|
test_loader=test_loader, |
|
tokenizer=tokenizer, |
|
device=DEVICE, |
|
output_path="model_prediction.png" |
|
) |
|
else: |
|
print("Please ensure 'model', 'test_loader', and 'tokenizer' are loaded before running generation.") |