Spaces:
Runtime error
Runtime error
import os | |
from functools import lru_cache | |
from typing import Any, Callable, Dict, List, Optional, Tuple | |
import attr | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import yaml | |
from glide_text2im.tokenizer.simple_tokenizer import SimpleTokenizer | |
from .encoders import ImageEncoder, TextEncoder | |
def default_config_path() -> str: | |
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.yaml") | |
class CLIPModel: | |
config: Dict[str, Any] = attr.ib() | |
text_encoder: nn.Module = attr.ib() | |
image_encoder: nn.Module = attr.ib() | |
logit_scale: torch.Tensor = attr.ib() | |
device: torch.device = attr.ib() | |
tokenizer: SimpleTokenizer = attr.ib() | |
def encode_prompts(self, prompts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: | |
tokens = [] | |
lens = [] | |
for prompt in prompts: | |
sub_tokens, sub_len = self.tokenizer.padded_tokens_and_len( | |
self.tokenizer.encode(prompt), self.text_encoder.max_text_len | |
) | |
tokens.append(sub_tokens) | |
lens.append(sub_len) | |
return ( | |
torch.tensor(tokens).to(dtype=torch.long, device=self.device), | |
torch.tensor(lens).to(dtype=torch.long, device=self.device), | |
) | |
def text_embeddings(self, prompts: List[str]) -> torch.Tensor: | |
tokens, lens = self.encode_prompts(prompts) | |
z_t = self.text_encoder(tokens, lens) | |
return z_t / (torch.linalg.norm(z_t, dim=-1, keepdim=True) + 1e-12) | |
def image_embeddings(self, images: torch.Tensor, t: torch.Tensor) -> torch.Tensor: | |
z_i = self.image_encoder((images + 1) * 127.5, t) | |
return z_i / (torch.linalg.norm(z_i, dim=-1, keepdim=True) + 1e-12) | |
def cond_fn(self, prompts: List[str], grad_scale: float) -> Callable[..., torch.Tensor]: | |
with torch.no_grad(): | |
z_t = self.text_embeddings(prompts) | |
def cond_fn(x, t, grad_scale=grad_scale, **kwargs): | |
with torch.enable_grad(): | |
x_var = x.detach().requires_grad_(True) | |
z_i = self.image_embeddings(x_var, t) | |
loss = torch.exp(self.logit_scale) * (z_t * z_i).sum() | |
grad = torch.autograd.grad(loss, x_var)[0].detach() | |
return grad * grad_scale | |
return cond_fn | |
def create_clip_model( | |
config_path: Optional[str] = None, | |
device: Optional[torch.device] = None, | |
tokenizer: Optional[SimpleTokenizer] = None, | |
) -> CLIPModel: | |
if config_path is None: | |
config_path = default_config_path() | |
if device is None: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if tokenizer is None: | |
tokenizer = SimpleTokenizer() | |
with open(config_path, "r") as f: | |
config = yaml.load(f, Loader=yaml.SafeLoader) | |
text_encoder = TextEncoder( | |
n_bpe_vocab=config["n_vocab"], | |
max_text_len=config["max_text_len"], | |
n_embd=config["n_embd"], | |
n_head=config["n_head_text"], | |
n_xf_blocks=config["n_xf_blocks_text"], | |
n_head_state=config["n_head_state_text"], | |
device=device, | |
) | |
image_encoder = ImageEncoder( | |
image_size=config["image_size"], | |
patch_size=config["patch_size"], | |
n_embd=config["n_embd"], | |
n_head=config["n_head_image"], | |
n_xf_blocks=config["n_xf_blocks_image"], | |
n_head_state=config["n_head_state_image"], | |
n_timestep=config["n_timesteps"], | |
device=device, | |
) | |
logit_scale = torch.tensor( | |
np.log(config["logit_scale"]), | |
dtype=torch.float32, | |
device=device, | |
requires_grad=False, | |
) | |
return CLIPModel( | |
config=config, | |
text_encoder=text_encoder, | |
image_encoder=image_encoder, | |
logit_scale=logit_scale, | |
device=device, | |
tokenizer=tokenizer, | |
) | |