Spaces:
Runtime error
Runtime error
File size: 4,777 Bytes
c308f77 4f4785c e932abd 96feb73 e932abd 29bcc5f 4f4785c 29bcc5f 4f4785c 29bcc5f c308f77 4f4785c e932abd 96feb73 4f4785c 29bcc5f e932abd 29bcc5f 96feb73 29bcc5f e932abd 29bcc5f e932abd 96feb73 29bcc5f e932abd c308f77 29bcc5f 96feb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from constants import TOKENIZER_PATH
import einops
import os
import random
from pytorch_lightning.callbacks import Callback
import torch
import torch.nn.functional as F
from torchvision import transforms
class LogImageTexCallback(Callback):
def __init__(self, logger, top_k, max_length):
self.logger = logger
self.top_k = top_k
self.max_length = max_length
self.tex_tokenizer = torch.load(TOKENIZER_PATH)
self.tensor_to_PIL = transforms.ToPILImage()
def on_validation_batch_start(self, trainer, transformer, batch, batch_idx, dataloader_idx):
if batch_idx != 0 or dataloader_idx != 0:
return
sample_id = random.randint(0, len(batch['images']) - 1)
image = batch['images'][sample_id]
texs_predicted = beam_search_decode(transformer, image, top_k=self.top_k, max_length=self.max_length)
image = self.tensor_to_PIL(image)
tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)))
self.logger.log_image(key="samples", images=[image],
caption=[f"True: {tex_true}\nPredicted: " + "\n".join(texs_predicted)])
@torch.inference_mode()
def beam_search_decode(transformer, image, image_transform=None, top_k=10, max_length=100):
"""Performs decoding maintaining k best candidates"""
def get_tgt_padding_mask(tgt):
mask = tgt == tex_tokenizer.token_to_id("[SEP]")
mask = torch.cumsum(mask, dim=1)
mask = mask.to(transformer.device, torch.bool)
return mask
if image_transform:
image = image_transform(image)
assert torch.is_tensor(image) and len(image.shape) == 3, "Image must be a 3 dimensional tensor (c h w)"
src = einops.rearrange(image, "c h w -> () c h w").to(transformer.device)
memory = transformer.encode(src)
tex_tokenizer = torch.load(TOKENIZER_PATH)
candidates_tex_ids = [[tex_tokenizer.token_to_id("[CLS]")]]
candidates_log_prob = torch.tensor([0], dtype=torch.float, device=transformer.device)
while candidates_tex_ids[0][-1] != tex_tokenizer.token_to_id("[SEP]") and len(candidates_tex_ids[0]) < max_length:
candidates_tex_ids = torch.tensor(candidates_tex_ids, dtype=torch.float, device=transformer.device)
tgt_mask = transformer.transformer.generate_square_subsequent_mask(candidates_tex_ids.shape[1]).to(
transformer.device, torch.bool)
shared_memories = einops.repeat(memory, f"one n d_model -> ({candidates_tex_ids.shape[0]} one) n d_model")
outs = transformer.decode(tgt=candidates_tex_ids,
memory=shared_memories,
tgt_mask=tgt_mask,
memory_mask=None,
tgt_padding_mask=get_tgt_padding_mask(candidates_tex_ids))
outs = einops.rearrange(outs, 'b n prob -> b prob n')[:, :, -1]
vocab_size = outs.shape[1]
outs = F.log_softmax(outs, dim=1)
outs += einops.rearrange(candidates_log_prob, "prob -> prob ()")
outs = einops.rearrange(outs, 'b prob -> (b prob)')
candidates_log_prob, indices = torch.topk(outs, k=top_k)
new_candidates = []
for index in indices:
candidate_id, token_id = divmod(index.item(), vocab_size)
new_candidates.append(candidates_tex_ids[candidate_id].to(int).tolist() + [token_id])
candidates_tex_ids = new_candidates
candidates_tex_ids = torch.tensor(candidates_tex_ids)
padding_mask = get_tgt_padding_mask(candidates_tex_ids).cpu()
candidates_tex_ids = candidates_tex_ids.masked_fill(
padding_mask & (candidates_tex_ids != tex_tokenizer.token_to_id("[SEP]")),
tex_tokenizer.token_to_id("[PAD]")).tolist()
texs = tex_tokenizer.decode_batch(candidates_tex_ids, skip_special_tokens=True)
texs = [tex.replace("\\ ", "\\") for tex in texs]
return texs
def average_checkpoints(model_type, checkpoints_dir):
"""Returns model averaged from checkpoints
Args:
:model_type: -- pytorch_lightning.LightningModule that corresponds to checkpoints
:checkpoints_dir: -- path to checkpoints
"""
checkpoints = [checkpoint.path for checkpoint in os.scandir(checkpoints_dir)]
n_models = len(checkpoints)
assert n_models > 0
average_model = model_type.load_from_checkpoint(checkpoints[0])
for checkpoint in checkpoints[1:]:
model = model_type.load_from_checkpoint(checkpoint)
for weight, weight_to_add in zip(average_model.parameters(), model.parameters()):
weight.data.add_(weight_to_add.data)
for weight in average_model.parameters():
weight.data.divide_(n_models)
return average_model
|