import torch.nn as nn import torch # import math # from torchvision import transforms import os # from timm.models import create_model from typing import Any, Dict, List, Optional, Union from transformers import LlamaTokenizer from diffusers import DiffusionPipeline # from torchvision.transforms.functional import pil_to_tensor # import torch from PIL import Image from torchvision import transforms # from qformer.qformer_quantizer import Blip2QformerQuantizer # from diffusers import StableUnCLIPImg2ImgPipeline from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline WEIGHTS_NAME = 'seed_quantizer.pt' DIFFUSION_NAME = 'diffusion_model' class ImageTokenizer(nn.Module): def __init__(self, model_path, diffusion_model_path=None, load_diffusion=False, image_size=224, device='cuda', fp16=True, **kwargs): super().__init__() from .seed_qformer.qformer_quantizer import Blip2QformerQuantizer model = Blip2QformerQuantizer.from_pretrained(pretrained_model_path=model_path, vit_precision='fp16' if fp16 else 'fp32', **kwargs).eval() if diffusion_model_path is not None and load_diffusion: # diffusion_model = DiffusionPipeline.from_pretrained(diffusion_model_path, # torch_dtype=torch.float16 if fp16 else torch.float32) diffusion_model = StableUnCLIPImg2ImgPipeline.from_pretrained(diffusion_model_path, torch_dtype=torch.float16 if fp16 else torch.float32) self.diffusion_model = diffusion_model.to(device) else: self.diffusion_model = None model = model.to(device) processor = transforms.Compose([ transforms.Resize((image_size, image_size), interpolation=3), # transforms.Resize(image_size, interpolation=3), # transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ]) if fp16: model = model.half() shape_latents = torch.Size([1, 4, 96, 96]) self.latents = torch.randn(shape_latents, generator=None, device=device, dtype=torch.float16, layout=torch.strided) shape_noise = torch.Size([1, 1024]) self.noise = torch.randn(shape_noise, generator=None, device=device, dtype=torch.float16, layout=torch.strided) self.model = model self.processor = processor self.device = device self.fp16 = fp16 def __len__(self): return self.model.n_embed def encode(self, image_torch): '''Convert a batch of img to code Args: model: The tokenizer model. img: [b, c, h, w] ''' if len(image_torch.shape) == 3: image_torch = image_torch.unsqueeze(0) # img = image_torch.to(self.device) img = image_torch if self.fp16: img = img.half() with torch.no_grad(): id, _ = self.model.get_codebook_indices(img) return id.view(img.shape[0], -1) def decode(self, indices, negative_indices=None, guidance_scale=10, num_inference_steps=20): image_embeds = self.model.get_codebook_entry(indices) # image = self.diffusion_model(image_embeds=image_embed, # noise_level=0, # num_inference_steps=20, # latents=self.latents, # noise=self.noise).images if negative_indices is not None: assert indices.shape == negative_indices.shape, 'Negative indices must have the same shape with indices' negative_image_embeds = self.model.get_codebook_entry(negative_indices) else: negative_image_embeds = None image = self.diffusion_model( image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, guidance_scale=guidance_scale, noise_level=0, num_inference_steps=num_inference_steps, latents=self.latents, ).images return image class SeedLlamaTokenizer(LlamaTokenizer): def __init__(self, vocab_file, unk_token="", bos_token="", eos_token="", pad_token=None, sp_model_kwargs: Optional[Dict[str, Any]] = None, add_bos_token=True, add_eos_token=False, clean_up_tokenization_spaces=False, device='cuda', fp16=True, load_diffusion=False, encoder_url=None, diffusion_path=None, **kwargs): super().__init__(vocab_file, unk_token, bos_token, eos_token, pad_token, sp_model_kwargs, add_bos_token, add_eos_token, clean_up_tokenization_spaces, **kwargs) self.device = device self.fp16 = fp16 self.pad_token = self.unk_token self.load_diffusion = load_diffusion self.encoder_url = encoder_url self.diffusion_path = diffusion_path self.load_image_tokenizer() def load_image_tokenizer(self): if not hasattr(self, '_image_tokenizer'): if self.encoder_url is not None: model_path = self.encoder_url else: assert hasattr(self, 'name_or_path') and os.path.exists(self.name_or_path) model_path = os.path.join(self.name_or_path, WEIGHTS_NAME) # diffusion_model_path = os.path.join(self.name_or_path, DIFFUSION_NAME) # diffusion_model_path = 'stabilityai/stable-diffusion-2-1-unclip' self._image_tokenizer = ImageTokenizer(model_path=model_path, diffusion_model_path=self.diffusion_path, load_diffusion=self.load_diffusion, device=self.device, fp16=self.fp16) @property def image_tokenizer(self): if not hasattr(self, '_image_tokenizer'): if self.encoder_url is not None: model_path = self.encoder_url else: assert hasattr(self, 'name_or_path') and os.path.exists(self.name_or_path) model_path = os.path.join(self.name_or_path, WEIGHTS_NAME) # diffusion_model_path = os.path.join(self.name_or_path, DIFFUSION_NAME) # diffusion_model_path = 'stabilityai/stable-diffusion-2-1-unclip' self._image_tokenizer = ImageTokenizer(model_path=model_path, diffusion_model_path=self.diffusion_path, load_diffusion=self.load_diffusion, device=self.device, fp16=self.fp16) return self._image_tokenizer @property def num_image_tokens(self): return 8192 # self.image_tokenizer.num_tokens # allow not load def to(self, device): self.device = device if hasattr(self, '_image_tokenizer'): self._image_tokenizer.to(device=device) def encode_image( self, image_path=None, image_pil=None, image_torch=None, image_size: int = 224, ): assert (image_path is None) + (image_pil is None) + (image_torch is None) == 2 # need_norm_to_1 = False if image_path is not None: image_pil = Image.open(image_path).convert('RGB') if image_pil is not None: image_torch = self.image_tokenizer.processor(image_pil) image_torch = image_torch.to(self.device) return self.image_tokenizer.encode(image_torch) def decode_image(self, indices, negative_indices=None, guidance_scale=10): indices = indices.to(self.device) if negative_indices is not None: negative_indices = negative_indices.to(self.device) image = self.image_tokenizer.decode( indices, negative_indices=negative_indices, guidance_scale=guidance_scale, ) return image