efeperro commited on
Commit
7cb104b
·
verified ·
1 Parent(s): 87d205e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py CHANGED
@@ -4,7 +4,92 @@ from PIL import Image
4
  from torchvision import transforms
5
  from transformers import T5Tokenizer, ViTFeatureExtractor
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  # Model loading and setting up the device
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
  model = torch.load("model_vit_ai.pt", map_location=device)
10
  model.to(device)
 
4
  from torchvision import transforms
5
  from transformers import T5Tokenizer, ViTFeatureExtractor
6
 
7
+ class Encoder(nn.Module):
8
+ def __init__(self, pretrained_model):
9
+ """
10
+ Implements the Encoder."
11
+
12
+ Args:
13
+ pretrained_model (str): name of the pretrained model
14
+
15
+ """
16
+
17
+ super(Encoder, self).__init__()
18
+
19
+ self.encoder = ViTModel.from_pretrained(pretrained_model)
20
+
21
+ def forward(self, input):
22
+ out = self.encoder(pixel_values = input)
23
+
24
+ return out
25
+
26
+ class Decoder(nn.Module):
27
+ def __init__(self, pretrained_model, encoder_modeldim):
28
+ """
29
+ Implements the Decoder."
30
+
31
+ Args:
32
+ pretrained_model (str): name of the pretrained model
33
+
34
+ """
35
+
36
+ super(Decoder, self).__init__()
37
+
38
+ self.decoder = T5ForConditionalGeneration.from_pretrained(pretrained_model)
39
+ self.linear = nn.Linear(self.decoder.model_dim, encoder_modeldim, bias = False)
40
+ self.encoder_modeldim = encoder_modeldim
41
+
42
+ def forward(self, output_encoder, targets, decoder_ids=None):
43
+
44
+ if self.decoder.model_dim!=self.encoder_modeldim:
45
+ print(f"Changed model hidden dimension from {self.encoder_modeldim} to {self.decoder.model_dim}")
46
+ output_encoder = self.linear(output_encoder)
47
+ print(output_encoder.shape)
48
+
49
+ # Validation/Testing
50
+ if decoder_ids is not None:
51
+ out = self.decoder(encoder_outputs=output_encoder, decoder_input_ids=decoder_ids)
52
+
53
+ # Training
54
+ else:
55
+ out = self.decoder(encoder_outputs=output_encoder, labels=targets)
56
+
57
+ return out
58
+
59
+ class EncoderDecoder(nn.Module):
60
+ def __init__(self, pretrained_model: Tuple[str], encoder_dmodel=768, eos_token_id=None, pad_token_id=None):
61
+ """
62
+ Implements a model that combines MyEncoder and MyDecoder."
63
+
64
+ Args:
65
+ pretrained_model (tuple): name of the pretrained model
66
+ encoder_dmodel (int): hidden dimension of the encoder output
67
+ eos_token_id (torch.long): token used for end of sentence
68
+ pad_token_id (torch.long): token used for padding
69
+
70
+ """
71
+
72
+ super(EncoderDecoder, self).__init__()
73
+ self.eos_token_id = eos_token_id
74
+ self.pad_token_id = pad_token_id
75
+ self.encoder = Encoder(pretrained_model[0])
76
+ self.encoder_dmodel = encoder_dmodel
77
+
78
+ # Freeze parameters from encoder
79
+ #for p in self.encoder.parameters():
80
+ # p.requires_grad=False
81
+
82
+ self.decoder = Decoder(pretrained_model[1], self.encoder_dmodel)
83
+ self.decoder_start_token_id = self.decoder.decoder.config.decoder_start_token_id
84
+
85
+ def forward(self, images = None, targets = None, decoder_ids = None):
86
+ output_encoder = self.encoder(images)
87
+ out = self.decoder(output_encoder, targets, decoder_ids)
88
+
89
+ return out
90
+
91
  # Model loading and setting up the device
92
+
93
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
  model = torch.load("model_vit_ai.pt", map_location=device)
95
  model.to(device)