Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import torch.nn as nn | |
from typing import List, Optional | |
import torch.utils.checkpoint | |
from torchvision.transforms import ToPILImage | |
from model_lib.moMA_generator import MoMA_generator | |
from transformers.activations import ACT2FN | |
from huggingface_hub import hf_hub_download | |
from dataset_lib.dataset_eval_MoMA import Dataset_evaluate_MoMA | |
from llava.model.builder import load_pretrained_model | |
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path | |
from llava.constants import IMAGE_TOKEN_INDEX | |
def add_function(model): | |
def my_llava_forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
images: Optional[torch.FloatTensor] = None, | |
return_dict: Optional[bool] = None, | |
): | |
(_,position_ids,attention_mask,_,inputs_embeds,_) = self.prepare_inputs_labels_for_multimodal(input_ids,position_ids,attention_mask,None,None,images) | |
outputs = self.model( | |
input_ids=None, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=None, | |
inputs_embeds=inputs_embeds, | |
use_cache=True, | |
output_attentions=False, | |
output_hidden_states=False, | |
return_dict=True, | |
) | |
return outputs[0] | |
model.my_llava_forward = my_llava_forward | |
class LlamaMLP_mapping(nn.Module): | |
def __init__(self, hidden_size,hidden_size_out): | |
super().__init__() | |
self.hidden_size, self.hidden_size_out = hidden_size,hidden_size_out | |
self.gate_proj = nn.Linear(self.hidden_size, self.hidden_size_out, bias=False) | |
self.up_proj = nn.Linear(self.hidden_size, self.hidden_size_out, bias=False) | |
self.down_proj = nn.Linear(self.hidden_size_out, self.hidden_size_out, bias=False) | |
self.act_fn = ACT2FN["silu"] | |
self.act_fn_output = ACT2FN["tanh"] | |
self.init_linear() | |
def forward(self, x): | |
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) | |
return down_proj | |
def init_linear(self): | |
torch.nn.init.xavier_normal_(self.gate_proj.weight) | |
self.gate_proj.weight.data=self.gate_proj.weight.data/4.0 | |
torch.nn.init.xavier_normal_(self.up_proj.weight) | |
self.up_proj.weight.data=self.up_proj.weight.data/4.0 | |
torch.nn.init.xavier_normal_(self.down_proj.weight) | |
self.down_proj.weight.data=self.down_proj.weight.data/4.0 | |
class MoMA_main_modal(nn.Module): | |
def __init__(self,args): | |
super().__init__() | |
self.args = args | |
self.device = args.device | |
self.moMA_generator = MoMA_generator(self.device,args) | |
self.unet = self.moMA_generator.pipe.unet | |
self.vae = self.moMA_generator.pipe.vae | |
print('Loading MoMA: its Multi-modal LLM...') | |
model_name = get_model_name_from_path(args.model_path) | |
self.tokenizer_llava, self.model_llava, self.image_processor_llava, self.context_len_llava = load_pretrained_model(args.model_path, None, model_name, load_8bit=self.args.load_8bit, load_4bit=self.args.load_4bit, device=args.device) | |
add_function(self.model_llava) | |
self.mapping = LlamaMLP_mapping(4096,1024).to(self.device, dtype=torch.float16) | |
self.load_saved_components() | |
self.freeze_modules() | |
def load_saved_components(self): | |
if not os.path.exists(self.args.load_attn_adapters): | |
print('Loading Attentions and LLM mappings...') | |
hf_hub_download(repo_id=self.args.model_path, filename="attn_adapters_projectors.th",local_dir='/'.join(self.args.load_attn_adapters.split('/')[:-1])) | |
#load attention adapters and self cross attentions | |
state_dict = torch.load(self.args.load_attn_adapters, map_location="cpu") | |
self.moMA_generator.image_proj_model.load_state_dict(state_dict["projectors"]) | |
attn_layers = torch.nn.ModuleList(self.unet.attn_processors.values()) | |
attn_layers.load_state_dict(state_dict["self_cross_attentions"],strict=False) | |
#load LLM projectors | |
self.load_state_dict(state_dict['llm_mapping'],strict=False) | |
def freeze_modules(self): | |
all_modules = [self.moMA_generator.pipe.vae,self.moMA_generator.pipe.text_encoder,self.unet,self.model_llava,self.mapping] | |
for module in all_modules: | |
module.train = False | |
module.requires_grad_(False) | |
def forward_MLLM(self,batch): | |
llava_processeds,subjects,prompts = batch['llava_processed'].half().to(self.device),batch['label'],batch['text'] | |
input_ids,attention_masks,position_ids = [],[],[] | |
for subject,prompt in zip(subjects,prompts): | |
prompt_construct = f"USER: <image>\n A photo of a {subject}. Describe a new image of the same {subject} in: {prompt}. ASSISTANT: *" | |
input_id = tokenizer_image_token(prompt_construct, self.tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) | |
attention_mask = torch.ones(input_id.shape, dtype=torch.long, device=self.device) | |
position_id = torch.tensor(list(range(input_id.shape[-1])), device=self.device) | |
position_ids += [position_id] | |
attention_masks += [attention_mask[0]] | |
input_ids += [input_id[0]] | |
input_ids = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in input_ids],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1]) | |
position_ids = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in position_ids],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1]) | |
attention_masks = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in attention_masks],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1]) | |
output = self.model_llava.my_llava_forward(self.model_llava,input_ids=input_ids,attention_mask=attention_masks,position_ids=position_ids,images=llava_processeds) | |
output = self.mapping(output) | |
return output[:,-1,:] | |
def reset(self): | |
self.moMA_generator.reset_all() | |
def generate_images(self, rgb_path, subject, prompt, strength=1.0, num=1, seed=0): | |
batch = Dataset_evaluate_MoMA(rgb_path, prompt, subject,self) | |
self.moMA_generator.set_selfAttn_strength(strength) | |
with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True): | |
with torch.no_grad(): | |
### key steps | |
llava_emb = self.forward_MLLM(batch).clone().detach() | |
img,mask = self.moMA_generator.generate_with_MoMA(batch,llava_emb=llava_emb,seed=seed,device=self.args.device) | |
self.reset() | |
result = ToPILImage()(img[0]) | |
return result |