Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModel, AutoConfig | |
import torch | |
from tqdm import tqdm | |
import gan_cls_768 | |
from torch.autograd import Variable | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def clean(txt): | |
txt = txt.lower() | |
txt = txt.strip() | |
txt = txt.strip('.') | |
return txt | |
max_len = 76 | |
def tokenize(tokenizer, txt): | |
return tokenizer( | |
txt, | |
max_length=max_len, | |
padding='max_length', | |
truncation=True, | |
return_offsets_mapping=False | |
) | |
def encode(model_name, model, tokenizer, txt): | |
txt = clean(txt) | |
txt_tokenized = tokenize(tokenizer, txt) | |
for k, v in txt_tokenized.items(): | |
txt_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)[None] | |
model.eval() | |
with torch.no_grad(): | |
encoded = model(**txt_tokenized) | |
return encoded.last_hidden_state.squeeze()[0].cpu().numpy() | |
model_name = 'roberta-base' | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained( | |
model_name, | |
config=AutoConfig.from_pretrained(model_name, output_hidden_states=True)).to(device) | |
def generate_image(text, n): | |
embed = encode(model_name, model, tokenizer, text) | |
generator = torch.nn.DataParallel(gan_cls_768.generator().to(device)) | |
generator.load_state_dict(torch.load("./gen_125.pth", map_location=torch.device('cpu'))) | |
generator.eval() | |
embed2 = torch.FloatTensor(embed) | |
embed2 = embed2.unsqueeze(0) | |
right_embed = Variable(embed2.float()).to(device) | |
l = [] | |
for i in tqdm(range(n)): | |
noise = Variable(torch.randn(1, 100)).to(device) | |
noise = noise.view(noise.size(0), 100, 1, 1) | |
fake_images = generator(right_embed, noise) | |
for idx, image in enumerate(fake_images): | |
im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy()) | |
l.append(im) | |
return l | |
if __name__ == '__main__': | |
n = 10 | |
imgs = generate_image('Red images', n) | |
fig, ax = plt.subplots(nrows=5, ncols=2) | |
ax = ax.flatten() | |
for idx, ax in enumerate(ax): | |
ax.imshow(imgs[idx]) | |
ax.axis('off') | |
fig.tight_layout() | |
plt.show() | |
# while True: | |
# print('Type Caption: ') | |
# txt = input() | |
# print('Generating images...') | |
# generate_image(txt) | |
# print('Completed') | |