bhavitvyamalik commited on
Commit
270ef28
1 Parent(s): 2538d98
Files changed (1) hide show
  1. utils.py +52 -0
utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.io import read_image, ImageReadMode
2
+ import torch
3
+ import numpy as np
4
+ from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
5
+ from torchvision.transforms.functional import InterpolationMode
6
+ from transformers import MBart50TokenizerFast
7
+ import json
8
+ from PIL import Image
9
+
10
+
11
+ class Transform(torch.nn.Module):
12
+ def __init__(self, image_size):
13
+ super().__init__()
14
+ self.transforms = torch.nn.Sequential(
15
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
16
+ CenterCrop(image_size),
17
+ ConvertImageDtype(torch.float),
18
+ Normalize(
19
+ (0.48145466, 0.4578275, 0.40821073),
20
+ (0.26862954, 0.26130258, 0.27577711),
21
+ ),
22
+ )
23
+
24
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
25
+ with torch.no_grad():
26
+ x = self.transforms(x)
27
+ return x
28
+
29
+
30
+ transform = Transform(224)
31
+
32
+
33
+ def get_transformed_image(image):
34
+ if image.shape[-1] == 3 and isinstance(image, np.ndarray):
35
+ image = image.transpose(2, 0, 1)
36
+ image = torch.tensor(image)
37
+ return transform(image).unsqueeze(0).permute(0, 2, 3, 1).numpy()
38
+
39
+ tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50")
40
+
41
+ language_mapping = {
42
+ "english": "en_XX",
43
+ "german": "de_DE",
44
+ "french": "fr_XX",
45
+ "spanish": "es_XX"
46
+ }
47
+
48
+ def generate_sequence(model, pixel_values, lang_code):
49
+ lang_code = language_mapping[lang_code]
50
+ output_ids = model.generate(input_ids=pixel_values, decoder_start_token_id=tokenizer.lang_code_to_id[lang_code], max_length=64, num_beams=4)
51
+ output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64)
52
+ return output_sequence