Spaces:
Sleeping
Sleeping
File size: 2,527 Bytes
f8a1225 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch
from tqdm import tqdm
import gan_cls_768
from torch.autograd import Variable
from PIL import Image
import matplotlib.pyplot as plt
device = "cuda" if torch.cuda.is_available() else "cpu"
def clean(txt):
txt = txt.lower()
txt = txt.strip()
txt = txt.strip('.')
return txt
max_len = 76
def tokenize(tokenizer, txt):
return tokenizer(
txt,
max_length=max_len,
padding='max_length',
truncation=True,
return_offsets_mapping=False
)
def encode(model_name, model, tokenizer, txt):
txt = clean(txt)
txt_tokenized = tokenize(tokenizer, txt)
for k, v in txt_tokenized.items():
txt_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)[None]
model.eval()
with torch.no_grad():
encoded = model(**txt_tokenized)
return encoded.last_hidden_state.squeeze()[0].cpu().numpy()
model_name = 'roberta-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(
model_name,
config=AutoConfig.from_pretrained(model_name, output_hidden_states=True)).to(device)
def generate_image(text, n):
embed = encode(model_name, model, tokenizer, text)
generator = torch.nn.DataParallel(gan_cls_768.generator().to(device))
generator.load_state_dict(torch.load("./gen_125.pth", map_location=torch.device('cpu')))
generator.eval()
embed2 = torch.FloatTensor(embed)
embed2 = embed2.unsqueeze(0)
right_embed = Variable(embed2.float()).to(device)
l = []
for i in tqdm(range(n)):
noise = Variable(torch.randn(1, 100)).to(device)
noise = noise.view(noise.size(0), 100, 1, 1)
fake_images = generator(right_embed, noise)
for idx, image in enumerate(fake_images):
im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy())
l.append(im)
return l
if __name__ == '__main__':
n = 10
imgs = generate_image('Red images', n)
fig, ax = plt.subplots(nrows=5, ncols=2)
ax = ax.flatten()
for idx, ax in enumerate(ax):
ax.imshow(imgs[idx])
ax.axis('off')
fig.tight_layout()
plt.show()
# while True:
# print('Type Caption: ')
# txt = input()
# print('Generating images...')
# generate_image(txt)
# print('Completed')
|