AlexWortega's picture
Create app.py
28e228f
raw
history blame
6.66 kB
import torch
from rudalle import get_tokenizer, get_vae
from rudalle.utils import seed_everything
import sys
import gradio as gr
from PIL import Image
device = 'cpu'
import clip
import os
from torch import nn
import numpy as np
import torch
import torch.nn.functional as nnf
import sys
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from tqdm import tqdm, trange
import PIL.Image
from IPython.display import Image
import transformers
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_path = 'coco_prefix_latest.pt'
#title Model
class MLP(nn.Module):
def forward(self, x):
return self.model(x)
def __init__(self, sizes, bias=True, act=nn.Tanh):
super(MLP, self).__init__()
layers = []
for i in range(len(sizes) -1):
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
if i < len(sizes) - 2:
layers.append(act())
self.model = nn.Sequential(*layers)
class ClipCaptionModel(nn.Module):
#@functools.lru_cache #FIXME
def get_dummy_token(self, batch_size, device):
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
def forward(self, tokens, prefix, mask, labels):
embedding_text = self.gpt.transformer.wte(tokens)
prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
#print(embedding_text.size()) #torch.Size([5, 67, 768])
#print(prefix_projections.size()) #torch.Size([5, 1, 768])
embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
if labels is not None:
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
labels = torch.cat((dummy_token, tokens), dim=1)
out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
return out
def __init__(self, prefix_length, prefix_size: int = 512):
super(ClipCaptionModel, self).__init__()
self.prefix_length = prefix_length
self.gpt = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
if prefix_length > 10: # not enough memory
self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
else:
self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
class ClipCaptionPrefix(ClipCaptionModel):
def parameters(self, recurse: bool = True):
return self.clip_project.parameters()
def train(self, mode: bool = True):
super(ClipCaptionPrefix, self).train(mode)
self.gpt.eval()
return self
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
prefix_length = 10
model = ClipCaptionModel(prefix_length)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.to(device)
def generate2(
model,
tokenizer,
tokens=None,
prompt=None,
embed=None,
entry_count=1,
entry_length=67,
top_p=0.98,
temperature=1.,
stop_token = '',
):
model.eval()
generated_num = 0
generated_list = []
stop_token_index = tokenizer.encode(stop_token)[0]
filter_value = -float("Inf")
device = next(model.parameters()).device
with torch.no_grad():
for entry_idx in trange(entry_count):
if embed is not None:
generated = embed
else:
if tokens is None:
tokens = torch.tensor(tokenizer.encode(prompt))
tokens = tokens.unsqueeze(0).to(device)
generated = model.gpt.transformer.wte(tokens)
for i in range(entry_length):
outputs = model.gpt(inputs_embeds=generated)
logits = outputs.logits
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1
].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[:, indices_to_remove] = filter_value
#
top_k = 2000
top_p = 0.98
#print(logits)
#next_token = transformers.top_k_top_p_filtering(logits.to(torch.int64).unsqueeze(0), top_k=top_k, top_p=top_p)
next_token = torch.argmax(logits, -1).unsqueeze(0)
next_token_embed = model.gpt.transformer.wte(next_token)
if tokens is None:
tokens = next_token
else:
tokens = torch.cat((tokens, next_token), dim=1)
generated = torch.cat((generated, next_token_embed), dim=1)
if stop_token_index == next_token.item():
break
output_list = list(tokens.squeeze().cpu().numpy())
output_text = tokenizer.decode(output_list)
generated_list.append(output_text)
return generated_list[0]
def _to_caption(pil_image):
image = preprocess(pil_image).unsqueeze(0).to(device)
with torch.no_grad():
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
return generated_text_prefix
def classify_image(inp):
print(type(inp))
inp = Image.fromarray(inp)
texts = _to_caption(inp)
print(texts)
return texts
image = gr.inputs.Image(shape=(128, 128))
label = gr.outputs.Label(num_top_classes=3)
iface = gr.Interface(fn=classify_image, description="https://github.com/AlexWortega/ruImageCaptioning RuImage Captioning trained for a image2text task to predict food calories by https://t.me/lovedeathtransformers Alex Wortega", inputs=image, outputs="text",examples=[
['b9c277a3.jpeg']])
iface.launch()