adymaharana commited on
Commit
1cac669
1 Parent(s): 77e955b
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, torch
2
  import gradio as gr
3
  import torchvision.utils as vutils
4
  import torchvision.transforms as transforms
@@ -68,6 +68,7 @@ def save_story_results(images, video_len=4, n_candidates=1, mask=None):
68
  def main(args):
69
  #device = 'cuda:0'
70
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
71
 
72
  model_url = 'https://drive.google.com/u/1/uc?id=1KAXVtE8lEE2Yc83VY7w6ycOOMkdWbmJo&export=sharing'
73
 
@@ -77,7 +78,7 @@ def main(args):
77
  #if not os.path.exists("./ckpt/25.pth"):
78
  # gdown.download(model_url, quiet=False, use_cookies=False, output="./ckpt/25.pth")
79
  # print("Downloaded checkpoint")
80
- assert os.path.exists("./ckpt/25.pth")
81
  gdown.download(png_url, quiet=True, use_cookies=False, output="demo_pororo_good.png")
82
 
83
  if args.debug:
@@ -102,6 +103,9 @@ def main(args):
102
  transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
103
  )
104
 
 
 
 
105
  def predict(caption_1, caption_2, caption_3, caption_4, source='Pororo', top_k=32, top_p=0.2, n_candidates=4,
106
  supercondition=False):
107
 
 
1
+ import os, sys, torch
2
  import gradio as gr
3
  import torchvision.utils as vutils
4
  import torchvision.transforms as transforms
 
68
  def main(args):
69
  #device = 'cuda:0'
70
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
71
+ #device = torch.device('cpu')
72
 
73
  model_url = 'https://drive.google.com/u/1/uc?id=1KAXVtE8lEE2Yc83VY7w6ycOOMkdWbmJo&export=sharing'
74
 
 
78
  #if not os.path.exists("./ckpt/25.pth"):
79
  # gdown.download(model_url, quiet=False, use_cookies=False, output="./ckpt/25.pth")
80
  # print("Downloaded checkpoint")
81
+ #assert os.path.exists("./ckpt/25.pth")
82
  gdown.download(png_url, quiet=True, use_cookies=False, output="demo_pororo_good.png")
83
 
84
  if args.debug:
 
103
  transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
104
  )
105
 
106
+ #torch.save(model, './ckpt/checkpoint.pt')
107
+ #sys.exit()
108
+
109
  def predict(caption_1, caption_2, caption_3, caption_4, source='Pororo', top_k=32, top_p=0.2, n_candidates=4,
110
  supercondition=False):
111
 
dalle/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/dalle/__pycache__/__init__.cpython-38.pyc and b/dalle/__pycache__/__init__.cpython-38.pyc differ
 
dalle/models/__init__.py CHANGED
@@ -23,6 +23,7 @@ from ..utils.utils import save_image
23
  from .tokenizer import build_tokenizer
24
  import numpy as np
25
  from .stage2.layers import CrossAttentionLayer
 
26
 
27
  _MODELS = {
28
  'minDALL-E/1.3B': 'https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz'
@@ -1191,7 +1192,9 @@ class StoryDalle(Dalle):
1191
  print("Loaded tokenizer from finetuned checkpoint")
1192
  print(model.cross_attention_idxs)
1193
  print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
 
1194
  # model.from_ckpt(args.model_name_or_path)
 
1195
  try:
1196
  model.load_state_dict(torch.load(args.model_name_or_path, map_location=torch.device('cpu'))['state_dict'])
1197
  except KeyError:
@@ -1248,9 +1251,9 @@ class StoryDalle(Dalle):
1248
  #{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}
1249
 
1250
  with torch.no_grad():
1251
- with autocast(enabled=False):
1252
- codes = self.stage1.get_codes(images).detach()
1253
- src_codes = self.stage1.get_codes(src_images).detach()
1254
 
1255
  B, C, H, W = images.shape
1256
 
@@ -1310,8 +1313,8 @@ class StoryDalle(Dalle):
1310
  # Check if the encoding works as intended
1311
  # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
1312
 
1313
- tokens = tokens.to(device)
1314
- source = source.to(device)
1315
 
1316
  # print(tokens.shape, sent_embeds.shape, prompt.shape)
1317
  B, L, _ = sent_embeds.shape
@@ -1322,8 +1325,8 @@ class StoryDalle(Dalle):
1322
  prompt = sent_embeds
1323
  pos_enc_prompt = get_positional_encoding(torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B*L, -1).to(self.device), mode='1d')
1324
 
1325
- with autocast(enabled=False):
1326
- src_codes = self.stage1.get_codes(source).detach()
1327
  src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len, dim=0)
1328
  print(tokens.shape, src_codes.shape, prompt.shape)
1329
  if self.config.story.condition:
@@ -1378,8 +1381,8 @@ class StoryDalle(Dalle):
1378
  # Check if the encoding works as intended
1379
  # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
1380
 
1381
- tokens = tokens.to(device)
1382
- source = source.to(device)
1383
 
1384
  # print(tokens.shape, sent_embeds.shape, prompt.shape)
1385
  B, L, _ = sent_embeds.shape
@@ -1389,10 +1392,10 @@ class StoryDalle(Dalle):
1389
  else:
1390
  prompt = sent_embeds
1391
  pos_enc_prompt = get_positional_encoding(
1392
- torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B * L, -1).to(self.device), mode='1d')
1393
 
1394
- with autocast(enabled=False):
1395
- src_codes = self.stage1.get_codes(source).detach()
1396
 
1397
  # repeat inputs to adjust to n_candidates and story length
1398
  src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len * n_candidates, dim=0)
 
23
  from .tokenizer import build_tokenizer
24
  import numpy as np
25
  from .stage2.layers import CrossAttentionLayer
26
+ from huggingface_hub import hf_hub_download
27
 
28
  _MODELS = {
29
  'minDALL-E/1.3B': 'https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz'
 
1192
  print("Loaded tokenizer from finetuned checkpoint")
1193
  print(model.cross_attention_idxs)
1194
  print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
1195
+
1196
  # model.from_ckpt(args.model_name_or_path)
1197
+
1198
  try:
1199
  model.load_state_dict(torch.load(args.model_name_or_path, map_location=torch.device('cpu'))['state_dict'])
1200
  except KeyError:
 
1251
  #{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}
1252
 
1253
  with torch.no_grad():
1254
+ #with autocast(enabled=False):
1255
+ codes = self.stage1.get_codes(images).detach()
1256
+ src_codes = self.stage1.get_codes(src_images).detach()
1257
 
1258
  B, C, H, W = images.shape
1259
 
 
1313
  # Check if the encoding works as intended
1314
  # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
1315
 
1316
+ #tokens = tokens.to(device)
1317
+ #source = source.to(device)
1318
 
1319
  # print(tokens.shape, sent_embeds.shape, prompt.shape)
1320
  B, L, _ = sent_embeds.shape
 
1325
  prompt = sent_embeds
1326
  pos_enc_prompt = get_positional_encoding(torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B*L, -1).to(self.device), mode='1d')
1327
 
1328
+ #with autocast(enabled=False):
1329
+ src_codes = self.stage1.get_codes(source).detach()
1330
  src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len, dim=0)
1331
  print(tokens.shape, src_codes.shape, prompt.shape)
1332
  if self.config.story.condition:
 
1381
  # Check if the encoding works as intended
1382
  # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
1383
 
1384
+ #tokens = tokens.to(device)
1385
+ #source = source.to(device)
1386
 
1387
  # print(tokens.shape, sent_embeds.shape, prompt.shape)
1388
  B, L, _ = sent_embeds.shape
 
1392
  else:
1393
  prompt = sent_embeds
1394
  pos_enc_prompt = get_positional_encoding(
1395
+ torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B * L, -1).to(tokens.device), mode='1d')
1396
 
1397
+ #with autocast(enabled=False):
1398
+ src_codes = self.stage1.get_codes(source).detach()
1399
 
1400
  # repeat inputs to adjust to n_candidates and story length
1401
  src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len * n_candidates, dim=0)
dalle/models/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/dalle/models/__pycache__/__init__.cpython-38.pyc and b/dalle/models/__pycache__/__init__.cpython-38.pyc differ