# Install required packages !pip install sentencepiece !pip install git+https://github.com/huggingface/transformers.git@cae78c46 !pip install diffusers !pip install tokenizers==0.12.1 !pip install datasets !pip install accelerate !pip install evaluate !pip install gradio==4.12.0 !pip install gradio_client==0.8.0 !pip install -i https://download.pytorch.org/whl/cu118 torch==2.0 torchvision==0.15 torchaudio==2.0 # conversation.py import dataclasses from enum import auto, Enum from typing import List, Tuple class SeparatorStyle(Enum): """Different separator style.""" SINGLE = auto() TWO = auto() MPT = auto() @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" system: str roles: List[str] messages: List[List[str]] offset: int sep_style: SeparatorStyle = SeparatorStyle.SINGLE sep: str = "###" sep2: str = None version: str = "Unknown" skip_next: bool = False def get_prompt(self): if self.sep_style == SeparatorStyle.SINGLE: ret = self.system + self.sep for role, message in self.messages: if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + self.sep else: ret += role + ":" return ret elif self.sep_style == SeparatorStyle.TWO: seps = [self.sep, self.sep2] ret = self.system + seps[0] for i, (role, message) in enumerate(self.messages): if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + seps[i % 2] else: ret += role + ":" return ret if self.sep_style == SeparatorStyle.MPT: ret = self.system + self.sep for role, message in self.messages: if message: if type(message) is tuple: message, _, _ = message ret += role + message + self.sep else: ret += role return ret else: raise ValueError(f"Invalid style: {self.sep_style}") def append_message(self, role, message): self.messages.append([role, message]) def get_images(self, return_pil=False): images = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO from PIL import Image msg, image, image_process_mode = msg if image_process_mode == "Pad": def expand2square(pil_img, background_color=(122, 116, 104)): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image) elif image_process_mode == "Crop": pass elif image_process_mode == "Resize": image = image.resize((224, 224)) else: raise ValueError(f"Invalid image_process_mode: {image_process_mode}") max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) if return_pil: images.append(image) else: buffered = BytesIO() image.save(buffered, format="JPEG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() images.append(img_b64_str) return images def to_gradio_chatbot(self): ret = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO msg, image, image_process_mode = msg max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) # image = image.resize((224, 224)) buffered = BytesIO() image.save(buffered, format="JPEG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() img_str = f'user upload image' msg = msg.replace('', img_str) ret.append([msg, None]) else: ret[-1][-1] = msg return ret def copy(self): return Conversation( system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2) def dict(self): if len(self.get_images()) > 0: return { "system": self.system, "roles": self.roles, "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } return { "system": self.system, "roles": self.roles, "messages": self.messages, "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } conv_v1 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ("Human", "Give three tips for staying healthy."), ("Assistant", "Sure, here are three tips for staying healthy:\n" "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. " "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, " "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or " "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening " "activities at least two days per week.\n" "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, " "vegetables, whole grains, lean proteins, and healthy fats can help support " "your overall health. Try to limit your intake of processed and high-sugar foods, " "and aim to drink plenty of water throughout the day.\n" "3. Get enough sleep: Getting enough quality sleep is essential for your physical " "and mental health. Adults should aim for seven to nine hours of sleep per night. " "Establish a regular sleep schedule and try to create a relaxing bedtime routine to " "help improve the quality of your sleep.") ), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) conv_v1_2 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ("Human", "What are the key differences between renewable and non-renewable energy sources?"), ("Assistant", "Renewable energy sources are those that can be replenished naturally in a relatively " "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " "Non-renewable energy sources, on the other hand, are finite and will eventually be " "depleted, such as coal, oil, and natural gas. Here are some key differences between " "renewable and non-renewable energy sources:\n" "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " "energy sources are finite and will eventually run out.\n" "2. Environmental impact: Renewable energy sources have a much lower environmental impact " "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " "and other negative effects.\n" "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " "have lower operational costs than non-renewable sources.\n" "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " "locations than non-renewable sources.\n" "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " "situations and needs, while non-renewable sources are more rigid and inflexible.\n" "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") ), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) conv_vicuna_v1_1 = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) conv_mpt = Conversation( system="""system - You are a helpful language and vision assistant. - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. - You should follow the instructions carefully and explain your answers in detail.""", roles=("user\n", "assistant\n"), version="mpt", messages=(), offset=0, sep_style=SeparatorStyle.MPT, sep="", ) conv_mpt_text = Conversation( system="""system - You are a helpful assistant chatbot trained by MosaicML. - You answer questions. - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""", roles=("user\n", "assistant\n"), version="mpt", messages=(), offset=0, sep_style=SeparatorStyle.MPT, sep="", ) conv_bair_v1 = Conversation( system="BEGINNING OF CONVERSATION:", roles=("USER", "GPT"), messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) simple_conv = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ("Human", "Hi!"), ("Assistant", "Hi there! How can I help you today?") ), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) simple_conv_multimodal = Conversation( system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab." "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." "Follow the instructions carefully and explain your answers in detail.", roles=("Human", "Assistant"), messages=( ("Human", "Hi!"), ("Assistant", "Hi there! How can I help you today?\n") ), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) simple_conv_mpt_multimodal = Conversation( system="""system - You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. - You should follow the instructions carefully and explain your answers in detail.""", roles=("user\n", "assistant\n"), version="mpt", messages=(), offset=0, sep_style=SeparatorStyle.MPT, sep="", ) simple_conv_legacy = Conversation( system="You are LLaVA, a large language model trained by UW Madison WAIV Lab." "You are designed to assist human with a variety of tasks using natural language." "Follow the instructions carefully.", roles=("Human", "Assistant"), messages=( ("Human", "Hi!\n\n### Response:"), ("Assistant", "Hi there! How can I help you today?\n") ), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) conv_llava_v1 = Conversation( system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab." "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." "Follow the instructions carefully and explain your answers in detail.", roles=("USER", "ASSISTANT"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) default_conversation = conv_v1_2 conv_templates = { "default": conv_v1_2, "simple": simple_conv, "simple_legacy": simple_conv_legacy, "multimodal": simple_conv_multimodal, "mpt_multimodal": simple_conv_mpt_multimodal, "llava_v1": conv_llava_v1, # fastchat "v1": conv_v1_2, "bair_v1": conv_bair_v1, "vicuna_v1_1": conv_vicuna_v1_1, "mpt": conv_mpt, "mpt_text": conv_mpt_text, } if __name__ == "__main__": print(default_conversation.get_prompt()) # mgie_llava.py from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss from transformers import AutoConfig, AutoModelForCausalLM, \ LlamaConfig, LlamaModel, LlamaForCausalLM, \ CLIPVisionModel, CLIPImageProcessor from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast import os, diffusers DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" class LlavaConfig(LlamaConfig): model_type = "llava" class LlavaLlamaModel(LlamaModel): config_class = LlavaConfig def __init__(self, config: LlamaConfig): super(LlavaLlamaModel, self).__init__(config) if hasattr(config, "mm_vision_tower"): # HACK: for FSDP self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)] # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower) if hasattr(config, "use_mm_proj"): self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size) def get_vision_tower(self): vision_tower = getattr(self, 'vision_tower', None) if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower def initialize_vision_modules(self, vision_tower, mm_vision_select_layer, pretrain_mm_mlp_adapter=None, fsdp=None): self.config.mm_vision_tower = vision_tower image_processor = CLIPImageProcessor.from_pretrained(vision_tower) if not hasattr(self, 'vision_tower'): vision_tower = CLIPVisionModel.from_pretrained(vision_tower) else: vision_tower = self.vision_tower[0] vision_tower.requires_grad_(False) if fsdp is not None and len(fsdp) > 0: self.vision_tower = [vision_tower] else: self.vision_tower = vision_tower vision_config = vision_tower.config num_patches = (vision_config.image_size // vision_config.patch_size) ** 2 self.config.use_mm_proj = True self.config.mm_hidden_size = vision_config.hidden_size self.config.mm_vision_select_layer = mm_vision_select_layer if not hasattr(self, 'mm_projector'): self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.hidden_size) if pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()}) return dict( image_processor=image_processor, image_token_len=num_patches, vision_config=vision_config ) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: # HACK: replace back original embeddings for LLaVA pretraining orig_embeds_params = getattr(self, 'orig_embeds_params', None) # if orig_embeds_params is not None: # orig_embeds_params = orig_embeds_params[0] # with torch.no_grad(): # self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) vision_tower = self.get_vision_tower() if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None: # TODO: this is a modified multimodal LLM -- Haotian Liu with torch.no_grad(): if type(images) is list: # variable length images image_features = [] for image in images: image_forward_out = vision_tower(image.unsqueeze(0), output_hidden_states=True) select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1) select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer] image_feature = select_hidden_state[:, 1:] image_features.append(image_feature) else: image_forward_outs = vision_tower(images.to(vision_tower.dtype), output_hidden_states=True) select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1) select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer] image_features = select_hidden_state[:, 1:].to(images.dtype) if type(images) is list: image_features = [self.mm_projector(image_feature)[0] for image_feature in image_features] else: image_features = self.mm_projector(image_features) dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) dummy_image_features = self.mm_projector(dummy_image_features) new_input_embeds = [] cur_image_idx = 0 for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum() new_input_embeds.append(cur_input_embeds) cur_image_idx += 1 continue if vision_tower.config.use_im_start_end: cur_image_features = image_features[cur_image_idx] num_patches = cur_image_features.shape[0] if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum(): raise ValueError("The number of image start tokens and image end tokens should be the same.") image_start_tokens = torch.where(cur_input_ids == vision_tower.config.im_start_token)[0] for image_start_token_pos in image_start_tokens: cur_image_features = image_features[cur_image_idx].to(device=cur_input_embeds.device) num_patches = cur_image_features.shape[0] if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token: raise ValueError("The image end token should follow the image start token.") if orig_embeds_params is not None: cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0) else: cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0) cur_image_idx += 1 new_input_embeds.append(cur_new_input_embeds) else: cur_image_features = image_features[cur_image_idx] num_patches = cur_image_features.shape[0] if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches: raise ValueError("The number of image patch tokens should be the same as the number of image patches.") masked_indices = torch.where(cur_input_ids == vision_tower.config.im_patch_token)[0] mask_index_start = masked_indices[0] if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any(): raise ValueError("The image patch tokens should be consecutive.") if orig_embeds_params is not None: cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_image_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0) else: cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0) new_input_embeds.append(cur_new_input_embeds) cur_image_idx += 1 inputs_embeds = torch.stack(new_input_embeds, dim=0) return super(LlavaLlamaModel, self).forward( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) class EditMapper(nn.Module): def __init__(self): super().__init__() self.llm2hid = nn.Linear(4096, 512) self.query = nn.Parameter(torch.randn(1, 77, 512)) self.mapper = nn.Transformer(batch_first=True, norm_first=True, d_model=512, nhead=4, num_encoder_layers=4, num_decoder_layers=4, dim_feedforward=2048, dropout=0.0) self.hid2feat = nn.Linear(512, 768) def forward(self, llm, emb): hid = self.llm2hid(llm+emb) hid = self.mapper(hid, self.query.repeat(llm.shape[0], 1, 1)) feat = self.hid2feat(hid) return feat class LlavaLlamaForCausalLM(LlamaForCausalLM): config_class = LlavaConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = LlavaLlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.edit_head = EditMapper() '''self.scheduler, self.vae, self.unet = [diffusers.DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='scheduler'), diffusers.AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='vae'), diffusers.UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='unet')] self.vae.requires_grad_(False) self.unet.register_to_config(in_channels=8) with torch.no_grad(): conv = torch.nn.Conv2d(8, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding) conv.weight.zero_() conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight) self.unet.conv_in = conv''' # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def get_vision_tower(self): return self.get_model().get_vision_tower() def get_vision_tower(self): model = self.get_model() vision_tower = model.vision_tower if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, p2p_inp=None, p2p_ans=None ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, images=images ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if labels is not None: llm = [] for i in range(labels.shape[0]): try: p = labels[i].data.cpu().tolist().index(32003)-1 except: p = len(labels[i])-9 p = min(len(hidden_states[i])-9, p) llm.append(hidden_states[i][p:p+8].unsqueeze(0)) llm = torch.cat(llm, dim=0) hid_edit = self.edit_head(llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1)) B, DROP = labels.shape[0], 0.05 hid_null = self.edit_head(torch.zeros(B, 8, 4096, device=labels.device), self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1)) with torch.no_grad(): lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample()*self.vae.config.scaling_factor, self.vae.encode(p2p_inp).latent_dist.mode() lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device), torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)] noise = torch.randn_like(lat_ans) ts = torch.randint(0, self.scheduler.config.num_train_timesteps, (B, ), device=noise.device).long() lat_noise = self.scheduler.add_noise(lat_ans, noise, ts) prob = torch.rand(B, device=lat_ans.device) mask = (prob<(DROP*2)).reshape(B, 1, 1) hid_edit = torch.where(mask, hid_null, hid_edit) mask = (1.0-((prob>=DROP).to(lat_inp.dtype)*(prob<(DROP*3)).to(lat_inp.dtype))).reshape(B, 1, 1, 1) lat_inp *= mask out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample loss_ce, loss_edit = loss, nn.functional.mse_loss(out, noise, reduction='mean') if int(os.environ['LOCAL_RANK'])==0: print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit) loss = loss_ce+loss_edit*0.5 if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values: input_ids = input_ids[:, -1:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "images": kwargs.get("images", None), } ) return model_inputs def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device, tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None): vision_config = self.get_vision_tower().config vision_config.use_im_start_end = mm_use_im_start_end tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if mm_use_im_start_end: num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if tune_mm_mlp_adapter: self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)] for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = False if pretrain_mm_mlp_adapter: mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] assert num_new_tokens == 2 if input_embeddings.shape == embed_tokens_weight.shape: input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] AutoConfig.register("llava", LlavaConfig) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) # main.py from google.colab import drive drive.mount('/content/drive') import os from PIL import Image import numpy as np import torch as T import transformers import diffusers import gradio as gr import huggingface_hub CKPT_DIR = '/content/drive/My Drive/_ckpt' def crop_resize(f, sz=512): w, h = f.size if w > h: p = (w - h) // 2 f = f.crop([p, 0, p + h, h]) elif h > w: p = (h - w) // 2 f = f.crop([0, p, w, p + w]) f = f.resize([sz, sz]) return f def remove_alter(s): if 'ASSISTANT:' in s: s = s[s.index('ASSISTANT:') + 10:].strip() if '' in s: s = s[:s.index('')].strip() if 'alternative' in s.lower(): s = s[:s.lower().index('alternative')] if '[IMG0]' in s: s = s[:s.index('[IMG0]')] s = '.'.join([s.strip() for s in s.split('.')[:2]]) if s[-1] != '.': s += '.' return s.strip() DEFAULT_IMAGE_TOKEN = '' DEFAULT_IMAGE_PATCH_TOKEN = '' DEFAULT_IM_START_TOKEN = '' DEFAULT_IM_END_TOKEN = '' PATH_LLAVA = f'{CKPT_DIR}/LLaVA-7B-v1' tokenizer = transformers.AutoTokenizer.from_pretrained(PATH_LLAVA) model = LlavaLlamaForCausalLM.from_pretrained(PATH_LLAVA, low_cpu_mem_usage=True, torch_dtype=T.float16, use_cache=True).cuda() image_processor = transformers.CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=T.float16) tokenizer.padding_side = 'left' tokenizer.add_tokens(['[IMG0]', '[IMG1]', '[IMG2]', '[IMG3]', '[IMG4]', '[IMG5]', '[IMG6]', '[IMG7]'], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) ckpt = T.load(f'{CKPT_DIR}/mgie_7b/mllm.pt', map_location='cpu') model.load_state_dict(ckpt, strict=False) mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) vision_tower = model.get_model().vision_tower[0] vision_tower = transformers.CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=T.float16, low_cpu_mem_usage=True).cuda() model.get_model().vision_tower[0] = vision_tower vision_config = vision_tower.config vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 _ = model.eval() pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained('timbrooks/instruct-pix2pix', torch_dtype=T.float16).to('cuda') pipe.set_progress_bar_config(disable=True) pipe.unet.load_state_dict(T.load(f'{CKPT_DIR}/mgie_7b/unet.pt', map_location='cpu')) print('--init MGIE--') def go_mgie(img, txt, seed, cfg_txt, cfg_img): EMB = ckpt['emb'].cuda() with T.inference_mode(): NULL = model.edit_head(T.zeros(1, 8, 4096).half().to('cuda'), EMB) img, seed = crop_resize(Image.fromarray(img).convert('RGB')), int(seed) inp = img img = image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0] txt = "what will this image be like if '%s'" % (txt) txt = txt + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN conv = conv_templates['vicuna_v1_1'].copy() conv.append_message(conv.roles[0], txt), conv.append_message(conv.roles[1], None) txt = conv.get_prompt() txt = tokenizer(txt) txt, mask = T.as_tensor(txt['input_ids']), T.as_tensor(txt['attention_mask']) with T.inference_mode(): _ = model.cuda() out = model.generate(txt.unsqueeze(dim=0).cuda(), images=img.half().unsqueeze(dim=0).cuda(), attention_mask=mask.unsqueeze(dim=0).cuda(), do_sample=False, max_new_tokens=96, num_beams=1, no_repeat_ngram_size=3, return_dict_in_generate=True, output_hidden_states=True) out, hid = out['sequences'][0].tolist(), T.cat([x[-1] for x in out['hidden_states']], dim=1)[0] if 32003 in out: p = out.index(32003) - 1 else: p = len(hid) - 9 p = min(p, len(hid) - 9) hid = hid[p:p + 8] out = remove_alter(tokenizer.decode(out)) _ = model.cuda() emb = model.edit_head(hid.unsqueeze(dim=0), EMB) res = pipe(image=inp, prompt_embeds=emb, negative_prompt_embeds=NULL, generator=T.Generator(device='cuda').manual_seed(seed), guidance_scale=cfg_txt, image_guidance_scale=cfg_img).images[0] return res, out with gr.Blocks() as app: gr.Markdown( """ # MagiX: Edit Personalized Images using Gen AI by Ateeb Taser """ ) with gr.Row(): inp, res = [gr.Image(height=384, width=384, label='Input Image', interactive=True), gr.Image(height=384, width=384, label='Goal Image', interactive=True)] with gr.Row(): txt, out = [gr.Textbox(label='Instruction', interactive=True), gr.Textbox(label='Expressive Instruction', interactive=False)] with gr.Row(): seed, cfg_txt, cfg_img = [gr.Number(value=13331, label='Seed', interactive=True), gr.Number(value=7.5, label='Text CFG', interactive=True), gr.Number(value=1.5, label='Image CFG', interactive=True)] with gr.Row(): btn_sub = gr.Button('Submit') btn_sub.click(fn=go_mgie, inputs=[inp, txt, seed, cfg_txt, cfg_img], outputs=[res, out]) app.launch()