diff --git a/gradio_demo/app.py b/gradio_demo/app.py new file mode 100644 index 0000000000000000000000000000000000000000..271b4b8989a7ea7cb509b73c57befbb776dcc038 --- /dev/null +++ b/gradio_demo/app.py @@ -0,0 +1,545 @@ +import sys +sys.path.append('./') +import gradio as gr +import random +import numpy as np +from gradio_demo.character_template import character_man, lorapath_man +from gradio_demo.character_template import character_woman, lorapath_woman +from gradio_demo.character_template import styles, lorapath_styles +import torch +import os +from typing import Tuple, List +import copy +import argparse +from diffusers.utils import load_image +import cv2 +from PIL import Image +from transformers import DPTFeatureExtractor, DPTForDepthEstimation +from controlnet_aux import OpenposeDetector +from controlnet_aux.open_pose.body import Body + +try: + from inference.models import YOLOWorld + from src.efficientvit.models.efficientvit.sam import EfficientViTSamPredictor + from src.efficientvit.sam_model_zoo import create_sam_model + import supervision as sv +except: + print("YoloWorld can not be load") + +try: + from groundingdino.models import build_model + from groundingdino.util import box_ops + from groundingdino.util.slconfig import SLConfig + from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap + from groundingdino.util.inference import annotate, predict + from segment_anything import build_sam, SamPredictor + import groundingdino.datasets.transforms as T +except: + print("groundingdino can not be load") + +from src.pipelines.lora_pipeline import LoraMultiConceptPipeline +from src.prompt_attention.p2p_attention import AttentionReplace +from diffusers import ControlNetModel, StableDiffusionXLPipeline +from src.pipelines.lora_pipeline import revise_regionally_controlnet_forward + +CHARACTER_MAN_NAMES = list(character_man.keys()) +CHARACTER_WOMAN_NAMES = list(character_woman.keys()) +STYLE_NAMES = list(styles.keys()) +MAX_SEED = np.iinfo(np.int32).max + +### Description +title = r""" +

OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models

+""" + +description = r""" +Official 🤗 Gradio demo for OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models.
+ +How to use:
+1. Select two characters. +2. Enter a text prompt as done in normal text-to-image models. +3. Click the Submit button to start customizing. +4. Enjoy the generated image😊! +""" + +article = r""" +--- +📝 **Citation** +
+If our work is helpful for your research or applications, please cite us via: +```bibtex +@article{, +title={OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models}, +author={}, +journal={}, +year={} +} +``` +""" + +tips = r""" +### Usage tips of OMG +1. Input text prompts to describe a man and a woman +""" + +css = ''' +.gradio-container {width: 85% !important} +''' + +def sample_image(pipe, + input_prompt, + input_neg_prompt=None, + generator=None, + concept_models=None, + num_inference_steps=50, + guidance_scale=7.5, + controller=None, + stage=None, + region_masks=None, + lora_list = None, + styleL=None, + **extra_kargs +): + + spatial_condition = extra_kargs.pop('spatial_condition') + if spatial_condition is not None: + spatial_condition_input = [spatial_condition] * len(input_prompt) + else: + spatial_condition_input = None + + images = pipe( + prompt=input_prompt, + concept_models=concept_models, + negative_prompt=input_neg_prompt, + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + cross_attention_kwargs={"scale": 0.8}, + controller=controller, + stage=stage, + region_masks=region_masks, + lora_list=lora_list, + styleL=styleL, + image=spatial_condition_input, + **extra_kargs).images + + return images + +def load_image_yoloworld(image_source) -> Tuple[np.array, torch.Tensor]: + image = np.asarray(image_source) + return image + +def load_image_dino(image_source) -> Tuple[np.array, torch.Tensor]: + transform = T.Compose( + [ + T.RandomResize([800], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + image = np.asarray(image_source) + image_transformed, _ = transform(image_source, None) + return image, image_transformed + +def predict_mask(segmentmodel, sam, image, TEXT_PROMPT, segmentType, confidence = 0.2, threshold = 0.5): + if segmentType=='GroundingDINO': + image_source, image = load_image_dino(image) + boxes, logits, phrases = predict( + model=segmentmodel, + image=image, + caption=TEXT_PROMPT, + box_threshold=0.3, + text_threshold=0.25 + ) + sam.set_image(image_source) + H, W, _ = image_source.shape + boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H]) + + transformed_boxes = sam.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).cuda() + masks, _, _ = sam.predict_torch( + point_coords=None, + point_labels=None, + boxes=transformed_boxes, + multimask_output=False, + ) + masks=masks[0].squeeze(0) + else: + image_source = load_image_yoloworld(image) + segmentmodel.set_classes([TEXT_PROMPT]) + results = segmentmodel.infer(image_source, confidence=confidence) + detections = sv.Detections.from_inference(results).with_nms( + class_agnostic=True, threshold=threshold + ) + masks = None + if len(detections) != 0: + print(TEXT_PROMPT + " detected!") + sam.set_image(image_source, image_format="RGB") + masks, _, _ = sam.predict(box=detections.xyxy[0], multimask_output=False) + masks = torch.from_numpy(masks.squeeze()) + + return masks + +def prepare_text(prompt, region_prompts): + ''' + Args: + prompt_entity: [subject1]-*-[attribute1]-*-[Location1]|[subject2]-*-[attribute2]-*-[Location2]|[global text] + Returns: + full_prompt: subject1, attribute1 and subject2, attribute2, global text + context_prompt: subject1 and subject2, global text + entity_collection: [(subject1, attribute1), Location1] + ''' + region_collection = [] + + regions = region_prompts.split('|') + + for region in regions: + if region == '': + break + prompt_region, neg_prompt_region = region.split('-*-') + prompt_region = prompt_region.replace('[', '').replace(']', '') + neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '') + + region_collection.append((prompt_region, neg_prompt_region)) + return (prompt, region_collection) + + +def build_model_sd(pretrained_model, controlnet_path, device, prompts): + controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16).to(device) + pipe = LoraMultiConceptPipeline.from_pretrained( + pretrained_model, controlnet=controlnet, torch_dtype=torch.float16, variant="fp16").to(device) + controller = AttentionReplace(prompts, 50, cross_replace_steps={"default_": 1.}, self_replace_steps=0.4, tokenizer=pipe.tokenizer, device=device, dtype=torch.float16, width=1024//32, height=1024//32) + revise_regionally_controlnet_forward(pipe.unet, controller) + pipe_concept = StableDiffusionXLPipeline.from_pretrained(pretrained_model, torch_dtype=torch.float16, + variant="fp16").to(device) + return pipe, controller, pipe_concept + +def build_model_lora(pipe_concept, lora_paths, style_path, condition, args): + pipe_list = [] + if condition == "Human pose": + controlnet = ControlNetModel.from_pretrained(args.openpose_checkpoint, torch_dtype=torch.float16).to(device) + pipe_concept.controlnet = controlnet + elif condition == "Canny Edge": + controlnet = ControlNetModel.from_pretrained(args.canny_checkpoint, torch_dtype=torch.float16).to(device) + pipe_concept.controlnet = controlnet + elif condition == "Depth": + controlnet = ControlNetModel.from_pretrained(args.depth_checkpoint, torch_dtype=torch.float16).to(device) + pipe_concept.controlnet = controlnet + + if style_path is not None and os.path.exists(style_path): + pipe_concept.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style') + + for lora_path in lora_paths.split('|'): + adapter_name = lora_path.split('/')[-1].split('.')[0] + pipe_concept.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name=adapter_name) + pipe_concept.enable_xformers_memory_efficient_attention() + pipe_list.append(adapter_name) + return pipe_list + +def build_yolo_segment_model(sam_path, device): + yolo_world = YOLOWorld(model_id="yolo_world/l") + sam = EfficientViTSamPredictor( + create_sam_model(name="xl1", weight_url=sam_path).to(device).eval() + ) + return yolo_world, sam + +def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'): + args = SLConfig.fromfile(ckpt_config_filename) + model = build_model(args) + args.device = device + + checkpoint = torch.load(os.path.join(repo_id, filename), map_location='cpu') + log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False) + print("Model loaded from {} \n => {}".format(filename, log)) + _ = model.eval() + return model + +def build_dino_segment_model(ckpt_repo_id, sam_checkpoint): + ckpt_filenmae = "groundingdino_swinb_cogcoor.pth" + ckpt_config_filename = os.path.join(ckpt_repo_id, "GroundingDINO_SwinB.cfg.py") + groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename) + sam = build_sam(checkpoint=sam_checkpoint) + sam.cuda() + sam_predictor = SamPredictor(sam) + return groundingdino_model, sam_predictor + + + +def main(device, segment_type): + pipe, controller, pipe_concept = build_model_sd(args.pretrained_sdxl_model, args.openpose_checkpoint, device, prompts_tmp) + + if segment_type == 'GroundingDINO': + detect_model, sam = build_dino_segment_model(args.dino_checkpoint, args.sam_checkpoint) + else: + detect_model, sam = build_yolo_segment_model(args.efficientViT_checkpoint, device) + + resolution_list = ["1440*728", + "1344*768", + "1216*832", + "1152*896", + "1024*1024", + "896*1152", + "832*1216", + "768*1344", + "728*1440"] + + condition_list = ["None", + "Human pose", + "Canny Edge", + "Depth"] + + depth_estimator = DPTForDepthEstimation.from_pretrained(args.dpt_checkpoint).to("cuda") + feature_extractor = DPTFeatureExtractor.from_pretrained(args.dpt_checkpoint) + body_model = Body(args.pose_detector_checkpoint) + openpose = OpenposeDetector(body_model) + + def remove_tips(): + return gr.update(visible=False) + + def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + if randomize_seed: + seed = random.randint(0, MAX_SEED) + return seed + + def get_humanpose(img): + openpose_image = openpose(img) + return openpose_image + + def get_cannyedge(image): + image = np.array(image) + image = cv2.Canny(image, 100, 200) + image = image[:, :, None] + image = np.concatenate([image, image, image], axis=2) + canny_image = Image.fromarray(image) + return canny_image + + def get_depth(image): + image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") + with torch.no_grad(), torch.autocast("cuda"): + depth_map = depth_estimator(image).predicted_depth + + depth_map = torch.nn.functional.interpolate( + depth_map.unsqueeze(1), + size=(1024, 1024), + mode="bicubic", + align_corners=False, + ) + depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) + depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) + depth_map = (depth_map - depth_min) / (depth_max - depth_min) + image = torch.cat([depth_map] * 3, dim=1) + image = image.permute(0, 2, 3, 1).cpu().numpy()[0] + image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) + return image + + def generate_image(prompt1, prompt2, prompt3, prompt4, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, condition_img2, condition_img3, condition_img4, style): + try: + path1 = lorapath_man[man] + path2 = lorapath_woman[woman] + pipe_concept.unload_lora_weights() + pipe_list = build_model_lora(pipe_concept, path1 + "|" + path2, lorapath_styles[style], condition, args) + + if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]): + styleL = True + else: + styleL = False + + input_list = [prompt1, prompt2, prompt3, prompt4] + condition_list = [condition_img1, condition_img2, condition_img3, condition_img4] + output_list = [] + + width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1]) + + kwargs = { + 'height': height, + 'width': width, + } + + for prompt, condition_img in zip(input_list, condition_list): + if prompt!='': + input_prompt = [] + p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.' + if styleL: + p = styles[style] + p + input_prompt.append([p.replace("{prompt}", prompt), p.replace("{prompt}", prompt)]) + input_prompt.append([(styles[style] + local_prompt1, character_man.get(man)[1]), (styles[style] + local_prompt2, character_woman.get(woman)[1])]) + + if condition == 'Human pose' and condition_img is not None: + spatial_condition = get_humanpose(condition_img).resize((width, height)) + elif condition == 'Canny Edge' and condition_img is not None: + spatial_condition = get_cannyedge(condition_img).resize((width, height)) + elif condition == 'Depth' and condition_img is not None: + spatial_condition = get_depth(condition_img).resize((width, height)) + else: + spatial_condition = None + + kwargs['spatial_condition'] = spatial_condition + + controller.reset() + image = sample_image( + pipe, + input_prompt=input_prompt, + concept_models=pipe_concept, + input_neg_prompt=[negative_prompt] * len(input_prompt), + generator=torch.Generator(device).manual_seed(seed), + controller=controller, + stage=1, + lora_list=pipe_list, + styleL=styleL, + **kwargs) + + controller.reset() + if pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]: + mask1 = predict_mask(detect_model, sam, image[0], 'man', args.segment_type, confidence=0.15, + threshold=0.5) + else: + mask1 = None + + if pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]: + mask2 = predict_mask(detect_model, sam, image[0], 'woman', args.segment_type, confidence=0.15, + threshold=0.5) + else: + mask2 = None + + if mask1 is None and mask2 is None: + output_list.append(image[1]) + else: + image = sample_image( + pipe, + input_prompt=input_prompt, + concept_models=pipe_concept, + input_neg_prompt=[negative_prompt] * len(input_prompt), + generator=torch.Generator(device).manual_seed(seed), + controller=controller, + stage=2, + region_masks=[mask1, mask2], + lora_list=pipe_list, + styleL=styleL, + **kwargs) + output_list.append(image[1]) + else: + output_list.append(None) + return output_list + except: + print("error") + return None, None, None, None + + def get_local_value_man(input): + return character_man[input][0] + + def get_local_value_woman(input): + return character_woman[input][0] + + + with gr.Blocks(css=css) as demo: + # description + gr.Markdown(title) + gr.Markdown(description) + + with gr.Row(): + gallery = gr.Image(label="Generated Images", height=512, width=512) + gallery2 = gr.Image(label="Generated Images", height=512, width=512) + gallery3 = gr.Image(label="Generated Images", height=512, width=512) + gallery4 = gr.Image(label="Generated Images", height=512, width=512) + usage_tips = gr.Markdown(label="Usage tips of OMG", value=tips, visible=False) + + with gr.Row(): + condition_img1 = gr.Image(label="Input condition", height=128, width=128) + condition_img2 = gr.Image(label="Input condition", height=128, width=128) + condition_img3 = gr.Image(label="Input condition", height=128, width=128) + condition_img4 = gr.Image(label="Input condition", height=128, width=128) + + # character choose + with gr.Row(): + man = gr.Dropdown(label="Character 1 selection", choices=CHARACTER_MAN_NAMES, value="Harry Potter (identifier: Harry Potter)") + woman = gr.Dropdown(label="Character 2 selection", choices=CHARACTER_WOMAN_NAMES, value="Hermione Granger (identifier: Hermione Granger)") + resolution = gr.Dropdown(label="Image Resolution (width*height)", choices=resolution_list, value="1024*1024") + condition = gr.Dropdown(label="Input condition type", choices=condition_list, value="None") + style = gr.Dropdown(label="style", choices=STYLE_NAMES, value="None") + + with gr.Row(): + local_prompt1 = gr.Textbox(label="Character1_prompt", + info="Describe the Character 1, this prompt should include the identifier of character 1", + value="Close-up photo of the Harry Potter, 35mm photograph, film, professional, 4k, highly detailed.") + local_prompt2 = gr.Textbox(label="Character2_prompt", + info="Describe the Character 2, this prompt should include the identifier of character2", + value="Close-up photo of the Hermione Granger, 35mm photograph, film, professional, 4k, highly detailed.") + + man.change(get_local_value_man, man, local_prompt1) + woman.change(get_local_value_woman, woman, local_prompt2) + + # prompt + with gr.Column(): + prompt = gr.Textbox(label="Prompt 1", + info="Give a simple prompt to describe the first image content", + placeholder="Required", + value="close-up shot, photography, the cool man and beautiful woman as they accidentally discover a mysterious island while on vacation by the sea, facing the camera smiling") + prompt2 = gr.Textbox(label="Prompt 2", + info="Give a simple prompt to describe the second image content", + placeholder="optional", + value="") + prompt3 = gr.Textbox(label="Prompt 3", + info="Give a simple prompt to describe the third image content", + placeholder="optional", + value="") + prompt4 = gr.Textbox(label="Prompt 4", + info="Give a simple prompt to describe the fourth image content", + placeholder="optional", + value="") + + with gr.Accordion(open=False, label="Advanced Options"): + seed = gr.Slider( + label="Seed", + minimum=0, + maximum=MAX_SEED, + step=1, + value=42, + ) + negative_prompt = gr.Textbox(label="Negative Prompt", + placeholder="noisy, blurry, soft, deformed, ugly", + value="noisy, blurry, soft, deformed, ugly") + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + + submit = gr.Button("Submit", variant="primary") + + submit.click( + fn=remove_tips, + outputs=usage_tips, + ).then( + fn=randomize_seed_fn, + inputs=[seed, randomize_seed], + outputs=seed, + queue=False, + api_name=False, + ).then( + fn=generate_image, + inputs=[prompt, prompt2, prompt3, prompt4, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, condition_img2, condition_img3, condition_img4, style], + outputs=[gallery, gallery2, gallery3, gallery4] + ) + demo.launch(server_name='0.0.0.0',server_port=7861, debug=True) + +def parse_args(): + parser = argparse.ArgumentParser('', add_help=False) + parser.add_argument('--pretrained_sdxl_model', default='./checkpoint/stable-diffusion-xl-base-1.0', type=str) + parser.add_argument('--openpose_checkpoint', default='./checkpoint/controlnet-openpose-sdxl-1.0', type=str) + parser.add_argument('--canny_checkpoint', default='./checkpoint/controlnet-canny-sdxl-1.0', type=str) + parser.add_argument('--depth_checkpoint', default='./checkpoint/controlnet-depth-sdxl-1.0', type=str) + parser.add_argument('--efficientViT_checkpoint', default='./checkpoint/sam/xl1.pt', type=str) + parser.add_argument('--dino_checkpoint', default='./checkpoint/GroundingDINO', type=str) + parser.add_argument('--sam_checkpoint', default='./checkpoint/sam/sam_vit_h_4b8939.pth', type=str) + parser.add_argument('--dpt_checkpoint', default='./checkpoint/dpt-hybrid-midas', type=str) + parser.add_argument('--pose_detector_checkpoint', default='./checkpoint/ControlNet/annotator/ckpts/body_pose_model.pth', type=str) + parser.add_argument('--prompt', default='Close-up photo of the cool man and beautiful woman in surprised expressions as they accidentally discover a mysterious island while on vacation by the sea, 35mm photograph, film, professional, 4k, highly detailed.', type=str) + parser.add_argument('--negative_prompt', default='noisy, blurry, soft, deformed, ugly', type=str) + parser.add_argument('--seed', default=22, type=int) + parser.add_argument('--suffix', default='', type=str) + parser.add_argument('--segment_type', default='yoloworld', help='GroundingDINO or yoloworld', type=str) + return parser.parse_args() + +if __name__ == '__main__': + args = parse_args() + + prompts = [args.prompt]*2 + prompts_tmp = copy.deepcopy(prompts) + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + main(device, args.segment_type) \ No newline at end of file diff --git a/gradio_demo/app_generateOne.py b/gradio_demo/app_generateOne.py new file mode 100644 index 0000000000000000000000000000000000000000..f716e314740ea0356d85021a261d5a1ee697a218 --- /dev/null +++ b/gradio_demo/app_generateOne.py @@ -0,0 +1,529 @@ +import sys +sys.path.append('./') +import gradio as gr +import random +import numpy as np +from gradio_demo.character_template import character_man, lorapath_man +from gradio_demo.character_template import character_woman, lorapath_woman +from gradio_demo.character_template import styles, lorapath_styles +import torch +import os +from typing import Tuple, List +import copy +import argparse +from diffusers.utils import load_image +import cv2 +from PIL import Image +from transformers import DPTFeatureExtractor, DPTForDepthEstimation +from controlnet_aux import OpenposeDetector +from controlnet_aux.open_pose.body import Body + +try: + from inference.models import YOLOWorld + from src.efficientvit.models.efficientvit.sam import EfficientViTSamPredictor + from src.efficientvit.sam_model_zoo import create_sam_model + import supervision as sv +except: + print("YoloWorld can not be load") + +try: + from groundingdino.models import build_model + from groundingdino.util import box_ops + from groundingdino.util.slconfig import SLConfig + from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap + from groundingdino.util.inference import annotate, predict + from segment_anything import build_sam, SamPredictor + import groundingdino.datasets.transforms as T +except: + print("groundingdino can not be load") + +from src.pipelines.lora_pipeline import LoraMultiConceptPipeline +from src.prompt_attention.p2p_attention import AttentionReplace +from diffusers import ControlNetModel, StableDiffusionXLPipeline +from src.pipelines.lora_pipeline import revise_regionally_controlnet_forward + +CHARACTER_MAN_NAMES = list(character_man.keys()) +CHARACTER_WOMAN_NAMES = list(character_woman.keys()) +STYLE_NAMES = list(styles.keys()) +MAX_SEED = np.iinfo(np.int32).max + +### Description +title = r""" +

OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models

+""" + +description = r""" +Official 🤗 Gradio demo for OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models.
+ +How to use:
+1. Select two characters. +2. Enter a text prompt as done in normal text-to-image models. +3. Click the Submit button to start customizing. +4. Enjoy the generated image😊! +""" + +article = r""" +--- +📝 **Citation** +
+If our work is helpful for your research or applications, please cite us via: +```bibtex +@article{, +title={OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models}, +author={}, +journal={}, +year={} +} +``` +""" + +tips = r""" +### Usage tips of OMG +1. Input text prompts to describe a man and a woman +""" + +css = ''' +.gradio-container {width: 85% !important} +''' + +def sample_image(pipe, + input_prompt, + input_neg_prompt=None, + generator=None, + concept_models=None, + num_inference_steps=50, + guidance_scale=7.5, + controller=None, + stage=None, + region_masks=None, + lora_list = None, + styleL=None, + **extra_kargs +): + + spatial_condition = extra_kargs.pop('spatial_condition') + if spatial_condition is not None: + spatial_condition_input = [spatial_condition] * len(input_prompt) + else: + spatial_condition_input = None + + images = pipe( + prompt=input_prompt, + concept_models=concept_models, + negative_prompt=input_neg_prompt, + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + cross_attention_kwargs={"scale": 0.8}, + controller=controller, + stage=stage, + region_masks=region_masks, + lora_list=lora_list, + styleL=styleL, + image=spatial_condition_input, + **extra_kargs).images + + return images + +def load_image_yoloworld(image_source) -> Tuple[np.array, torch.Tensor]: + image = np.asarray(image_source) + return image + +def load_image_dino(image_source) -> Tuple[np.array, torch.Tensor]: + transform = T.Compose( + [ + T.RandomResize([800], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + image = np.asarray(image_source) + image_transformed, _ = transform(image_source, None) + return image, image_transformed + +def predict_mask(segmentmodel, sam, image, TEXT_PROMPT, segmentType, confidence = 0.2, threshold = 0.5): + if segmentType=='GroundingDINO': + image_source, image = load_image_dino(image) + boxes, logits, phrases = predict( + model=segmentmodel, + image=image, + caption=TEXT_PROMPT, + box_threshold=0.3, + text_threshold=0.25 + ) + sam.set_image(image_source) + H, W, _ = image_source.shape + boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H]) + + transformed_boxes = sam.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).cuda() + masks, _, _ = sam.predict_torch( + point_coords=None, + point_labels=None, + boxes=transformed_boxes, + multimask_output=False, + ) + masks=masks[0].squeeze(0) + else: + image_source = load_image_yoloworld(image) + segmentmodel.set_classes([TEXT_PROMPT]) + results = segmentmodel.infer(image_source, confidence=confidence) + detections = sv.Detections.from_inference(results).with_nms( + class_agnostic=True, threshold=threshold + ) + masks = None + if len(detections) != 0: + print(TEXT_PROMPT + " detected!") + sam.set_image(image_source, image_format="RGB") + masks, _, _ = sam.predict(box=detections.xyxy[0], multimask_output=False) + masks = torch.from_numpy(masks.squeeze()) + + return masks + +def prepare_text(prompt, region_prompts): + ''' + Args: + prompt_entity: [subject1]-*-[attribute1]-*-[Location1]|[subject2]-*-[attribute2]-*-[Location2]|[global text] + Returns: + full_prompt: subject1, attribute1 and subject2, attribute2, global text + context_prompt: subject1 and subject2, global text + entity_collection: [(subject1, attribute1), Location1] + ''' + region_collection = [] + + regions = region_prompts.split('|') + + for region in regions: + if region == '': + break + prompt_region, neg_prompt_region = region.split('-*-') + prompt_region = prompt_region.replace('[', '').replace(']', '') + neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '') + + region_collection.append((prompt_region, neg_prompt_region)) + return (prompt, region_collection) + + +def build_model_sd(pretrained_model, controlnet_path, device, prompts): + controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16).to(device) + pipe = LoraMultiConceptPipeline.from_pretrained( + pretrained_model, controlnet=controlnet, torch_dtype=torch.float16, variant="fp16").to(device) + controller = AttentionReplace(prompts, 50, cross_replace_steps={"default_": 1.}, self_replace_steps=0.4, tokenizer=pipe.tokenizer, device=device, dtype=torch.float16, width=1024//32, height=1024//32) + revise_regionally_controlnet_forward(pipe.unet, controller) + pipe_concept = StableDiffusionXLPipeline.from_pretrained(pretrained_model, torch_dtype=torch.float16, + variant="fp16").to(device) + return pipe, controller, pipe_concept + +def build_model_lora(pipe_concept, lora_paths, style_path, condition, args): + pipe_list = [] + if condition == "Human pose": + controlnet = ControlNetModel.from_pretrained(args.openpose_checkpoint, torch_dtype=torch.float16).to(device) + pipe_concept.controlnet = controlnet + elif condition == "Canny Edge": + controlnet = ControlNetModel.from_pretrained(args.canny_checkpoint, torch_dtype=torch.float16).to(device) + pipe_concept.controlnet = controlnet + elif condition == "Depth": + controlnet = ControlNetModel.from_pretrained(args.depth_checkpoint, torch_dtype=torch.float16).to(device) + pipe_concept.controlnet = controlnet + + if style_path is not None and os.path.exists(style_path): + pipe_concept.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style') + + for lora_path in lora_paths.split('|'): + adapter_name = lora_path.split('/')[-1].split('.')[0] + pipe_concept.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name=adapter_name) + pipe_concept.enable_xformers_memory_efficient_attention() + pipe_list.append(adapter_name) + return pipe_list + +def build_yolo_segment_model(sam_path, device): + yolo_world = YOLOWorld(model_id="yolo_world/l") + sam = EfficientViTSamPredictor( + create_sam_model(name="xl1", weight_url=sam_path).to(device).eval() + ) + return yolo_world, sam + +def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'): + args = SLConfig.fromfile(ckpt_config_filename) + model = build_model(args) + args.device = device + + checkpoint = torch.load(os.path.join(repo_id, filename), map_location='cpu') + log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False) + print("Model loaded from {} \n => {}".format(filename, log)) + _ = model.eval() + return model + +def build_dino_segment_model(ckpt_repo_id, sam_checkpoint): + ckpt_filenmae = "groundingdino_swinb_cogcoor.pth" + ckpt_config_filename = os.path.join(ckpt_repo_id, "GroundingDINO_SwinB.cfg.py") + groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename) + sam = build_sam(checkpoint=sam_checkpoint) + sam.cuda() + sam_predictor = SamPredictor(sam) + return groundingdino_model, sam_predictor + + + +def main(device, segment_type): + pipe, controller, pipe_concept = build_model_sd(args.pretrained_sdxl_model, args.openpose_checkpoint, device, prompts_tmp) + + if segment_type == 'GroundingDINO': + detect_model, sam = build_dino_segment_model(args.dino_checkpoint, args.sam_checkpoint) + else: + detect_model, sam = build_yolo_segment_model(args.efficientViT_checkpoint, device) + + resolution_list = ["1440*728", + "1344*768", + "1216*832", + "1152*896", + "1024*1024", + "896*1152", + "832*1216", + "768*1344", + "728*1440"] + + condition_list = ["None", + "Human pose", + "Canny Edge", + "Depth"] + + depth_estimator = DPTForDepthEstimation.from_pretrained(args.dpt_checkpoint).to("cuda") + feature_extractor = DPTFeatureExtractor.from_pretrained(args.dpt_checkpoint) + body_model = Body(args.pose_detector_checkpoint) + openpose = OpenposeDetector(body_model) + + def remove_tips(): + return gr.update(visible=False) + + def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + if randomize_seed: + seed = random.randint(0, MAX_SEED) + return seed + + def get_humanpose(img): + openpose_image = openpose(img) + return openpose_image + + def get_cannyedge(image): + image = np.array(image) + image = cv2.Canny(image, 100, 200) + image = image[:, :, None] + image = np.concatenate([image, image, image], axis=2) + canny_image = Image.fromarray(image) + return canny_image + + def get_depth(image): + image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") + with torch.no_grad(), torch.autocast("cuda"): + depth_map = depth_estimator(image).predicted_depth + + depth_map = torch.nn.functional.interpolate( + depth_map.unsqueeze(1), + size=(1024, 1024), + mode="bicubic", + align_corners=False, + ) + depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) + depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) + depth_map = (depth_map - depth_min) / (depth_max - depth_min) + image = torch.cat([depth_map] * 3, dim=1) + image = image.permute(0, 2, 3, 1).cpu().numpy()[0] + image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) + return image + + def generate_image(prompt1, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, style): + try: + path1 = lorapath_man[man] + path2 = lorapath_woman[woman] + pipe_concept.unload_lora_weights() + pipe_list = build_model_lora(pipe_concept, path1 + "|" + path2, lorapath_styles[style], condition, args) + + if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]): + styleL = True + else: + styleL = False + + input_list = [prompt1] + condition_list = [condition_img1] + output_list = [] + + width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1]) + + kwargs = { + 'height': height, + 'width': width, + } + + for prompt, condition_img in zip(input_list, condition_list): + if prompt!='': + input_prompt = [] + p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.' + if styleL: + p = styles[style] + p + input_prompt.append([p.replace("{prompt}", prompt), p.replace("{prompt}", prompt)]) + input_prompt.append([(styles[style] + local_prompt1, character_man.get(man)[1]), (styles[style] + local_prompt2, character_woman.get(woman)[1])]) + + if condition == 'Human pose' and condition_img is not None: + spatial_condition = get_humanpose(condition_img).resize((width, height)) + elif condition == 'Canny Edge' and condition_img is not None: + spatial_condition = get_cannyedge(condition_img).resize((width, height)) + elif condition == 'Depth' and condition_img is not None: + spatial_condition = get_depth(condition_img).resize((width, height)) + else: + spatial_condition = None + + kwargs['spatial_condition'] = spatial_condition + controller.reset() + image = sample_image( + pipe, + input_prompt=input_prompt, + concept_models=pipe_concept, + input_neg_prompt=[negative_prompt] * len(input_prompt), + generator=torch.Generator(device).manual_seed(seed), + controller=controller, + stage=1, + lora_list=pipe_list, + styleL=styleL, + **kwargs) + + controller.reset() + if pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]: + mask1 = predict_mask(detect_model, sam, image[0], 'man', args.segment_type, confidence=0.15, + threshold=0.5) + else: + mask1 = None + + if pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]: + mask2 = predict_mask(detect_model, sam, image[0], 'woman', args.segment_type, confidence=0.15, + threshold=0.5) + else: + mask2 = None + + if mask1 is None and mask2 is None: + output_list.append(image[1]) + else: + image = sample_image( + pipe, + input_prompt=input_prompt, + concept_models=pipe_concept, + input_neg_prompt=[negative_prompt] * len(input_prompt), + generator=torch.Generator(device).manual_seed(seed), + controller=controller, + stage=2, + region_masks=[mask1, mask2], + lora_list=pipe_list, + styleL=styleL, + **kwargs) + output_list.append(image[1]) + else: + output_list.append(None) + output_list.append(spatial_condition) + return output_list + except: + print("error") + return + + def get_local_value_man(input): + return character_man[input][0] + + def get_local_value_woman(input): + return character_woman[input][0] + + + with gr.Blocks(css=css) as demo: + # description + gr.Markdown(title) + gr.Markdown(description) + + with gr.Row(): + gallery = gr.Image(label="Generated Images", height=512, width=512) + gen_condition = gr.Image(label="Spatial Condition", height=512, width=512) + usage_tips = gr.Markdown(label="Usage tips of OMG", value=tips, visible=False) + + with gr.Row(): + condition_img1 = gr.Image(label="Input an RGB image for condition", height=128, width=128) + + # character choose + with gr.Row(): + man = gr.Dropdown(label="Character 1 selection", choices=CHARACTER_MAN_NAMES, value="Harry Potter (identifier: Harry Potter)") + woman = gr.Dropdown(label="Character 2 selection", choices=CHARACTER_WOMAN_NAMES, value="Hermione Granger (identifier: Hermione Granger)") + resolution = gr.Dropdown(label="Image Resolution (width*height)", choices=resolution_list, value="1024*1024") + condition = gr.Dropdown(label="Input condition type", choices=condition_list, value="None") + style = gr.Dropdown(label="style", choices=STYLE_NAMES, value="None") + + with gr.Row(): + local_prompt1 = gr.Textbox(label="Character1_prompt", + info="Describe the Character 1, this prompt should include the identifier of character 1", + value="Close-up photo of the Harry Potter, 35mm photograph, film, professional, 4k, highly detailed.") + local_prompt2 = gr.Textbox(label="Character2_prompt", + info="Describe the Character 2, this prompt should include the identifier of character2", + value="Close-up photo of the Hermione Granger, 35mm photograph, film, professional, 4k, highly detailed.") + + man.change(get_local_value_man, man, local_prompt1) + woman.change(get_local_value_woman, woman, local_prompt2) + + # prompt + with gr.Column(): + prompt = gr.Textbox(label="Prompt 1", + info="Give a simple prompt to describe the first image content", + placeholder="Required", + value="close-up shot, photography, the cool man and beautiful woman as they accidentally discover a mysterious island while on vacation by the sea, facing the camera smiling") + + + with gr.Accordion(open=False, label="Advanced Options"): + seed = gr.Slider( + label="Seed", + minimum=0, + maximum=MAX_SEED, + step=1, + value=42, + ) + negative_prompt = gr.Textbox(label="Negative Prompt", + placeholder="noisy, blurry, soft, deformed, ugly", + value="noisy, blurry, soft, deformed, ugly") + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + + submit = gr.Button("Submit", variant="primary") + + submit.click( + fn=remove_tips, + outputs=usage_tips, + ).then( + fn=randomize_seed_fn, + inputs=[seed, randomize_seed], + outputs=seed, + queue=False, + api_name=False, + ).then( + fn=generate_image, + inputs=[prompt, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, style], + outputs=[gallery, gen_condition] + ) + demo.launch(server_name='0.0.0.0',server_port=7861, debug=True) + +def parse_args(): + parser = argparse.ArgumentParser('', add_help=False) + parser.add_argument('--pretrained_sdxl_model', default='./checkpoint/stable-diffusion-xl-base-1.0', type=str) + parser.add_argument('--openpose_checkpoint', default='./checkpoint/controlnet-openpose-sdxl-1.0', type=str) + parser.add_argument('--canny_checkpoint', default='./checkpoint/controlnet-canny-sdxl-1.0', type=str) + parser.add_argument('--depth_checkpoint', default='./checkpoint/controlnet-depth-sdxl-1.0', type=str) + parser.add_argument('--efficientViT_checkpoint', default='./checkpoint/sam/xl1.pt', type=str) + parser.add_argument('--dino_checkpoint', default='./checkpoint/GroundingDINO', type=str) + parser.add_argument('--sam_checkpoint', default='./checkpoint/sam/sam_vit_h_4b8939.pth', type=str) + parser.add_argument('--dpt_checkpoint', default='./checkpoint/dpt-hybrid-midas', type=str) + parser.add_argument('--pose_detector_checkpoint', default='./checkpoint/ControlNet/annotator/ckpts/body_pose_model.pth', type=str) + parser.add_argument('--prompt', default='Close-up photo of the cool man and beautiful woman in surprised expressions as they accidentally discover a mysterious island while on vacation by the sea, 35mm photograph, film, professional, 4k, highly detailed.', type=str) + parser.add_argument('--negative_prompt', default='noisy, blurry, soft, deformed, ugly', type=str) + parser.add_argument('--seed', default=22, type=int) + parser.add_argument('--suffix', default='', type=str) + parser.add_argument('--segment_type', default='yoloworld', help='GroundingDINO or yoloworld', type=str) + return parser.parse_args() + +if __name__ == '__main__': + args = parse_args() + + prompts = [args.prompt]*2 + prompts_tmp = copy.deepcopy(prompts) + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + main(device, args.segment_type) \ No newline at end of file diff --git a/gradio_demo/character_template.py b/gradio_demo/character_template.py new file mode 100644 index 0000000000000000000000000000000000000000..07bb4873a325f1fe31eee8fa383298dbb3585623 --- /dev/null +++ b/gradio_demo/character_template.py @@ -0,0 +1,62 @@ +character_list_man = [ + { + "name": "Harry Potter (identifier: Harry Potter)", + "prompt": "Close-up photo of the Harry Potter, 35mm photograph, film, professional, 4k, highly detailed.", + "negative_prompt": "noisy, blurry, soft, deformed, ugly", + "path": "./checkpoint/lora/Harry_Potter.safetensors", + }, + { + "name": "Chris Evans (identifier: Chris Evans)", + "prompt": "Close-up photo of the Chris Evans, 35mm photograph, film, professional, 4k, highly detailed.", + "negative_prompt": "noisy, blurry, soft, deformed, ugly", + "path": "./checkpoint/lora/chris-evans.safetensors", + }, + { + "name": "Jordan Torres (identifier: jordan_torres)", + "prompt": "Close-up photo of the jordan_torres man, 35mm photograph, film, professional, 4k, highly detailed.", + "negative_prompt": "noisy, blurry, soft, deformed, ugly", + "path": "./checkpoint/lora/jordan_torres_v2_xl.safetensors", + }, +] + +character_list_woman = [ + { + "name": "Hermione Granger (identifier: Hermione Granger)", + "prompt": "Close-up photo of the Hermione Granger, 35mm photograph, film, professional, 4k, highly detailed.", + "negative_prompt": "noisy, blurry, soft, deformed, ugly", + "path": "./checkpoint/lora/Hermione_Granger.safetensors", + }, + { + "name": "Taylor Swift (identifier: TaylorSwift)", + "prompt": "Close-up photo of the TaylorSwift, 35mm photograph, film, professional, 4k, highly detailed.", + "negative_prompt": "noisy, blurry, soft, deformed, ugly", + "path": "./checkpoint/lora/TaylorSwiftSDXL.safetensors", + }, + { + "name": "Keira Knightley (identifier: ohwx woman)", + "prompt": "Close-up photo of the ohwx woman, 35mm photograph, film, professional, 4k, highly detailed.", + "negative_prompt": "noisy, blurry, soft, deformed, ugly", + "path": "./checkpoint/lora/keira_lora_sdxl_v1-000008.safetensors", + }, +] + +style_list = [ + { + "name": "None", + "prompt": "", + "path": "", + }, + { + "name": "Anime sketch style", + "prompt": "Pencil_Sketch:1.2, messy lines, greyscale, traditional media, sketch, ", + "path": "./checkpoint/style/Anime_Sketch_SDXL.safetensors", + } +] + +character_man = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in character_list_man} +character_woman = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in character_list_woman} +styles = {k["name"]: (k["prompt"]) for k in style_list} + +lorapath_man = {k["name"]: (k["path"]) for k in character_list_man} +lorapath_woman = {k["name"]: (k["path"]) for k in character_list_woman} +lorapath_styles = {k["name"]: (k["path"]) for k in style_list} \ No newline at end of file diff --git a/src/efficientvit/__init__.py b/src/efficientvit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/efficientvit/apps/__init__.py b/src/efficientvit/apps/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/efficientvit/apps/data_provider/__init__.py b/src/efficientvit/apps/data_provider/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9a5dfa34097fdf24730a203a9f24c5c4ac0a74 --- /dev/null +++ b/src/efficientvit/apps/data_provider/__init__.py @@ -0,0 +1,7 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from .augment import * +from .base import * +from .random_resolution import * diff --git a/src/efficientvit/apps/data_provider/augment/__init__.py b/src/efficientvit/apps/data_provider/augment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ea4d65f7f5a471cc433fbd68a58d4853b217d2 --- /dev/null +++ b/src/efficientvit/apps/data_provider/augment/__init__.py @@ -0,0 +1,6 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from .bbox import * +from .color_aug import * diff --git a/src/efficientvit/apps/data_provider/augment/bbox.py b/src/efficientvit/apps/data_provider/augment/bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..b9f089a3f70881313a5ce4308d1f74fbf1fa0c31 --- /dev/null +++ b/src/efficientvit/apps/data_provider/augment/bbox.py @@ -0,0 +1,30 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import numpy as np + +__all__ = ["rand_bbox"] + + +def rand_bbox( + h: int, + w: int, + lam: float, + rand_func: callable = np.random.uniform, +) -> tuple[int, int, int, int]: + """randomly sample bbox, used in cutmix""" + cut_rat = np.sqrt(1.0 - lam) + cut_w = w * cut_rat + cut_h = h * cut_rat + + # uniform + cx = rand_func(0, w) + cy = rand_func(0, h) + + bbx1 = int(np.clip(cx - cut_w / 2, 0, w)) + bby1 = int(np.clip(cy - cut_h / 2, 0, h)) + bbx2 = int(np.clip(cx + cut_w / 2, 0, w)) + bby2 = int(np.clip(cy + cut_h / 2, 0, h)) + + return bbx1, bby1, bbx2, bby2 diff --git a/src/efficientvit/apps/data_provider/augment/color_aug.py b/src/efficientvit/apps/data_provider/augment/color_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..c5e1dcc6998374738c300414b06e4fdb2ed8af95 --- /dev/null +++ b/src/efficientvit/apps/data_provider/augment/color_aug.py @@ -0,0 +1,84 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import numpy as np +import torchvision.transforms as transforms +from PIL import Image +from timm.data.auto_augment import rand_augment_transform + +__all__ = ["ColorAug", "RandAug"] + + +class ImageAug: + def aug_image(self, image: Image.Image) -> Image.Image: + raise NotImplementedError + + def __call__( + self, feed_dict: dict or np.ndarray or Image.Image + ) -> dict or np.ndarray or Image.Image: + if isinstance(feed_dict, dict): + output_dict = feed_dict + image = feed_dict[self.key] + else: + output_dict = None + image = feed_dict + is_ndarray = isinstance(image, np.ndarray) + if is_ndarray: + image = Image.fromarray(image) + + image = self.aug_image(image) + + if is_ndarray: + image = np.array(image) + + if output_dict is None: + return image + else: + output_dict[self.key] = image + return output_dict + + +class ColorAug(transforms.ColorJitter, ImageAug): + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, key="data"): + super().__init__( + brightness=brightness, + contrast=contrast, + saturation=saturation, + hue=hue, + ) + self.key = key + + def aug_image(self, image: Image.Image) -> Image.Image: + return transforms.ColorJitter.forward(self, image) + + def forward( + self, feed_dict: dict or np.ndarray or Image.Image + ) -> dict or np.ndarray or Image.Image: + return ImageAug.__call__(self, feed_dict) + + +class RandAug(ImageAug): + def __init__( + self, config: dict[str, any], mean: tuple[float, float, float], key="data" + ): + n = config.get("n", 2) + m = config.get("m", 9) + mstd = config.get("mstd", 1.0) + inc = config.get("inc", 1) + tpct = config.get("tpct", 0.45) + config_str = f"rand-n{n}-m{m}-mstd{mstd}-inc{inc}" + + aa_params = dict( + translate_pct=tpct, + img_mean=tuple([min(255, round(255 * x)) for x in mean]), + interpolation=Image.BICUBIC, + ) + self.aug_op = rand_augment_transform(config_str, aa_params) + self.key = key + + def aug_image(self, image: Image.Image) -> Image.Image: + return self.aug_op(image) + + def __repr__(self): + return self.aug_op.__repr__() diff --git a/src/efficientvit/apps/data_provider/base.py b/src/efficientvit/apps/data_provider/base.py new file mode 100644 index 0000000000000000000000000000000000000000..1f57679f681a6c95c48bd66216e48241c391f209 --- /dev/null +++ b/src/efficientvit/apps/data_provider/base.py @@ -0,0 +1,223 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import copy +import warnings + +import torch.utils.data +from torch.utils.data.distributed import DistributedSampler + +from src.efficientvit.apps.data_provider.random_resolution import RRSController +from src.efficientvit.models.utils import val2tuple + +__all__ = ["parse_image_size", "random_drop_data", "DataProvider"] + + +def parse_image_size(size: int or str) -> tuple[int, int]: + if isinstance(size, str): + size = [int(val) for val in size.split("-")] + return size[0], size[1] + else: + return val2tuple(size, 2) + + +def random_drop_data(dataset, drop_size: int, seed: int, keys=("samples",)): + g = torch.Generator() + g.manual_seed(seed) # set random seed before sampling validation set + rand_indexes = torch.randperm(len(dataset), generator=g).tolist() + + dropped_indexes = rand_indexes[:drop_size] + remaining_indexes = rand_indexes[drop_size:] + + dropped_dataset = copy.deepcopy(dataset) + for key in keys: + setattr( + dropped_dataset, + key, + [getattr(dropped_dataset, key)[idx] for idx in dropped_indexes], + ) + setattr(dataset, key, [getattr(dataset, key)[idx] for idx in remaining_indexes]) + return dataset, dropped_dataset + + +class DataProvider: + data_keys = ("samples",) + mean_std = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]} + SUB_SEED = 937162211 # random seed for sampling subset + VALID_SEED = 2147483647 # random seed for the validation set + + name: str + + def __init__( + self, + train_batch_size: int, + test_batch_size: int or None, + valid_size: int or float or None, + n_worker: int, + image_size: int or list[int] or str or list[str], + num_replicas: int or None = None, + rank: int or None = None, + train_ratio: float or None = None, + drop_last: bool = False, + ): + warnings.filterwarnings("ignore") + super().__init__() + + # batch_size & valid_size + self.train_batch_size = train_batch_size + self.test_batch_size = test_batch_size or self.train_batch_size + self.valid_size = valid_size + + # image size + if isinstance(image_size, list): + self.image_size = [parse_image_size(size) for size in image_size] + self.image_size.sort() # e.g., 160 -> 224 + RRSController.IMAGE_SIZE_LIST = copy.deepcopy(self.image_size) + self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size[-1] + else: + self.image_size = parse_image_size(image_size) + RRSController.IMAGE_SIZE_LIST = [self.image_size] + self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size + + # distributed configs + self.num_replicas = num_replicas + self.rank = rank + + # build datasets + train_dataset, val_dataset, test_dataset = self.build_datasets() + + if train_ratio is not None and train_ratio < 1.0: + assert 0 < train_ratio < 1 + _, train_dataset = random_drop_data( + train_dataset, + int(train_ratio * len(train_dataset)), + self.SUB_SEED, + self.data_keys, + ) + + # build data loader + self.train = self.build_dataloader( + train_dataset, train_batch_size, n_worker, drop_last=drop_last, train=True + ) + self.valid = self.build_dataloader( + val_dataset, test_batch_size, n_worker, drop_last=False, train=False + ) + self.test = self.build_dataloader( + test_dataset, test_batch_size, n_worker, drop_last=False, train=False + ) + if self.valid is None: + self.valid = self.test + self.sub_train = None + + @property + def data_shape(self) -> tuple[int, ...]: + return 3, self.active_image_size[0], self.active_image_size[1] + + def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any: + raise NotImplementedError + + def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any: + raise NotImplementedError + + def build_datasets(self) -> tuple[any, any, any]: + raise NotImplementedError + + def build_dataloader( + self, + dataset: any or None, + batch_size: int, + n_worker: int, + drop_last: bool, + train: bool, + ): + if dataset is None: + return None + if isinstance(self.image_size, list) and train: + from efficientvit.apps.data_provider.random_resolution._data_loader import \ + RRSDataLoader + + dataloader_class = RRSDataLoader + else: + dataloader_class = torch.utils.data.DataLoader + if self.num_replicas is None: + return dataloader_class( + dataset=dataset, + batch_size=batch_size, + shuffle=True, + num_workers=n_worker, + pin_memory=True, + drop_last=drop_last, + ) + else: + sampler = DistributedSampler(dataset, self.num_replicas, self.rank) + return dataloader_class( + dataset=dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=n_worker, + pin_memory=True, + drop_last=drop_last, + ) + + def set_epoch(self, epoch: int) -> None: + RRSController.set_epoch(epoch, len(self.train)) + if isinstance(self.train.sampler, DistributedSampler): + self.train.sampler.set_epoch(epoch) + + def assign_active_image_size(self, new_size: int or tuple[int, int]) -> None: + self.active_image_size = val2tuple(new_size, 2) + new_transform = self.build_valid_transform(self.active_image_size) + # change the transform of the valid and test set + self.valid.dataset.transform = self.test.dataset.transform = new_transform + + def sample_val_dataset(self, train_dataset, valid_transform) -> tuple[any, any]: + if self.valid_size is not None: + if 0 < self.valid_size < 1: + valid_size = int(self.valid_size * len(train_dataset)) + else: + assert self.valid_size >= 1 + valid_size = int(self.valid_size) + train_dataset, val_dataset = random_drop_data( + train_dataset, + valid_size, + self.VALID_SEED, + self.data_keys, + ) + val_dataset.transform = valid_transform + else: + val_dataset = None + return train_dataset, val_dataset + + def build_sub_train_loader(self, n_samples: int, batch_size: int) -> any: + # used for resetting BN running statistics + if self.sub_train is None: + self.sub_train = {} + if self.active_image_size in self.sub_train: + return self.sub_train[self.active_image_size] + + # construct dataset and dataloader + train_dataset = copy.deepcopy(self.train.dataset) + if n_samples < len(train_dataset): + _, train_dataset = random_drop_data( + train_dataset, + n_samples, + self.SUB_SEED, + self.data_keys, + ) + RRSController.ACTIVE_SIZE = self.active_image_size + train_dataset.transform = self.build_train_transform( + image_size=self.active_image_size + ) + data_loader = self.build_dataloader( + train_dataset, batch_size, self.train.num_workers, True, False + ) + + # pre-fetch data + self.sub_train[self.active_image_size] = [ + data + for data in data_loader + for _ in range(max(1, n_samples // len(train_dataset))) + ] + + return self.sub_train[self.active_image_size] diff --git a/src/efficientvit/apps/data_provider/random_resolution/__init__.py b/src/efficientvit/apps/data_provider/random_resolution/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b831fa9d3e933e76cf78120947143e8a19133ea2 --- /dev/null +++ b/src/efficientvit/apps/data_provider/random_resolution/__init__.py @@ -0,0 +1,7 @@ +"""Random resolution data loader compatible with multi-processing and distributed training. + +Replace Pytorch's DataLoader with RRSDataLoader to support random resolution +at the training time, resolution sampling is controlled by RRSController +""" + +from .controller import * diff --git a/src/efficientvit/apps/data_provider/random_resolution/_data_loader.py b/src/efficientvit/apps/data_provider/random_resolution/_data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..df06e2b95a4468cbeecc6b56c6f97deb8c35ff2c --- /dev/null +++ b/src/efficientvit/apps/data_provider/random_resolution/_data_loader.py @@ -0,0 +1,1598 @@ +r"""This file is based on torch/utils/data/data_loader.py + +Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter + +To support these two classes, in `./_utils` we define many utility methods and +functions to be run in multiprocessing. E.g., the data loading worker loop is +in `./_utils/worker.py`. +""" + +import functools +import itertools +import logging +import multiprocessing as python_multiprocessing +import os +import queue +import threading +import warnings +from typing import (Any, Callable, Generic, Iterable, List, Optional, Sequence, + TypeVar, Union) + +import torch +import torch.distributed as dist +import torch.multiprocessing as multiprocessing +import torch.utils.data.graph_settings +from torch._utils import ExceptionWrapper +from torch.utils.data import (BatchSampler, Dataset, IterableDataset, + IterDataPipe, MapDataPipe, RandomSampler, + Sampler, SequentialSampler, _utils) +from torch.utils.data.datapipes.datapipe import ( + _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper) + +from ._data_worker import _worker_loop + +__all__ = ["RRSDataLoader"] + +T_co = TypeVar("T_co", covariant=True) +T = TypeVar("T") +_worker_init_fn_t = Callable[[int], None] + +# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that +# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'. +# See https://github.com/python/mypy/issues/3737. +_collate_fn_t = Callable[[List[T]], Any] + + +# These functions used to be defined in this file. However, it was moved to +# _utils/collate.py. Although it is rather hard to access this from user land +# (one has to explicitly directly `import torch.utils.data.dataloader`), there +# probably is user code out there using it. This aliasing maintains BC in this +# aspect. +default_collate: _collate_fn_t = _utils.collate.default_collate +default_convert = _utils.collate.default_convert + +get_worker_info = _utils.worker.get_worker_info + +logger = logging.getLogger(__name__) + + +class _DatasetKind: + Map = 0 + Iterable = 1 + + @staticmethod + def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): + if kind == _DatasetKind.Map: + return _utils.fetch._MapDatasetFetcher( + dataset, auto_collation, collate_fn, drop_last + ) + else: + return _utils.fetch._IterableDatasetFetcher( + dataset, auto_collation, collate_fn, drop_last + ) + + +class _InfiniteConstantSampler(Sampler): + r"""Analogous to ``itertools.repeat(None, None)``. + Used as sampler for :class:`~torch.utils.data.IterableDataset`. + + Args: + data_source (Dataset): dataset to sample from + """ + + def __init__(self): + super().__init__(None) + + def __iter__(self): + while True: + yield None + + +def _get_distributed_settings(): + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size(), dist.get_rank() + else: + return 1, 0 + + +def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id): + global_worker_id = worker_id + info = torch.utils.data.get_worker_info() + assert info is not None + total_workers = info.num_workers + datapipe = info.dataset + assert isinstance(datapipe, (IterDataPipe, MapDataPipe)) + # To distribute elements across distributed process evenly, we should shard data on distributed + # processes first then shard on worker processes + total_workers *= world_size + global_worker_id = global_worker_id * world_size + rank_id + # For BC, use default SHARDING_PRIORITIES + torch.utils.data.graph_settings.apply_sharding( + datapipe, total_workers, global_worker_id + ) + if worker_init_fn is not None: + worker_init_fn(worker_id) + + +def _share_dist_seed(generator, pg): + _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator) + if isinstance(pg, dist.ProcessGroup): + dist.broadcast(_shared_seed, src=0, group=pg) + return _shared_seed.item() + + +class RRSDataLoader(Generic[T_co]): + r""" + Data loader. Combines a dataset and a sampler, and provides an iterable over + the given dataset. + + The :class:`~torch.utils.data.DataLoader` supports both map-style and + iterable-style datasets with single- or multi-process loading, customizing + loading order and optional automatic batching (collation) and memory pinning. + + See :py:mod:`torch.utils.data` documentation page for more details. + + Args: + dataset (Dataset): dataset from which to load the data. + batch_size (int, optional): how many samples per batch to load + (default: ``1``). + shuffle (bool, optional): set to ``True`` to have the data reshuffled + at every epoch (default: ``False``). + sampler (Sampler or Iterable, optional): defines the strategy to draw + samples from the dataset. Can be any ``Iterable`` with ``__len__`` + implemented. If specified, :attr:`shuffle` must not be specified. + batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but + returns a batch of indices at a time. Mutually exclusive with + :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, + and :attr:`drop_last`. + num_workers (int, optional): how many subprocesses to use for data + loading. ``0`` means that the data will be loaded in the main process. + (default: ``0``) + collate_fn (Callable, optional): merges a list of samples to form a + mini-batch of Tensor(s). Used when using batched loading from a + map-style dataset. + pin_memory (bool, optional): If ``True``, the data loader will copy Tensors + into device/CUDA pinned memory before returning them. If your data elements + are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, + see the example below. + drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If ``False`` and + the size of dataset is not divisible by the batch size, then the last batch + will be smaller. (default: ``False``) + timeout (numeric, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: ``0``) + worker_init_fn (Callable, optional): If not ``None``, this will be called on each + worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as + input, after seeding and before data loading. (default: ``None``) + generator (torch.Generator, optional): If not ``None``, this RNG will be used + by RandomSampler to generate random indexes and multiprocessing to generate + `base_seed` for workers. (default: ``None``) + prefetch_factor (int, optional, keyword-only arg): Number of batches loaded + in advance by each worker. ``2`` means there will be a total of + 2 * num_workers batches prefetched across all workers. (default value depends + on the set value for num_workers. If value of num_workers=0 default is ``None``. + Otherwise if value of num_workers>0 default is ``2``). + persistent_workers (bool, optional): If ``True``, the data loader will not shutdown + the worker processes after a dataset has been consumed once. This allows to + maintain the workers `Dataset` instances alive. (default: ``False``) + pin_memory_device (str, optional): the data loader will copy Tensors + into device pinned memory before returning them if pin_memory is set to true. + + + .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` + cannot be an unpicklable object, e.g., a lambda function. See + :ref:`multiprocessing-best-practices` on more details related + to multiprocessing in PyTorch. + + .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used. + When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`, + it instead returns an estimate based on ``len(dataset) / batch_size``, with proper + rounding depending on :attr:`drop_last`, regardless of multi-process loading + configurations. This represents the best guess PyTorch can make because PyTorch + trusts user :attr:`dataset` code in correctly handling multi-process + loading to avoid duplicate data. + + However, if sharding results in multiple workers having incomplete last batches, + this estimate can still be inaccurate, because (1) an otherwise complete batch can + be broken into multiple ones and (2) more than one batch worth of samples can be + dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such + cases in general. + + See `Dataset Types`_ for more details on these two types of datasets and how + :class:`~torch.utils.data.IterableDataset` interacts with + `Multi-process data loading`_. + + .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and + :ref:`data-loading-randomness` notes for random seed related questions. + """ + + dataset: Dataset[T_co] + batch_size: Optional[int] + num_workers: int + pin_memory: bool + drop_last: bool + timeout: float + sampler: Union[Sampler, Iterable] + pin_memory_device: str + prefetch_factor: Optional[int] + _iterator: Optional["_BaseDataLoaderIter"] + __initialized = False + + def __init__( + self, + dataset: Dataset[T_co], + batch_size: Optional[int] = 1, + shuffle: Optional[bool] = None, + sampler: Union[Sampler, Iterable, None] = None, + batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None, + num_workers: int = 0, + collate_fn: Optional[_collate_fn_t] = None, + pin_memory: bool = False, + drop_last: bool = False, + timeout: float = 0, + worker_init_fn: Optional[_worker_init_fn_t] = None, + multiprocessing_context=None, + generator=None, + *, + prefetch_factor: Optional[int] = None, + persistent_workers: bool = False, + pin_memory_device: str = "" + ): + torch._C._log_api_usage_once("python.data_loader") + + if num_workers < 0: + raise ValueError( + "num_workers option should be non-negative; " + "use num_workers=0 to disable multiprocessing." + ) + + if timeout < 0: + raise ValueError("timeout option should be non-negative") + + if num_workers == 0 and prefetch_factor is not None: + raise ValueError( + "prefetch_factor option could only be specified in multiprocessing." + "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None." + ) + elif num_workers > 0 and prefetch_factor is None: + prefetch_factor = 2 + elif prefetch_factor is not None and prefetch_factor < 0: + raise ValueError("prefetch_factor option should be non-negative") + + if persistent_workers and num_workers == 0: + raise ValueError("persistent_workers option needs num_workers > 0") + + self.dataset = dataset + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor + self.pin_memory = pin_memory + self.pin_memory_device = pin_memory_device + self.timeout = timeout + self.worker_init_fn = worker_init_fn + self.multiprocessing_context = multiprocessing_context + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler + if isinstance(self.dataset, IterDataPipe): + self.dataset = _IterDataPipeSerializationWrapper(self.dataset) + elif isinstance(self.dataset, MapDataPipe): + self.dataset = _MapDataPipeSerializationWrapper(self.dataset) + + # Arg-check dataset related before checking samplers because we want to + # tell users that iterable-style datasets are incompatible with custom + # samplers first, so that they don't learn that this combo doesn't work + # after spending time fixing the custom sampler errors. + if isinstance(dataset, IterableDataset): + self._dataset_kind = _DatasetKind.Iterable + # NOTE [ Custom Samplers and IterableDataset ] + # + # `IterableDataset` does not support custom `batch_sampler` or + # `sampler` since the key is irrelevant (unless we support + # generator-style dataset one day...). + # + # For `sampler`, we always create a dummy sampler. This is an + # infinite sampler even when the dataset may have an implemented + # finite `__len__` because in multi-process data loading, naive + # settings will return duplicated data (which may be desired), and + # thus using a sampler with length matching that of dataset will + # cause data lost (you may have duplicates of the first couple + # batches, but never see anything afterwards). Therefore, + # `Iterabledataset` always uses an infinite sampler, an instance of + # `_InfiniteConstantSampler` defined above. + # + # A custom `batch_sampler` essentially only controls the batch size. + # However, it is unclear how useful it would be since an iterable-style + # dataset can handle that within itself. Moreover, it is pointless + # in multi-process data loading as the assignment order of batches + # to workers is an implementation detail so users can not control + # how to batchify each worker's iterable. Thus, we disable this + # option. If this turns out to be useful in future, we can re-enable + # this, and support custom samplers that specify the assignments to + # specific workers. + if isinstance(dataset, IterDataPipe): + if shuffle is not None: + dataset = torch.utils.data.graph_settings.apply_shuffle_settings( + dataset, shuffle=shuffle + ) + # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default. + elif shuffle not in {False, None}: + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + "shuffle option, but got shuffle={}".format(shuffle) + ) + + if sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + "sampler option, but got sampler={}".format(sampler) + ) + elif batch_sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + "batch_sampler option, but got batch_sampler={}".format( + batch_sampler + ) + ) + else: + shuffle = bool(shuffle) + self._dataset_kind = _DatasetKind.Map + + if sampler is not None and shuffle: + raise ValueError("sampler option is mutually exclusive with " "shuffle") + + if batch_sampler is not None: + # auto_collation with custom batch_sampler + if batch_size != 1 or shuffle or sampler is not None or drop_last: + raise ValueError( + "batch_sampler option is mutually exclusive " + "with batch_size, shuffle, sampler, and " + "drop_last" + ) + batch_size = None + drop_last = False + elif batch_size is None: + # no auto_collation + if drop_last: + raise ValueError( + "batch_size=None option disables auto-batching " + "and is mutually exclusive with drop_last" + ) + + if sampler is None: # give default samplers + if self._dataset_kind == _DatasetKind.Iterable: + # See NOTE [ Custom Samplers and IterableDataset ] + sampler = _InfiniteConstantSampler() + else: # map-style + if shuffle: + sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type] + else: + sampler = SequentialSampler(dataset) # type: ignore[arg-type] + + if batch_size is not None and batch_sampler is None: + # auto_collation without custom batch_sampler + batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + self.batch_size = batch_size + self.drop_last = drop_last + self.sampler = sampler + self.batch_sampler = batch_sampler + self.generator = generator + + if collate_fn is None: + if self._auto_collation: + collate_fn = _utils.collate.default_collate + else: + collate_fn = _utils.collate.default_convert + + self.collate_fn = collate_fn + self.persistent_workers = persistent_workers + + self.__initialized = True + self._IterableDataset_len_called = ( + None # See NOTE [ IterableDataset and __len__ ] + ) + + self._iterator = None + + self.check_worker_number_rationality() + + torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined] + + def _get_iterator(self) -> "_BaseDataLoaderIter": + if self.num_workers == 0: + return _SingleProcessDataLoaderIter(self) + else: + self.check_worker_number_rationality() + return _MultiProcessingDataLoaderIter(self) + + @property + def multiprocessing_context(self): + return self.__multiprocessing_context + + @multiprocessing_context.setter + def multiprocessing_context(self, multiprocessing_context): + if multiprocessing_context is not None: + if self.num_workers > 0: + if isinstance(multiprocessing_context, str): + valid_start_methods = multiprocessing.get_all_start_methods() + if multiprocessing_context not in valid_start_methods: + raise ValueError( + ( + "multiprocessing_context option " + "should specify a valid start method in {!r}, but got " + "multiprocessing_context={!r}" + ).format(valid_start_methods, multiprocessing_context) + ) + multiprocessing_context = multiprocessing.get_context( + multiprocessing_context + ) + + if not isinstance( + multiprocessing_context, python_multiprocessing.context.BaseContext + ): + raise TypeError( + ( + "multiprocessing_context option should be a valid context " + "object or a string specifying the start method, but got " + "multiprocessing_context={}" + ).format(multiprocessing_context) + ) + else: + raise ValueError( + ( + "multiprocessing_context can only be used with " + "multi-process loading (num_workers > 0), but got " + "num_workers={}" + ).format(self.num_workers) + ) + + self.__multiprocessing_context = multiprocessing_context + + def __setattr__(self, attr, val): + if self.__initialized and attr in ( + "batch_size", + "batch_sampler", + "sampler", + "drop_last", + "dataset", + "persistent_workers", + ): + raise ValueError( + "{} attribute should not be set after {} is " + "initialized".format(attr, self.__class__.__name__) + ) + + super().__setattr__(attr, val) + + # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up + # since '_BaseDataLoaderIter' references 'DataLoader'. + def __iter__(self) -> "_BaseDataLoaderIter": + # When using a single worker the returned iterator should be + # created everytime to avoid reseting its state + # However, in the case of a multiple workers iterator + # the iterator is only created once in the lifetime of the + # DataLoader object so that workers can be reused + if self.persistent_workers and self.num_workers > 0: + if self._iterator is None: + self._iterator = self._get_iterator() + else: + self._iterator._reset(self) + return self._iterator + else: + return self._get_iterator() + + @property + def _auto_collation(self): + return self.batch_sampler is not None + + @property + def _index_sampler(self): + # The actual sampler used for generating indices for `_DatasetFetcher` + # (see _utils/fetch.py) to read data at each time. This would be + # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. + # We can't change `.sampler` and `.batch_sampler` attributes for BC + # reasons. + if self._auto_collation: + return self.batch_sampler + else: + return self.sampler + + def __len__(self) -> int: + if self._dataset_kind == _DatasetKind.Iterable: + # NOTE [ IterableDataset and __len__ ] + # + # For `IterableDataset`, `__len__` could be inaccurate when one naively + # does multi-processing data loading, since the samples will be duplicated. + # However, no real use case should be actually using that behavior, so + # it should count as a user error. We should generally trust user + # code to do the proper thing (e.g., configure each replica differently + # in `__iter__`), and give us the correct `__len__` if they choose to + # implement it (this will still throw if the dataset does not implement + # a `__len__`). + # + # To provide a further warning, we track if `__len__` was called on the + # `DataLoader`, save the returned value in `self._len_called`, and warn + # if the iterator ends up yielding more than this number of samples. + + # Cannot statically verify that dataset is Sized + length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type] + if ( + self.batch_size is not None + ): # IterableDataset doesn't allow custom sampler or batch_sampler + from math import ceil + + if self.drop_last: + length = length // self.batch_size + else: + length = ceil(length / self.batch_size) + return length + else: + return len(self._index_sampler) + + def check_worker_number_rationality(self): + # This function check whether the dataloader's worker number is rational based on + # current system's resource. Current rule is that if the number of workers this + # Dataloader will create is bigger than the number of logical cpus that is allowed to + # use, than we will pop up a warning to let user pay attention. + # + # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2 + # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current + # DataLoader process can use half of them which is 32, then the rational max number of + # worker that initiated from this process is 32. + # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32. + # So the warning message is triggered to notify the user to lower the worker number if + # necessary. + # + # + # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is + # available (available in most of Linux system, but not OSX and Windows). + # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but + # it doesn't repect cpuset. + # We don't take threading into account since each worker process is single threaded + # at this time. + # + # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc) + # other than `torch.set_num_threads` to 1 in the worker process, if the passing + # in functions use 3rd party modules that rely on those threading flags to determine + # how many thread to create (eg. numpy, etc), then it is caller's responsibility to + # set those flags correctly. + def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked): + + suggested_max_worker_msg = ( + ( + ( + "Our suggested max number of worker in current system is {}{}, which is smaller " + "than what this DataLoader is going to create." + ).format( + num_worker_suggest, + ( + "" + if cpuset_checked + else " (`cpuset` is not taken into account)" + ), + ) + ) + if num_worker_suggest is not None + else ( + "DataLoader is not able to compute a suggested max number of worker in current system." + ) + ) + + warn_msg = ( + "This DataLoader will create {} worker processes in total. {} " + "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, " + "lower the worker number to avoid potential slowness/freeze if necessary." + ).format(num_worker_created, suggested_max_worker_msg) + return warn_msg + + if not self.num_workers or self.num_workers == 0: + return + + # try to compute a suggested max number of worker based on system's resource + max_num_worker_suggest = None + cpuset_checked = False + if hasattr(os, "sched_getaffinity"): + try: + max_num_worker_suggest = len(os.sched_getaffinity(0)) + cpuset_checked = True + except Exception: + pass + if max_num_worker_suggest is None: + # os.cpu_count() could return Optional[int] + # get cpu count first and check None in order to satify mypy check + cpu_count = os.cpu_count() + if cpu_count is not None: + max_num_worker_suggest = cpu_count + + if max_num_worker_suggest is None: + warnings.warn( + _create_warning_msg( + max_num_worker_suggest, self.num_workers, cpuset_checked + ) + ) + return + + if self.num_workers > max_num_worker_suggest: + warnings.warn( + _create_warning_msg( + max_num_worker_suggest, self.num_workers, cpuset_checked + ) + ) + + +class _BaseDataLoaderIter: + def __init__(self, loader: RRSDataLoader) -> None: + self._dataset = loader.dataset + self._shared_seed = None + self._pg = None + if isinstance(self._dataset, IterDataPipe): + if dist.is_available() and dist.is_initialized(): + self._pg = dist.new_group(backend="gloo") + self._shared_seed = _share_dist_seed(loader.generator, self._pg) + shared_rng = torch.Generator() + shared_rng.manual_seed(self._shared_seed) + self._dataset = torch.utils.data.graph_settings.apply_random_seed( + self._dataset, shared_rng + ) + self._dataset_kind = loader._dataset_kind + self._IterableDataset_len_called = loader._IterableDataset_len_called + self._auto_collation = loader._auto_collation + self._drop_last = loader.drop_last + self._index_sampler = loader._index_sampler + self._num_workers = loader.num_workers + ws, rank = _get_distributed_settings() + self._world_size = ws + self._rank = rank + # for other backends, pin_memory_device need to set. if not set + # default behaviour is CUDA device. if pin_memory_device is selected + # and pin_memory is not set, the default behaviour false. + if len(loader.pin_memory_device) == 0: + self._pin_memory = loader.pin_memory and torch.cuda.is_available() + self._pin_memory_device = None + else: + if not loader.pin_memory: + warn_msg = ( + "pin memory device is set and pin_memory flag is not used then device pinned memory won't be used" + "please set pin_memory to true, if you need to use the device pin memory" + ) + warnings.warn(warn_msg) + + self._pin_memory = loader.pin_memory + self._pin_memory_device = loader.pin_memory_device + self._timeout = loader.timeout + self._collate_fn = loader.collate_fn + self._sampler_iter = iter(self._index_sampler) + self._base_seed = ( + torch.empty((), dtype=torch.int64) + .random_(generator=loader.generator) + .item() + ) + self._persistent_workers = loader.persistent_workers + self._num_yielded = 0 + self._profile_name = "enumerate(DataLoader)#{}.__next__".format( + self.__class__.__name__ + ) + + def __iter__(self) -> "_BaseDataLoaderIter": + return self + + def _reset(self, loader, first_iter=False): + self._sampler_iter = iter(self._index_sampler) + self._num_yielded = 0 + self._IterableDataset_len_called = loader._IterableDataset_len_called + if isinstance(self._dataset, IterDataPipe): + self._shared_seed = _share_dist_seed(loader.generator, self._pg) + shared_rng = torch.Generator() + shared_rng.manual_seed(self._shared_seed) + self._dataset = torch.utils.data.graph_settings.apply_random_seed( + self._dataset, shared_rng + ) + + def _next_index(self): + return next(self._sampler_iter) # may raise StopIteration + + def _next_data(self): + raise NotImplementedError + + def __next__(self) -> Any: + with torch.autograd.profiler.record_function(self._profile_name): + if self._sampler_iter is None: + self._reset() # type: ignore[call-arg] + data = self._next_data() + self._num_yielded += 1 + if ( + self._dataset_kind == _DatasetKind.Iterable + and self._IterableDataset_len_called is not None + and self._num_yielded > self._IterableDataset_len_called + ): + warn_msg = ( + "Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} " + "samples have been fetched. " + ).format( + self._dataset, self._IterableDataset_len_called, self._num_yielded + ) + if self._num_workers > 0: + warn_msg += ( + "For multiprocessing data-loading, this could be caused by not properly configuring the " + "IterableDataset replica at each worker. Please see " + "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples." + ) + warnings.warn(warn_msg) + return data + + def __len__(self) -> int: + return len(self._index_sampler) + + def __getstate__(self): + # across multiple threads for HOGWILD. + # Probably the best way to do this is by moving the sample pushing + # to a separate thread and then just sharing the data queue + # but signalling the end is tricky without a non-blocking API + raise NotImplementedError("{} cannot be pickled", self.__class__.__name__) + + +class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): + def __init__(self, loader): + super().__init__(loader) + assert self._timeout == 0 + assert self._num_workers == 0 + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # Taking care of distributed sharding + if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): + # For BC, use default SHARDING_PRIORITIES + torch.utils.data.graph_settings.apply_sharding( + self._dataset, self._world_size, self._rank + ) + + self._dataset_fetcher = _DatasetKind.create_fetcher( + self._dataset_kind, + self._dataset, + self._auto_collation, + self._collate_fn, + self._drop_last, + ) + + def _next_data(self): + index = self._next_index() # may raise StopIteration + data = self._dataset_fetcher.fetch(index) # may raise StopIteration + if self._pin_memory: + data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) + return data + + +class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): + r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" + + # NOTE [ Data Loader Multiprocessing Shutdown Logic ] + # + # Preliminary: + # + # Our data model looks like this (queues are indicated with curly brackets): + # + # main process || + # | || + # {index_queue} || + # | || + # worker processes || DATA + # | || + # {worker_result_queue} || FLOW + # | || + # pin_memory_thread of main process || DIRECTION + # | || + # {data_queue} || + # | || + # data output \/ + # + # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if + # `pin_memory=False`. + # + # + # Terminating multiprocessing logic requires very careful design. In + # particular, we need to make sure that + # + # 1. The iterator gracefully exits the workers when its last reference is + # gone or it is depleted. + # + # In this case, the workers should be gracefully exited because the + # main process may still need to continue to run, and we want cleaning + # up code in the workers to be executed (e.g., releasing GPU memory). + # Naturally, we implement the shutdown logic in `__del__` of + # DataLoaderIterator. + # + # We delay the discussion on the logic in this case until later. + # + # 2. The iterator exits the workers when the loader process and/or worker + # processes exits normally or with error. + # + # We set all workers and `pin_memory_thread` to have `daemon=True`. + # + # You may ask, why can't we make the workers non-daemonic, and + # gracefully exit using the same logic as we have in `__del__` when the + # iterator gets deleted (see 1 above)? + # + # First of all, `__del__` is **not** guaranteed to be called when + # interpreter exits. Even if it is called, by the time it executes, + # many Python core library resources may alreay be freed, and even + # simple things like acquiring an internal lock of a queue may hang. + # Therefore, in this case, we actually need to prevent `__del__` from + # being executed, and rely on the automatic termination of daemonic + # children. + # + # Thus, we register an `atexit` hook that sets a global flag + # `_utils.python_exit_status`. Since `atexit` hooks are executed in the + # reverse order of registration, we are guaranteed that this flag is + # set before library resources we use are freed (which, at least in + # CPython, is done via an `atexit` handler defined in + # `multiprocessing/util.py` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362 + # registered when an object requiring this mechanism is first + # created, e.g., `mp.Queue` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103 + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29 + # ) + # + # So in `__del__`, we check if `_utils.python_exit_status` is set or + # `None` (freed), and perform no-op if so. + # + # However, simply letting library clean-up codes run can also be bad, + # because such codes (i.e., `multiprocessing.util._exit_function()`) + # include join putting threads for `mp.Queue`, which can be blocking. + # Hence, the main process putting threads are called with + # `cancel_join_thread` at creation. See later section + # [ 3b. A process won't hang when putting into a queue; ] + # for more details. + # + # Here are two example cases where library clean-up codes can run + # before `__del__` is called: + # + # 1. If we hold onto a reference to the iterator, it more often + # than not tries to do `multiprocessing` library cleaning before + # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666) + # and thus prevents our cleaning-up code to run first. + # + # 2. A similar issue araises when a `DataLoader` is used in a subprocess. + # When a process ends, it shuts the all its daemonic children + # down with a SIGTERM (instead of joining them without a timeout). + # Simiarly for threads, but by a different mechanism. This fact, + # together with a few implementation details of multiprocessing, forces + # us to make workers daemonic. All of our problems arise when a + # DataLoader is used in a subprocess, and are caused by multiprocessing + # code which looks more or less like this: + # + # try: + # your_function_using_a_dataloader() + # finally: + # multiprocessing.util._exit_function() + # + # The joining/termination mentioned above happens inside + # `_exit_function()`. Now, if `your_function_using_a_dataloader()` + # throws, the stack trace stored in the exception will prevent the + # frame which uses `DataLoaderIter` to be freed. If the frame has any + # reference to the `DataLoaderIter` (e.g., in a method of the iter), + # its `__del__`, which starts the shutdown procedure, will not be + # called. That, in turn, means that workers aren't notified. Attempting + # to join in `_exit_function` will then result in a hang. + # + # For context, `_exit_function` is also registered as an `atexit` call. + # So it is unclear to me (@ssnl) why this is needed in a finally block. + # The code dates back to 2008 and there is no comment on the original + # PEP 371 or patch https://bugs.python.org/issue3050 (containing both + # the finally block and the `atexit` registration) that explains this. + # + # + # Finally, another choice is to just shutdown workers with logic in 1 + # above whenever we see an error in `next`. This isn't ideal because + # a. It prevents users from using try-catch to resume data loading. + # b. It doesn't prevent hanging if users have references to the + # iterator. + # + # 3. All processes exit if any of them die unexpectedly by fatal signals. + # + # As shown above, the workers are set as daemonic children of the main + # process. However, automatic cleaning-up of such child processes only + # happens if the parent process exits gracefully (e.g., not via fatal + # signals like SIGKILL). So we must ensure that each process will exit + # even the process that should send/receive data to/from it were + # killed, i.e., + # + # a. A process won't hang when getting from a queue. + # + # Even with carefully designed data dependencies (i.e., a `put()` + # always corresponding to a `get()`), hanging on `get()` can still + # happen when data in queue is corrupted (e.g., due to + # `cancel_join_thread` or unexpected exit). + # + # For child exit, we set a timeout whenever we try to get data + # from `data_queue`, and check the workers' status on each timeout + # and error. + # See `_DataLoaderiter._get_batch()` and + # `_DataLoaderiter._try_get_data()` for details. + # + # Additionally, for child exit on non-Windows platforms, we also + # register a SIGCHLD handler (which is supported on Windows) on + # the main process, which checks if any of the workers fail in the + # (Python) handler. This is more efficient and faster in detecting + # worker failures, compared to only using the above mechanism. + # See `DataLoader.cpp` and `_utils/signal_handling.py` for details. + # + # For `.get()` calls where the sender(s) is not the workers, we + # guard them with timeouts, and check the status of the sender + # when timeout happens: + # + in the workers, the `_utils.worker.ManagerWatchdog` class + # checks the status of the main process. + # + if `pin_memory=True`, when getting from `pin_memory_thread`, + # check `pin_memory_thread` status periodically until `.get()` + # returns or see that `pin_memory_thread` died. + # + # b. A process won't hang when putting into a queue; + # + # We use `mp.Queue` which has a separate background thread to put + # objects from an unbounded buffer array. The background thread is + # daemonic and usually automatically joined when the process + # *exits*. + # + # In case that the receiver has ended abruptly while + # reading from the pipe, the join will hang forever. The usual + # solution for this in Python is calling `q.cancel_join_thread`, + # which prevents automatically joining it when finalizing + # (exiting). + # + # Nonetheless, `cancel_join_thread` must only be called when the + # queue is **not** going to be read from or write into by another + # process, because it may hold onto a lock or leave corrupted data + # in the queue, leading other readers/writers to hang. + # + # Hence, + # + For worker processes, we only do so (for their output + # queues, i.e., `worker_result_queue`) before exiting. + # + For `pin_memory_thread`, its output queue `data_queue` is a + # `queue.Queue` that does blocking `put` if the queue is full. + # So there is no above problem, but as a result, in + # `_pin_memory_loop`, we do need to wrap the `put` in a loop + # that breaks not only upon success, but also when the main + # process stops reading, i.e., is shutting down. + # + For loader process, we `cancel_join_thread()` for all + # `_index_queues` because the whole purpose of workers and + # `pin_memory_thread` is to serve the loader process. If + # loader process is already exiting, we don't really care if + # the queues are corrupted. + # + # + # Now let's get back to 1: + # how we gracefully exit the workers when the last reference to the + # iterator is gone. + # + # To achieve this, we implement the following logic along with the design + # choices mentioned above: + # + # `workers_done_event`: + # A `multiprocessing.Event` shared among the main process and all worker + # processes. This is used to signal the workers that the iterator is + # shutting down. After it is set, they will not send processed data to + # queues anymore, and only wait for the final `None` before exiting. + # `done_event` isn't strictly needed. I.e., we can just check for `None` + # from the input queue, but it allows us to skip wasting resources + # processing data if we are already shutting down. + # + # `pin_memory_thread_done_event`: + # A `threading.Event` for a similar purpose to that of + # `workers_done_event`, but is for the `pin_memory_thread`. The reason + # that separate events are needed is that `pin_memory_thread` reads from + # the output queue of the workers. But the workers, upon seeing that + # `workers_done_event` is set, only wants to see the final `None`, and is + # not required to flush all data in the output queue (e.g., it may call + # `cancel_join_thread` on that queue if its `IterableDataset` iterator + # happens to exhaust coincidentally, which is out of the control of the + # main process). Thus, since we will exit `pin_memory_thread` before the + # workers (see below), two separete events are used. + # + # NOTE: In short, the protocol is that the main process will set these + # `done_event`s and then the corresponding processes/threads a `None`, + # and that they may exit at any time after receiving the `None`. + # + # NOTE: Using `None` as the final signal is valid, since normal data will + # always be a 2-tuple with the 1st element being the index of the data + # transferred (different from dataset index/key), and the 2nd being + # either the dataset key or the data sample (depending on which part + # of the data model the queue is at). + # + # [ worker processes ] + # While loader process is alive: + # Get from `index_queue`. + # If get anything else, + # Check `workers_done_event`. + # If set, continue to next iteration + # i.e., keep getting until see the `None`, then exit. + # Otherwise, process data: + # If is fetching from an `IterableDataset` and the iterator + # is exhausted, send an `_IterableDatasetStopIteration` + # object to signal iteration end. The main process, upon + # receiving such an object, will send `None` to this + # worker and not use the corresponding `index_queue` + # anymore. + # If timed out, + # No matter `workers_done_event` is set (still need to see `None`) + # or not, must continue to next iteration. + # (outside loop) + # If `workers_done_event` is set, (this can be False with `IterableDataset`) + # `data_queue.cancel_join_thread()`. (Everything is ending here: + # main process won't read from it; + # other workers will also call + # `cancel_join_thread`.) + # + # [ pin_memory_thread ] + # # No need to check main thread. If this thread is alive, the main loader + # # thread must be alive, because this thread is set as daemonic. + # While `pin_memory_thread_done_event` is not set: + # Get from `index_queue`. + # If timed out, continue to get in the next iteration. + # Otherwise, process data. + # While `pin_memory_thread_done_event` is not set: + # Put processed data to `data_queue` (a `queue.Queue` with blocking put) + # If timed out, continue to put in the next iteration. + # Otherwise, break, i.e., continuing to the out loop. + # + # NOTE: we don't check the status of the main thread because + # 1. if the process is killed by fatal signal, `pin_memory_thread` + # ends. + # 2. in other cases, either the cleaning-up in __del__ or the + # automatic exit of daemonic thread will take care of it. + # This won't busy-wait either because `.get(timeout)` does not + # busy-wait. + # + # [ main process ] + # In the DataLoader Iter's `__del__` + # b. Exit `pin_memory_thread` + # i. Set `pin_memory_thread_done_event`. + # ii Put `None` in `worker_result_queue`. + # iii. Join the `pin_memory_thread`. + # iv. `worker_result_queue.cancel_join_thread()`. + # + # c. Exit the workers. + # i. Set `workers_done_event`. + # ii. Put `None` in each worker's `index_queue`. + # iii. Join the workers. + # iv. Call `.cancel_join_thread()` on each worker's `index_queue`. + # + # NOTE: (c) is better placed after (b) because it may leave corrupted + # data in `worker_result_queue`, which `pin_memory_thread` + # reads from, in which case the `pin_memory_thread` can only + # happen at timeing out, which is slow. Nonetheless, same thing + # happens if a worker is killed by signal at unfortunate times, + # but in other cases, we are better off having a non-corrupted + # `worker_result_queue` for `pin_memory_thread`. + # + # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b) + # can be omitted + # + # NB: `done_event`s isn't strictly needed. E.g., we can just check for + # `None` from `index_queue`, but it allows us to skip wasting resources + # processing indices already in `index_queue` if we are already shutting + # down. + + def __init__(self, loader): + super().__init__(loader) + + self._prefetch_factor = loader.prefetch_factor + + assert self._num_workers > 0 + assert self._prefetch_factor > 0 + + if loader.multiprocessing_context is None: + multiprocessing_context = multiprocessing + else: + multiprocessing_context = loader.multiprocessing_context + + self._worker_init_fn = loader.worker_init_fn + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # Additional worker init function will take care of sharding in MP and Distributed + if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): + self._worker_init_fn = functools.partial( + _sharding_worker_init_fn, + self._worker_init_fn, + self._world_size, + self._rank, + ) + + # No certainty which module multiprocessing_context is + self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + self._worker_pids_set = False + self._shutdown = False + self._workers_done_event = multiprocessing_context.Event() + + self._index_queues = [] + self._workers = [] + for i in range(self._num_workers): + # No certainty which module multiprocessing_context is + index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + # Need to `cancel_join_thread` here! + # See sections (2) and (3b) above. + index_queue.cancel_join_thread() + w = multiprocessing_context.Process( + target=_worker_loop, + args=( + self._dataset_kind, + self._dataset, + index_queue, + self._worker_result_queue, + self._workers_done_event, + self._auto_collation, + self._collate_fn, + self._drop_last, + self._base_seed, + self._worker_init_fn, + i, + self._num_workers, + self._persistent_workers, + self._shared_seed, + ), + ) + w.daemon = True + # NB: Process.start() actually take some time as it needs to + # start a process and pass the arguments over via a pipe. + # Therefore, we only add a worker to self._workers list after + # it started, so that we do not call .join() if program dies + # before it starts, and __del__ tries to join but will get: + # AssertionError: can only join a started process. + w.start() + self._index_queues.append(index_queue) + self._workers.append(w) + + if self._pin_memory: + self._pin_memory_thread_done_event = threading.Event() + + # Queue is not type-annotated + self._data_queue = queue.Queue() # type: ignore[var-annotated] + if self._pin_memory_device == "xpu": + current_device = torch.xpu.current_device() # type: ignore[attr-defined] + else: + current_device = torch.cuda.current_device() # choose cuda for default + pin_memory_thread = threading.Thread( + target=_utils.pin_memory._pin_memory_loop, + args=( + self._worker_result_queue, + self._data_queue, + current_device, + self._pin_memory_thread_done_event, + self._pin_memory_device, + ), + ) + pin_memory_thread.daemon = True + pin_memory_thread.start() + # Similar to workers (see comment above), we only register + # pin_memory_thread once it is started. + self._pin_memory_thread = pin_memory_thread + else: + self._data_queue = self._worker_result_queue + + # In some rare cases, persistent workers (daemonic processes) + # would be terminated before `__del__` of iterator is invoked + # when main process exits + # It would cause failure when pin_memory_thread tries to read + # corrupted data from worker_result_queue + # atexit is used to shutdown thread and child processes in the + # right sequence before main process exits + if self._persistent_workers and self._pin_memory: + import atexit + + for w in self._workers: + atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w) + + # .pid can be None only before process is spawned (not the case, so ignore) + _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] + _utils.signal_handling._set_SIGCHLD_handler() + self._worker_pids_set = True + self._reset(loader, first_iter=True) + + def _reset(self, loader, first_iter=False): + super()._reset(loader, first_iter) + self._send_idx = 0 # idx of the next task to be sent to workers + self._rcvd_idx = 0 # idx of the next task to be returned in __next__ + # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx). + # map: task idx => - (worker_id,) if data isn't fetched (outstanding) + # \ (worker_id, data) if data is already fetched (out-of-order) + self._task_info = {} + self._tasks_outstanding = ( + 0 # always equal to count(v for v in task_info.values() if len(v) == 1) + ) + # A list of booleans representing whether each worker still has work to + # do, i.e., not having exhausted its iterable dataset object. It always + # contains all `True`s if not using an iterable-style dataset + # (i.e., if kind != Iterable). + # Not that this indicates that a worker still has work to do *for this epoch*. + # It does not mean that a worker is dead. In case of `_persistent_workers`, + # the worker will be reset to available in the next epoch. + self._workers_status = [True for i in range(self._num_workers)] + # Reset the worker queue cycle so it resumes next epoch at worker 0 + self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers)) + # We resume the prefetching in case it was enabled + if not first_iter: + for idx in range(self._num_workers): + self._index_queues[idx].put( + _utils.worker._ResumeIteration(self._shared_seed) + ) + resume_iteration_cnt = self._num_workers + while resume_iteration_cnt > 0: + return_idx, return_data = self._get_data() + if isinstance(return_idx, _utils.worker._ResumeIteration): + assert return_data is None + resume_iteration_cnt -= 1 + # prime the prefetch loop + for _ in range(self._prefetch_factor * self._num_workers): + self._try_put_index() + + def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): + # Tries to fetch data from `self._data_queue` once for a given timeout. + # This can also be used as inner loop of fetching without timeout, with + # the sender status as the loop condition. + # + # This raises a `RuntimeError` if any worker died expectedly. This error + # can come from either the SIGCHLD handler in `_utils/signal_handling.py` + # (only for non-Windows platforms), or the manual check below on errors + # and timeouts. + # + # Returns a 2-tuple: + # (bool: whether successfully get data, any: data if successful else None) + try: + data = self._data_queue.get(timeout=timeout) + return (True, data) + except Exception as e: + # At timeout and error, we manually check whether any worker has + # failed. Note that this is the only mechanism for Windows to detect + # worker failures. + failed_workers = [] + for worker_id, w in enumerate(self._workers): + if self._workers_status[worker_id] and not w.is_alive(): + failed_workers.append(w) + self._mark_worker_as_unavailable(worker_id) + if len(failed_workers) > 0: + pids_str = ", ".join(str(w.pid) for w in failed_workers) + raise RuntimeError( + "DataLoader worker (pid(s) {}) exited unexpectedly".format(pids_str) + ) from e + if isinstance(e, queue.Empty): + return (False, None) + import errno + import tempfile + + try: + # Raise an exception if we are this close to the FDs limit. + # Apparently, trying to open only one file is not a sufficient + # test. + # See NOTE [ DataLoader on Linux and open files limit ] + fds_limit_margin = 10 + fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)] + except OSError as e: + if e.errno == errno.EMFILE: + raise RuntimeError( + "Too many open files. Communication with the" + " workers is no longer possible. Please increase the" + " limit using `ulimit -n` in the shell or change the" + " sharing strategy by calling" + " `torch.multiprocessing.set_sharing_strategy('file_system')`" + " at the beginning of your code" + ) from None + raise + + # NOTE [ DataLoader on Linux and open files limit ] + # + # On Linux when DataLoader is used with multiprocessing we pass the data between + # the root process and the workers through SHM files. We remove those files from + # the filesystem as soon as they are created and keep them alive by + # passing around their file descriptors through AF_UNIX sockets. (See + # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in + # the wiki (https://github.com/pytorch/pytorch/wiki).) + # + # This sometimes leads us to exceeding the open files limit. When that happens, + # and the offending file descriptor is coming over a socket, the `socket` Python + # package silently strips the file descriptor from the message, setting only the + # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that + # it _indicates that some control data were discarded due to lack of space in + # the buffer for ancillary data_). This might reflect the C implementation of + # AF_UNIX sockets. + # + # This behaviour can be reproduced with the script and instructions at the + # bottom of this note. + # + # When that happens, the standard Python `multiprocessing` (and not + # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata` + # + # Sometimes, instead of the FD being stripped, you may get an `OSError: + # Too many open files`, both in the script below and in DataLoader. However, + # this is rare and seems to be nondeterministic. + # + # + # #!/usr/bin/env python3 + # import sys + # import socket + # import os + # import array + # import shutil + # import socket + # + # + # if len(sys.argv) != 4: + # print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)") + # sys.exit(1) + # + # if __name__ == '__main__': + # dirname = sys.argv[1] + # sock_path = dirname + "/sock" + # iterations = int(sys.argv[2]) + # def dummy_path(i): + # return dirname + "/" + str(i) + ".dummy" + # + # + # if sys.argv[3] == 'send': + # while not os.path.exists(sock_path): + # pass + # client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + # client.connect(sock_path) + # for i in range(iterations): + # fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT) + # ancdata = array.array('i', [fd]) + # msg = bytes([i % 256]) + # print("Sending fd ", fd, " (iteration #", i, ")") + # client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)]) + # + # + # else: + # assert sys.argv[3] == 'recv' + # + # if os.path.exists(dirname): + # raise Exception("Directory exists") + # + # os.mkdir(dirname) + # + # print("Opening socket...") + # server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + # server.bind(sock_path) + # + # print("Listening...") + # for i in range(iterations): + # a = array.array('i') + # msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize)) + # assert(len(ancdata) == 1) + # cmsg_level, cmsg_type, cmsg_data = ancdata[0] + # a.frombytes(cmsg_data) + # print("Received fd ", a[0], " (iteration #", i, ")") + # + # shutil.rmtree(dirname) + # + # Steps to reproduce: + # + # 1. Run two shells and set lower file descriptor limit in the receiving one: + # (shell1) ulimit -n 1020 + # (shell2) ulimit -n 1022 + # + # 2. Run the script above with the `recv` option in the first shell + # (shell1) ./test_socket.py sock_tmp 1017 recv + # + # 3. Run the script with the `send` option in the second shell: + # (shell2) ./test_socket.py sock_tmp 1017 send + + def _get_data(self): + # Fetches data from `self._data_queue`. + # + # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, + # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)` + # in a loop. This is the only mechanism to detect worker failures for + # Windows. For other platforms, a SIGCHLD handler is also used for + # worker failure detection. + # + # If `pin_memory=True`, we also need check if `pin_memory_thread` had + # died at timeouts. + if self._timeout > 0: + success, data = self._try_get_data(self._timeout) + if success: + return data + else: + raise RuntimeError( + "DataLoader timed out after {} seconds".format(self._timeout) + ) + elif self._pin_memory: + while self._pin_memory_thread.is_alive(): + success, data = self._try_get_data() + if success: + return data + else: + # while condition is false, i.e., pin_memory_thread died. + raise RuntimeError("Pin memory thread exited unexpectedly") + # In this case, `self._data_queue` is a `queue.Queue`,. But we don't + # need to call `.task_done()` because we don't use `.join()`. + else: + while True: + success, data = self._try_get_data() + if success: + return data + + def _next_data(self): + while True: + # If the worker responsible for `self._rcvd_idx` has already ended + # and was unable to fulfill this task (due to exhausting an `IterableDataset`), + # we try to advance `self._rcvd_idx` to find the next valid index. + # + # This part needs to run in the loop because both the `self._get_data()` + # call and `_IterableDatasetStopIteration` check below can mark + # extra worker(s) as dead. + while self._rcvd_idx < self._send_idx: + info = self._task_info[self._rcvd_idx] + worker_id = info[0] + if ( + len(info) == 2 or self._workers_status[worker_id] + ): # has data or is still active + break + del self._task_info[self._rcvd_idx] + self._rcvd_idx += 1 + else: + # no valid `self._rcvd_idx` is found (i.e., didn't break) + if not self._persistent_workers: + self._shutdown_workers() + raise StopIteration + + # Now `self._rcvd_idx` is the batch index we want to fetch + + # Check if the next sample has already been generated + if len(self._task_info[self._rcvd_idx]) == 2: + data = self._task_info.pop(self._rcvd_idx)[1] + return self._process_data(data) + + assert not self._shutdown and self._tasks_outstanding > 0 + idx, data = self._get_data() + self._tasks_outstanding -= 1 + if self._dataset_kind == _DatasetKind.Iterable: + # Check for _IterableDatasetStopIteration + if isinstance(data, _utils.worker._IterableDatasetStopIteration): + if self._persistent_workers: + self._workers_status[data.worker_id] = False + else: + self._mark_worker_as_unavailable(data.worker_id) + self._try_put_index() + continue + + if idx != self._rcvd_idx: + # store out-of-order samples + self._task_info[idx] += (data,) + else: + del self._task_info[idx] + return self._process_data(data) + + def _try_put_index(self): + assert self._tasks_outstanding < self._prefetch_factor * self._num_workers + + try: + index = self._next_index() + except StopIteration: + return + for _ in range(self._num_workers): # find the next active worker, if any + worker_queue_idx = next(self._worker_queue_idx_cycle) + if self._workers_status[worker_queue_idx]: + break + else: + # not found (i.e., didn't break) + return + + self._index_queues[worker_queue_idx].put((self._send_idx, index)) + self._task_info[self._send_idx] = (worker_queue_idx,) + self._tasks_outstanding += 1 + self._send_idx += 1 + + def _process_data(self, data): + self._rcvd_idx += 1 + self._try_put_index() + if isinstance(data, ExceptionWrapper): + data.reraise() + return data + + def _mark_worker_as_unavailable(self, worker_id, shutdown=False): + # Mark a worker as having finished its work e.g., due to + # exhausting an `IterableDataset`. This should be used only when this + # `_MultiProcessingDataLoaderIter` is going to continue running. + + assert self._workers_status[worker_id] or ( + self._persistent_workers and shutdown + ) + + # Signal termination to that specific worker. + q = self._index_queues[worker_id] + # Indicate that no more data will be put on this queue by the current + # process. + q.put(None) + + # Note that we don't actually join the worker here, nor do we remove the + # worker's pid from C side struct because (1) joining may be slow, and + # (2) since we don't join, the worker may still raise error, and we + # prefer capturing those, rather than ignoring them, even though they + # are raised after the worker has finished its job. + # Joinning is deferred to `_shutdown_workers`, which it is called when + # all workers finish their jobs (e.g., `IterableDataset` replicas) or + # when this iterator is garbage collected. + + self._workers_status[worker_id] = False + + assert self._workers_done_event.is_set() == shutdown + + def _shutdown_workers(self): + # Called when shutting down this `_MultiProcessingDataLoaderIter`. + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on + # the logic of this function. + if ( + _utils is None + or _utils.python_exit_status is True + or _utils.python_exit_status is None + ): + # See (2) of the note. If Python is shutting down, do no-op. + return + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + if not self._shutdown: + self._shutdown = True + try: + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + + # Exit `pin_memory_thread` first because exiting workers may leave + # corrupted data in `worker_result_queue` which `pin_memory_thread` + # reads from. + if hasattr(self, "_pin_memory_thread"): + # Use hasattr in case error happens before we set the attribute. + self._pin_memory_thread_done_event.set() + # Send something to pin_memory_thread in case it is waiting + # so that it can wake up and check `pin_memory_thread_done_event` + self._worker_result_queue.put((None, None)) + self._pin_memory_thread.join() + self._worker_result_queue.cancel_join_thread() + self._worker_result_queue.close() + + # Exit workers now. + self._workers_done_event.set() + for worker_id in range(len(self._workers)): + # Get number of workers from `len(self._workers)` instead of + # `self._num_workers` in case we error before starting all + # workers. + # If we are using workers_status with persistent_workers + # we have to shut it down because the worker is paused + if self._persistent_workers or self._workers_status[worker_id]: + self._mark_worker_as_unavailable(worker_id, shutdown=True) + for w in self._workers: + # We should be able to join here, but in case anything went + # wrong, we set a timeout and if the workers fail to join, + # they are killed in the `finally` block. + w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + for q in self._index_queues: + q.cancel_join_thread() + q.close() + finally: + # Even though all this function does is putting into queues that + # we have called `cancel_join_thread` on, weird things can + # happen when a worker is killed by a signal, e.g., hanging in + # `Event.set()`. So we need to guard this with SIGCHLD handler, + # and remove pids from the C side data structure only at the + # end. + # + if self._worker_pids_set: + _utils.signal_handling._remove_worker_pids(id(self)) + self._worker_pids_set = False + for w in self._workers: + if w.is_alive(): + # Existing mechanisms try to make the workers exit + # peacefully, but in case that we unfortunately reach + # here, which we shouldn't, (e.g., pytorch/pytorch#39570), + # we kill the worker. + w.terminate() + + # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter` + @staticmethod + def _clean_up_worker(w): + try: + w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + finally: + if w.is_alive(): + w.terminate() + + def __del__(self): + self._shutdown_workers() diff --git a/src/efficientvit/apps/data_provider/random_resolution/_data_worker.py b/src/efficientvit/apps/data_provider/random_resolution/_data_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..be41a3ec24d4d934a06ffa82eac24e432d27d619 --- /dev/null +++ b/src/efficientvit/apps/data_provider/random_resolution/_data_worker.py @@ -0,0 +1,377 @@ +r""""This file is based on torch/utils/data/_utils/worker.py + +Contains definitions of the methods used by the _BaseDataLoaderIter workers. +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. +""" + +import os +import queue +import random +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union + +import torch +from torch._utils import ExceptionWrapper +from torch.utils.data._utils import (HAS_NUMPY, IS_WINDOWS, + MP_STATUS_CHECK_INTERVAL, signal_handling) + +if TYPE_CHECKING: + from torch.utils.data import Dataset + +from .controller import RRSController + +if IS_WINDOWS: + import ctypes + from ctypes.wintypes import BOOL, DWORD, HANDLE + + # On Windows, the parent ID of the worker process remains unchanged when the manager process + # is gone, and the only way to check it through OS is to let the worker have a process handle + # of the manager and ask if the process status has changed. + class ManagerWatchdog: + def __init__(self): + self.manager_pid = os.getppid() + + # mypy cannot detect this code is windows only + self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined] + self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) + self.kernel32.OpenProcess.restype = HANDLE + self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) + self.kernel32.WaitForSingleObject.restype = DWORD + + # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx + SYNCHRONIZE = 0x00100000 + self.manager_handle = self.kernel32.OpenProcess( + SYNCHRONIZE, 0, self.manager_pid + ) + + if not self.manager_handle: + raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined] + + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx + self.manager_dead = ( + self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0 + ) + return not self.manager_dead + +else: + + class ManagerWatchdog: # type: ignore[no-redef] + def __init__(self): + self.manager_pid = os.getppid() + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + self.manager_dead = os.getppid() != self.manager_pid + return not self.manager_dead + + +_worker_info = None + + +class WorkerInfo: + id: int + num_workers: int + seed: int + dataset: "Dataset" + __initialized = False + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + self.__keys = tuple(kwargs.keys()) + self.__initialized = True + + def __setattr__(self, key, val): + if self.__initialized: + raise RuntimeError( + "Cannot assign attributes to {} objects".format(self.__class__.__name__) + ) + return super().__setattr__(key, val) + + def __repr__(self): + items = [] + for k in self.__keys: + items.append("{}={}".format(k, getattr(self, k))) + return "{}({})".format(self.__class__.__name__, ", ".join(items)) + + +def get_worker_info() -> Optional[WorkerInfo]: + r"""Returns the information about the current + :class:`~torch.utils.data.DataLoader` iterator worker process. + + When called in a worker, this returns an object guaranteed to have the + following attributes: + + * :attr:`id`: the current worker id. + * :attr:`num_workers`: the total number of workers. + * :attr:`seed`: the random seed set for the current worker. This value is + determined by main process RNG and the worker id. See + :class:`~torch.utils.data.DataLoader`'s documentation for more details. + * :attr:`dataset`: the copy of the dataset object in **this** process. Note + that this will be a different object in a different process than the one + in the main process. + + When called in the main process, this returns ``None``. + + .. note:: + When used in a :attr:`worker_init_fn` passed over to + :class:`~torch.utils.data.DataLoader`, this method can be useful to + set up each worker process differently, for instance, using ``worker_id`` + to configure the ``dataset`` object to only read a specific fraction of a + sharded dataset, or use ``seed`` to seed other libraries used in dataset + code. + """ + return _worker_info + + +r"""Dummy class used to signal the end of an IterableDataset""" + + +@dataclass(frozen=True) +class _IterableDatasetStopIteration: + worker_id: int + + +r"""Dummy class used to resume the fetching when worker reuse is enabled""" + + +@dataclass(frozen=True) +class _ResumeIteration: + seed: Optional[int] = None + + +# The function `_generate_state` is adapted from `numpy.random.SeedSequence` +# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx +# It's MIT licensed, here is the copyright: + +# Copyright (c) 2015 Melissa E. O'Neill +# Copyright (c) 2019 NumPy Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +# This function generates an array of int32 as the seed for +# `numpy.random`, in order to prevent state collision due to same +# seed and algorithm for `numpy.random` and `random` modules. +def _generate_state(base_seed, worker_id): + INIT_A = 0x43B0D7E5 + MULT_A = 0x931E8875 + INIT_B = 0x8B51F9DD + MULT_B = 0x58F38DED + MIX_MULT_L = 0xCA01F9DD + MIX_MULT_R = 0x4973F715 + XSHIFT = 4 * 8 // 2 + MASK32 = 0xFFFFFFFF + + entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0] + pool = [0] * 4 + + hash_const_A = INIT_A + + def hash(value): + nonlocal hash_const_A + value = (value ^ hash_const_A) & MASK32 + hash_const_A = (hash_const_A * MULT_A) & MASK32 + value = (value * hash_const_A) & MASK32 + value = (value ^ (value >> XSHIFT)) & MASK32 + return value + + def mix(x, y): + result_x = (MIX_MULT_L * x) & MASK32 + result_y = (MIX_MULT_R * y) & MASK32 + result = (result_x - result_y) & MASK32 + result = (result ^ (result >> XSHIFT)) & MASK32 + return result + + # Add in the entropy to the pool. + for i in range(len(pool)): + pool[i] = hash(entropy[i]) + + # Mix all bits together so late bits can affect earlier bits. + for i_src in range(len(pool)): + for i_dst in range(len(pool)): + if i_src != i_dst: + pool[i_dst] = mix(pool[i_dst], hash(pool[i_src])) + + hash_const_B = INIT_B + state = [] + for i_dst in range(4): + data_val = pool[i_dst] + data_val = (data_val ^ hash_const_B) & MASK32 + hash_const_B = (hash_const_B * MULT_B) & MASK32 + data_val = (data_val * hash_const_B) & MASK32 + data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32 + state.append(data_val) + return state + + +def _worker_loop( + dataset_kind, + dataset, + index_queue, + data_queue, + done_event, + auto_collation, + collate_fn, + drop_last, + base_seed, + init_fn, + worker_id, + num_workers, + persistent_workers, + shared_seed, +): + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + + try: + # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal + # module's handlers are executed after Python returns from C low-level + # handlers, likely when the same fatal signal had already happened + # again. + # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers + signal_handling._set_worker_signal_handlers() + + torch.set_num_threads(1) + seed = base_seed + worker_id + random.seed(seed) + torch.manual_seed(seed) + if HAS_NUMPY: + np_seed = _generate_state(base_seed, worker_id) + import numpy as np + + np.random.seed(np_seed) + + from torch.utils.data import IterDataPipe + from torch.utils.data.graph_settings import apply_random_seed + + shared_rng = torch.Generator() + if isinstance(dataset, IterDataPipe): + assert shared_seed is not None + shared_rng.manual_seed(shared_seed) + dataset = apply_random_seed(dataset, shared_rng) + + global _worker_info + _worker_info = WorkerInfo( + id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset + ) + + from torch.utils.data import _DatasetKind + + init_exception = None + + try: + if init_fn is not None: + init_fn(worker_id) + + fetcher = _DatasetKind.create_fetcher( + dataset_kind, dataset, auto_collation, collate_fn, drop_last + ) + except Exception: + init_exception = ExceptionWrapper( + where="in DataLoader worker process {}".format(worker_id) + ) + + # When using Iterable mode, some worker can exit earlier than others due + # to the IterableDataset behaving differently for different workers. + # When such things happen, an `_IterableDatasetStopIteration` object is + # sent over to the main process with the ID of this worker, so that the + # main process won't send more tasks to this worker, and will send + # `None` to this worker to properly exit it. + # + # Note that we cannot set `done_event` from a worker as it is shared + # among all processes. Instead, we set the `iteration_end` flag to + # signify that the iterator is exhausted. When either `done_event` or + # `iteration_end` is set, we skip all processing step and just wait for + # `None`. + iteration_end = False + + watchdog = ManagerWatchdog() + + while watchdog.is_alive(): + try: + r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + continue + if isinstance(r, _ResumeIteration): + # Acknowledge the main process + data_queue.put((r, None)) + iteration_end = False + + if isinstance(dataset, IterDataPipe): + assert r.seed is not None + shared_rng.manual_seed(r.seed) + dataset = apply_random_seed(dataset, shared_rng) + + # Recreate the fetcher for worker-reuse policy + fetcher = _DatasetKind.create_fetcher( + dataset_kind, dataset, auto_collation, collate_fn, drop_last + ) + continue + elif r is None: + # Received the final signal + assert done_event.is_set() or iteration_end + break + elif done_event.is_set() or iteration_end: + # `done_event` is set. But I haven't received the final signal + # (None) yet. I will keep continuing until get it, and skip the + # processing steps. + continue + idx, index = r + """ Added """ + RRSController.sample_resolution(batch_id=idx) + """ Added """ + data: Union[_IterableDatasetStopIteration, ExceptionWrapper] + if init_exception is not None: + data = init_exception + init_exception = None + else: + try: + data = fetcher.fetch(index) + except Exception as e: + if ( + isinstance(e, StopIteration) + and dataset_kind == _DatasetKind.Iterable + ): + data = _IterableDatasetStopIteration(worker_id) + # Set `iteration_end` + # (1) to save future `next(...)` calls, and + # (2) to avoid sending multiple `_IterableDatasetStopIteration`s. + iteration_end = True + else: + # It is important that we don't store exc_info in a variable. + # `ExceptionWrapper` does the correct thing. + # See NOTE [ Python Traceback Reference Cycle Problem ] + data = ExceptionWrapper( + where="in DataLoader worker process {}".format(worker_id) + ) + data_queue.put((idx, data)) + del data, idx, index, r # save memory + except KeyboardInterrupt: + # Main process will raise KeyboardInterrupt anyways. + pass + if done_event.is_set(): + data_queue.cancel_join_thread() + data_queue.close() diff --git a/src/efficientvit/apps/data_provider/random_resolution/controller.py b/src/efficientvit/apps/data_provider/random_resolution/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..b62134a5e829c8322806366d98760f1d01c30678 --- /dev/null +++ b/src/efficientvit/apps/data_provider/random_resolution/controller.py @@ -0,0 +1,94 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import copy + +import torch +import torchvision.transforms as transforms +import torchvision.transforms.functional as F + +from src.efficientvit.models.utils import torch_random_choices + +__all__ = [ + "RRSController", + "get_interpolate", + "MyRandomResizedCrop", +] + + +class RRSController: + ACTIVE_SIZE = (224, 224) + IMAGE_SIZE_LIST = [(224, 224)] + + CHOICE_LIST = None + + @staticmethod + def get_candidates() -> list[tuple[int, int]]: + return copy.deepcopy(RRSController.IMAGE_SIZE_LIST) + + @staticmethod + def sample_resolution(batch_id: int) -> None: + RRSController.ACTIVE_SIZE = RRSController.CHOICE_LIST[batch_id] + + @staticmethod + def set_epoch(epoch: int, batch_per_epoch: int) -> None: + g = torch.Generator() + g.manual_seed(epoch) + RRSController.CHOICE_LIST = torch_random_choices( + RRSController.get_candidates(), + g, + batch_per_epoch, + ) + + +def get_interpolate(name: str) -> F.InterpolationMode: + mapping = { + "nearest": F.InterpolationMode.NEAREST, + "bilinear": F.InterpolationMode.BILINEAR, + "bicubic": F.InterpolationMode.BICUBIC, + "box": F.InterpolationMode.BOX, + "hamming": F.InterpolationMode.HAMMING, + "lanczos": F.InterpolationMode.LANCZOS, + } + if name in mapping: + return mapping[name] + elif name == "random": + return torch_random_choices( + [ + F.InterpolationMode.NEAREST, + F.InterpolationMode.BILINEAR, + F.InterpolationMode.BICUBIC, + F.InterpolationMode.BOX, + F.InterpolationMode.HAMMING, + F.InterpolationMode.LANCZOS, + ], + ) + else: + raise NotImplementedError + + +class MyRandomResizedCrop(transforms.RandomResizedCrop): + def __init__( + self, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation: str = "random", + ): + super(MyRandomResizedCrop, self).__init__(224, scale, ratio) + self.interpolation = interpolation + + def forward(self, img: torch.Tensor) -> torch.Tensor: + i, j, h, w = self.get_params(img, list(self.scale), list(self.ratio)) + target_size = RRSController.ACTIVE_SIZE + return F.resized_crop( + img, i, j, h, w, list(target_size), get_interpolate(self.interpolation) + ) + + def __repr__(self) -> str: + format_string = self.__class__.__name__ + format_string += f"(\n\tsize={RRSController.get_candidates()},\n" + format_string += f"\tscale={tuple(round(s, 4) for s in self.scale)},\n" + format_string += f"\tratio={tuple(round(r, 4) for r in self.ratio)},\n" + format_string += f"\tinterpolation={self.interpolation})" + return format_string diff --git a/src/efficientvit/apps/setup.py b/src/efficientvit/apps/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..2cd36c4205ad64f4ccb8dfc1efd9e5ebb0ce8e3d --- /dev/null +++ b/src/efficientvit/apps/setup.py @@ -0,0 +1,141 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import os +import time +from copy import deepcopy + +import torch.backends.cudnn +import torch.distributed +import torch.nn as nn + +from src.efficientvit.apps.data_provider import DataProvider +from src.efficientvit.apps.trainer.run_config import RunConfig +from src.efficientvit.apps.utils import (dist_init, dump_config, + get_dist_local_rank, get_dist_rank, + get_dist_size, init_modules, is_master, + load_config, partial_update_config, + zero_last_gamma) +from src.efficientvit.models.utils import (build_kwargs_from_config, + load_state_dict_from_file) + +__all__ = [ + "save_exp_config", + "setup_dist_env", + "setup_seed", + "setup_exp_config", + "setup_data_provider", + "setup_run_config", + "init_model", +] + + +def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None: + if not is_master(): + return + dump_config(exp_config, os.path.join(path, name)) + + +def setup_dist_env(gpu: str or None = None) -> None: + if gpu is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = gpu + if not torch.distributed.is_initialized(): + dist_init() + torch.backends.cudnn.benchmark = True + torch.cuda.set_device(get_dist_local_rank()) + + +def setup_seed(manual_seed: int, resume: bool) -> None: + if resume: + manual_seed = int(time.time()) + manual_seed = get_dist_rank() + manual_seed + torch.manual_seed(manual_seed) + torch.cuda.manual_seed_all(manual_seed) + + +def setup_exp_config( + config_path: str, recursive=True, opt_args: dict or None = None +) -> dict: + # load config + if not os.path.isfile(config_path): + raise ValueError(config_path) + + fpaths = [config_path] + if recursive: + extension = os.path.splitext(config_path)[1] + while os.path.dirname(config_path) != config_path: + config_path = os.path.dirname(config_path) + fpath = os.path.join(config_path, "default" + extension) + if os.path.isfile(fpath): + fpaths.append(fpath) + fpaths = fpaths[::-1] + + default_config = load_config(fpaths[0]) + exp_config = deepcopy(default_config) + for fpath in fpaths[1:]: + partial_update_config(exp_config, load_config(fpath)) + # update config via args + if opt_args is not None: + partial_update_config(exp_config, opt_args) + + return exp_config + + +def setup_data_provider( + exp_config: dict, + data_provider_classes: list[type[DataProvider]], + is_distributed: bool = True, +) -> DataProvider: + dp_config = exp_config["data_provider"] + dp_config["num_replicas"] = get_dist_size() if is_distributed else None + dp_config["rank"] = get_dist_rank() if is_distributed else None + dp_config["test_batch_size"] = ( + dp_config.get("test_batch_size", None) or dp_config["base_batch_size"] * 2 + ) + dp_config["batch_size"] = dp_config["train_batch_size"] = dp_config[ + "base_batch_size" + ] + + data_provider_lookup = { + provider.name: provider for provider in data_provider_classes + } + data_provider_class = data_provider_lookup[dp_config["dataset"]] + + data_provider_kwargs = build_kwargs_from_config(dp_config, data_provider_class) + data_provider = data_provider_class(**data_provider_kwargs) + return data_provider + + +def setup_run_config(exp_config: dict, run_config_cls: type[RunConfig]) -> RunConfig: + exp_config["run_config"]["init_lr"] = ( + exp_config["run_config"]["base_lr"] * get_dist_size() + ) + + run_config = run_config_cls(**exp_config["run_config"]) + + return run_config + + +def init_model( + network: nn.Module, + init_from: str or None = None, + backbone_init_from: str or None = None, + rand_init="trunc_normal", + last_gamma=None, +) -> None: + # initialization + init_modules(network, init_type=rand_init) + # zero gamma of last bn in each block + if last_gamma is not None: + zero_last_gamma(network, last_gamma) + + # load weight + if init_from is not None and os.path.isfile(init_from): + network.load_state_dict(load_state_dict_from_file(init_from)) + print(f"Loaded init from {init_from}") + elif backbone_init_from is not None and os.path.isfile(backbone_init_from): + network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from)) + print(f"Loaded backbone init from {backbone_init_from}") + else: + print(f"Random init ({rand_init}) with last gamma {last_gamma}") diff --git a/src/efficientvit/apps/trainer/__init__.py b/src/efficientvit/apps/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b9219c0c05c23e46926de0988c658b79b72388b --- /dev/null +++ b/src/efficientvit/apps/trainer/__init__.py @@ -0,0 +1,6 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from .base import * +from .run_config import * diff --git a/src/efficientvit/apps/trainer/base.py b/src/efficientvit/apps/trainer/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4f190e460bdb11f55357b5aef069486ced98939b --- /dev/null +++ b/src/efficientvit/apps/trainer/base.py @@ -0,0 +1,297 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import os + +import torch +import torch.nn as nn + +from src.efficientvit.apps.data_provider import DataProvider, parse_image_size +from src.efficientvit.apps.trainer.run_config import RunConfig +from src.efficientvit.apps.utils import (EMA, dist_barrier, get_dist_local_rank, + is_master) +from src.efficientvit.models.nn.norm import reset_bn +from src.efficientvit.models.utils import is_parallel, load_state_dict_from_file + +__all__ = ["Trainer"] + + +class Trainer: + def __init__(self, path: str, model: nn.Module, data_provider: DataProvider): + self.path = os.path.realpath(os.path.expanduser(path)) + self.model = model.cuda() + self.data_provider = data_provider + + self.ema = None + + self.checkpoint_path = os.path.join(self.path, "checkpoint") + self.logs_path = os.path.join(self.path, "logs") + for path in [self.path, self.checkpoint_path, self.logs_path]: + os.makedirs(path, exist_ok=True) + + self.best_val = 0.0 + self.start_epoch = 0 + + @property + def network(self) -> nn.Module: + return self.model.module if is_parallel(self.model) else self.model + + @property + def eval_network(self) -> nn.Module: + if self.ema is None: + model = self.model + else: + model = self.ema.shadows + model = model.module if is_parallel(model) else model + return model + + def write_log(self, log_str, prefix="valid", print_log=True, mode="a") -> None: + if is_master(): + fout = open(os.path.join(self.logs_path, f"{prefix}.log"), mode) + fout.write(log_str + "\n") + fout.flush() + fout.close() + if print_log: + print(log_str) + + def save_model( + self, + checkpoint=None, + only_state_dict=True, + epoch=0, + model_name=None, + ) -> None: + if is_master(): + if checkpoint is None: + if only_state_dict: + checkpoint = {"state_dict": self.network.state_dict()} + else: + checkpoint = { + "state_dict": self.network.state_dict(), + "epoch": epoch, + "best_val": self.best_val, + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "ema": self.ema.state_dict() if self.ema is not None else None, + "scaler": self.scaler.state_dict() if self.fp16 else None, + } + + model_name = model_name or "checkpoint.pt" + + latest_fname = os.path.join(self.checkpoint_path, "latest.txt") + model_path = os.path.join(self.checkpoint_path, model_name) + with open(latest_fname, "w") as _fout: + _fout.write(model_path + "\n") + torch.save(checkpoint, model_path) + + def load_model(self, model_fname=None) -> None: + latest_fname = os.path.join(self.checkpoint_path, "latest.txt") + if model_fname is None and os.path.exists(latest_fname): + with open(latest_fname, "r") as fin: + model_fname = fin.readline() + if len(model_fname) > 0 and model_fname[-1] == "\n": + model_fname = model_fname[:-1] + try: + if model_fname is None: + model_fname = f"{self.checkpoint_path}/checkpoint.pt" + elif not os.path.exists(model_fname): + model_fname = f"{self.checkpoint_path}/{os.path.basename(model_fname)}" + if not os.path.exists(model_fname): + model_fname = f"{self.checkpoint_path}/checkpoint.pt" + print(f"=> loading checkpoint {model_fname}") + checkpoint = load_state_dict_from_file(model_fname, False) + except Exception: + self.write_log(f"fail to load checkpoint from {self.checkpoint_path}") + return + + # load checkpoint + self.network.load_state_dict(checkpoint["state_dict"], strict=False) + log = [] + if "epoch" in checkpoint: + self.start_epoch = checkpoint["epoch"] + 1 + self.run_config.update_global_step(self.start_epoch) + log.append(f"epoch={self.start_epoch - 1}") + if "best_val" in checkpoint: + self.best_val = checkpoint["best_val"] + log.append(f"best_val={self.best_val:.2f}") + if "optimizer" in checkpoint: + self.optimizer.load_state_dict(checkpoint["optimizer"]) + log.append("optimizer") + if "lr_scheduler" in checkpoint: + self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + log.append("lr_scheduler") + if "ema" in checkpoint and self.ema is not None: + self.ema.load_state_dict(checkpoint["ema"]) + log.append("ema") + if "scaler" in checkpoint and self.fp16: + self.scaler.load_state_dict(checkpoint["scaler"]) + log.append("scaler") + self.write_log("Loaded: " + ", ".join(log)) + + """ validate """ + + def reset_bn( + self, + network: nn.Module or None = None, + subset_size: int = 16000, + subset_batch_size: int = 100, + data_loader=None, + progress_bar=False, + ) -> None: + network = network or self.network + if data_loader is None: + data_loader = [] + for data in self.data_provider.build_sub_train_loader( + subset_size, subset_batch_size + ): + if isinstance(data, list): + data_loader.append(data[0]) + elif isinstance(data, dict): + data_loader.append(data["data"]) + elif isinstance(data, torch.Tensor): + data_loader.append(data) + else: + raise NotImplementedError + + network.eval() + reset_bn( + network, + data_loader, + sync=True, + progress_bar=progress_bar, + ) + + def _validate(self, model, data_loader, epoch) -> dict[str, any]: + raise NotImplementedError + + def validate( + self, model=None, data_loader=None, is_test=True, epoch=0 + ) -> dict[str, any]: + model = model or self.eval_network + if data_loader is None: + if is_test: + data_loader = self.data_provider.test + else: + data_loader = self.data_provider.valid + + model.eval() + return self._validate(model, data_loader, epoch) + + def multires_validate( + self, + model=None, + data_loader=None, + is_test=True, + epoch=0, + eval_image_size=None, + ) -> dict[str, dict[str, any]]: + eval_image_size = eval_image_size or self.run_config.eval_image_size + eval_image_size = eval_image_size or self.data_provider.image_size + model = model or self.eval_network + + if not isinstance(eval_image_size, list): + eval_image_size = [eval_image_size] + + output_dict = {} + for r in eval_image_size: + self.data_provider.assign_active_image_size(parse_image_size(r)) + if self.run_config.reset_bn: + self.reset_bn( + network=model, + subset_size=self.run_config.reset_bn_size, + subset_batch_size=self.run_config.reset_bn_batch_size, + progress_bar=True, + ) + output_dict[f"r{r}"] = self.validate(model, data_loader, is_test, epoch) + return output_dict + + """ training """ + + def prep_for_training( + self, run_config: RunConfig, ema_decay: float or None = None, fp16=False + ) -> None: + self.run_config = run_config + self.model = nn.parallel.DistributedDataParallel( + self.model.cuda(), + device_ids=[get_dist_local_rank()], + static_graph=True, + ) + + self.run_config.global_step = 0 + self.run_config.batch_per_epoch = len(self.data_provider.train) + assert self.run_config.batch_per_epoch > 0, "Training set is empty" + + # build optimizer + self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model) + + if ema_decay is not None: + self.ema = EMA(self.network, ema_decay) + + # fp16 + self.fp16 = fp16 + self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16) + + def sync_model(self): + print("Sync model") + self.save_model(model_name="sync.pt") + dist_barrier() + checkpoint = torch.load( + os.path.join(self.checkpoint_path, "sync.pt"), map_location="cpu" + ) + dist_barrier() + if is_master(): + os.remove(os.path.join(self.checkpoint_path, "sync.pt")) + dist_barrier() + + # load checkpoint + self.network.load_state_dict(checkpoint["state_dict"], strict=False) + if "optimizer" in checkpoint: + self.optimizer.load_state_dict(checkpoint["optimizer"]) + if "lr_scheduler" in checkpoint: + self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + if "ema" in checkpoint and self.ema is not None: + self.ema.load_state_dict(checkpoint["ema"]) + if "scaler" in checkpoint and self.fp16: + self.scaler.load_state_dict(checkpoint["scaler"]) + + def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]: + for key in feed_dict: + if isinstance(feed_dict[key], torch.Tensor): + feed_dict[key] = feed_dict[key].cuda() + return feed_dict + + def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]: + raise NotImplementedError + + def after_step(self) -> None: + self.scaler.unscale_(self.optimizer) + # gradient clip + if self.run_config.grad_clip is not None: + torch.nn.utils.clip_grad_value_( + self.model.parameters(), self.run_config.grad_clip + ) + # update + self.scaler.step(self.optimizer) + self.scaler.update() + + self.lr_scheduler.step() + self.run_config.step() + # update ema + if self.ema is not None: + self.ema.step(self.network, self.run_config.global_step) + + def _train_one_epoch(self, epoch: int) -> dict[str, any]: + raise NotImplementedError + + def train_one_epoch(self, epoch: int) -> dict[str, any]: + self.model.train() + + self.data_provider.set_epoch(epoch) + + train_info_dict = self._train_one_epoch(epoch) + + return train_info_dict + + def train(self) -> None: + raise NotImplementedError diff --git a/src/efficientvit/apps/trainer/run_config.py b/src/efficientvit/apps/trainer/run_config.py new file mode 100644 index 0000000000000000000000000000000000000000..6a233f3908aaf0cf7829a02c0073f348f14eed10 --- /dev/null +++ b/src/efficientvit/apps/trainer/run_config.py @@ -0,0 +1,121 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import json + +import numpy as np +import torch.nn as nn + +from src.efficientvit.apps.utils import CosineLRwithWarmup, build_optimizer + +__all__ = ["Scheduler", "RunConfig"] + + +class Scheduler: + PROGRESS = 0 + + +class RunConfig: + n_epochs: int + init_lr: float + warmup_epochs: int + warmup_lr: float + lr_schedule_name: str + lr_schedule_param: dict + optimizer_name: str + optimizer_params: dict + weight_decay: float + no_wd_keys: list + grad_clip: float # allow none to turn off grad clipping + reset_bn: bool + reset_bn_size: int + reset_bn_batch_size: int + eval_image_size: list # allow none to use image_size in data_provider + + @property + def none_allowed(self): + return ["grad_clip", "eval_image_size"] + + def __init__(self, **kwargs): # arguments must be passed as kwargs + for k, val in kwargs.items(): + setattr(self, k, val) + + # check that all relevant configs are there + annotations = {} + for clas in type(self).mro(): + if hasattr(clas, "__annotations__"): + annotations.update(clas.__annotations__) + for k, k_type in annotations.items(): + assert hasattr( + self, k + ), f"Key {k} with type {k_type} required for initialization." + attr = getattr(self, k) + if k in self.none_allowed: + k_type = (k_type, type(None)) + assert isinstance( + attr, k_type + ), f"Key {k} must be type {k_type}, provided={attr}." + + self.global_step = 0 + self.batch_per_epoch = 1 + + def build_optimizer(self, network: nn.Module) -> tuple[any, any]: + r"""require setting 'batch_per_epoch' before building optimizer & lr_scheduler""" + param_dict = {} + for name, param in network.named_parameters(): + if param.requires_grad: + opt_config = [self.weight_decay, self.init_lr] + if self.no_wd_keys is not None and len(self.no_wd_keys) > 0: + if np.any([key in name for key in self.no_wd_keys]): + opt_config[0] = 0 + opt_key = json.dumps(opt_config) + param_dict[opt_key] = param_dict.get(opt_key, []) + [param] + + net_params = [] + for opt_key, param_list in param_dict.items(): + wd, lr = json.loads(opt_key) + net_params.append({"params": param_list, "weight_decay": wd, "lr": lr}) + + optimizer = build_optimizer( + net_params, self.optimizer_name, self.optimizer_params, self.init_lr + ) + # build lr scheduler + if self.lr_schedule_name == "cosine": + decay_steps = [] + for epoch in self.lr_schedule_param.get("step", []): + decay_steps.append(epoch * self.batch_per_epoch) + decay_steps.append(self.n_epochs * self.batch_per_epoch) + decay_steps.sort() + lr_scheduler = CosineLRwithWarmup( + optimizer, + self.warmup_epochs * self.batch_per_epoch, + self.warmup_lr, + decay_steps, + ) + else: + raise NotImplementedError + return optimizer, lr_scheduler + + def update_global_step(self, epoch, batch_id=0) -> None: + self.global_step = epoch * self.batch_per_epoch + batch_id + Scheduler.PROGRESS = self.progress + + @property + def progress(self) -> float: + warmup_steps = self.warmup_epochs * self.batch_per_epoch + steps = max(0, self.global_step - warmup_steps) + return steps / (self.n_epochs * self.batch_per_epoch) + + def step(self) -> None: + self.global_step += 1 + Scheduler.PROGRESS = self.progress + + def get_remaining_epoch(self, epoch, post=True) -> int: + return self.n_epochs + self.warmup_epochs - epoch - int(post) + + def epoch_format(self, epoch: int) -> str: + epoch_format = f"%.{len(str(self.n_epochs))}d" + epoch_format = f"[{epoch_format}/{epoch_format}]" + epoch_format = epoch_format % (epoch + 1 - self.warmup_epochs, self.n_epochs) + return epoch_format diff --git a/src/efficientvit/apps/utils/__init__.py b/src/efficientvit/apps/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c826a22544285746c588741f3f20fbe3802ccd50 --- /dev/null +++ b/src/efficientvit/apps/utils/__init__.py @@ -0,0 +1,12 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from .dist import * +from .ema import * +from .export import * +from .init import * +from .lr import * +from .metric import * +from .misc import * +from .opt import * diff --git a/src/efficientvit/apps/utils/dist.py b/src/efficientvit/apps/utils/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..9b55142678fa168d67dc33fff6907d3b8c87a485 --- /dev/null +++ b/src/efficientvit/apps/utils/dist.py @@ -0,0 +1,73 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import os + +import torch +import torch.distributed + +from src.efficientvit.models.utils.list import list_mean, list_sum + +__all__ = [ + "dist_init", + "get_dist_rank", + "get_dist_size", + "is_master", + "dist_barrier", + "get_dist_local_rank", + "sync_tensor", +] + + +def dist_init() -> None: + try: + torch.distributed.init_process_group(backend="nccl") + assert torch.distributed.is_initialized() + except Exception: + # use torchpack + from torchpack import distributed as dist + + dist.init() + os.environ["RANK"] = f"{dist.rank()}" + os.environ["WORLD_SIZE"] = f"{dist.size()}" + os.environ["LOCAL_RANK"] = f"{dist.local_rank()}" + + +def get_dist_rank() -> int: + return int(os.environ["RANK"]) + + +def get_dist_size() -> int: + return int(os.environ["WORLD_SIZE"]) + + +def is_master() -> bool: + return get_dist_rank() == 0 + + +def dist_barrier() -> None: + torch.distributed.barrier() + + +def get_dist_local_rank() -> int: + return int(os.environ["LOCAL_RANK"]) + + +def sync_tensor( + tensor: torch.Tensor or float, reduce="mean" +) -> torch.Tensor or list[torch.Tensor]: + if not isinstance(tensor, torch.Tensor): + tensor = torch.Tensor(1).fill_(tensor).cuda() + tensor_list = [torch.empty_like(tensor) for _ in range(get_dist_size())] + torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False) + if reduce == "mean": + return list_mean(tensor_list) + elif reduce == "sum": + return list_sum(tensor_list) + elif reduce == "cat": + return torch.cat(tensor_list, dim=0) + elif reduce == "root": + return tensor_list[0] + else: + return tensor_list diff --git a/src/efficientvit/apps/utils/ema.py b/src/efficientvit/apps/utils/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..9de7f7fee67840f98ec97bf759dd0a390618a576 --- /dev/null +++ b/src/efficientvit/apps/utils/ema.py @@ -0,0 +1,50 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import copy +import math + +import torch +import torch.nn as nn + +from src.efficientvit.models.utils import is_parallel + +__all__ = ["EMA"] + + +def update_ema( + ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float +) -> None: + for k, v in ema.state_dict().items(): + if v.dtype.is_floating_point: + v -= (1.0 - decay) * (v - new_state_dict[k].detach()) + + +class EMA: + def __init__(self, model: nn.Module, decay: float, warmup_steps=2000): + self.shadows = copy.deepcopy( + model.module if is_parallel(model) else model + ).eval() + self.decay = decay + self.warmup_steps = warmup_steps + + for p in self.shadows.parameters(): + p.requires_grad = False + + def step(self, model: nn.Module, global_step: int) -> None: + with torch.no_grad(): + msd = (model.module if is_parallel(model) else model).state_dict() + update_ema( + self.shadows, + msd, + self.decay * (1 - math.exp(-global_step / self.warmup_steps)), + ) + + def state_dict(self) -> dict[float, dict[str, torch.Tensor]]: + return {self.decay: self.shadows.state_dict()} + + def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None: + for decay in state_dict: + if decay == self.decay: + self.shadows.load_state_dict(state_dict[decay]) diff --git a/src/efficientvit/apps/utils/export.py b/src/efficientvit/apps/utils/export.py new file mode 100644 index 0000000000000000000000000000000000000000..d611f957a6ff22b98210d611e7344426e091d3df --- /dev/null +++ b/src/efficientvit/apps/utils/export.py @@ -0,0 +1,47 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import io +import os + +import onnx +import torch +import torch.nn as nn +from onnxsim import simplify as simplify_func + +__all__ = ["export_onnx"] + + +def export_onnx( + model: nn.Module, export_path: str, sample_inputs: any, simplify=True, opset=11 +) -> None: + """Export a model to a platform-specific onnx format. + + Args: + model: a torch.nn.Module object. + export_path: export location. + sample_inputs: Any. + simplify: a flag to turn on onnx-simplifier + opset: int + """ + model.eval() + + buffer = io.BytesIO() + with torch.no_grad(): + torch.onnx.export(model, sample_inputs, buffer, opset_version=opset) + buffer.seek(0, 0) + if simplify: + onnx_model = onnx.load_model(buffer) + onnx_model, success = simplify_func(onnx_model) + assert success + new_buffer = io.BytesIO() + onnx.save(onnx_model, new_buffer) + buffer = new_buffer + buffer.seek(0, 0) + + if buffer.getbuffer().nbytes > 0: + save_dir = os.path.dirname(export_path) + os.makedirs(save_dir, exist_ok=True) + with open(export_path, "wb") as f: + f.write(buffer.read()) diff --git a/src/efficientvit/apps/utils/init.py b/src/efficientvit/apps/utils/init.py new file mode 100644 index 0000000000000000000000000000000000000000..2d2ebe26ff45a7ee1de614a39e0db24198097152 --- /dev/null +++ b/src/efficientvit/apps/utils/init.py @@ -0,0 +1,68 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch +import torch.nn as nn +from torch.nn.modules.batchnorm import _BatchNorm + +__all__ = ["init_modules", "zero_last_gamma"] + + +def init_modules(model: nn.Module or list[nn.Module], init_type="trunc_normal") -> None: + _DEFAULT_INIT_PARAM = {"trunc_normal": 0.02} + + if isinstance(model, list): + for sub_module in model: + init_modules(sub_module, init_type) + else: + init_params = init_type.split("@") + init_params = float(init_params[1]) if len(init_params) > 1 else None + + if init_type.startswith("trunc_normal"): + init_func = lambda param: nn.init.trunc_normal_( + param, std=(init_params or _DEFAULT_INIT_PARAM["trunc_normal"]) + ) + else: + raise NotImplementedError + + for m in model.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)): + init_func(m.weight) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Embedding): + init_func(m.weight) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + m.weight.data.fill_(1) + m.bias.data.zero_() + else: + weight = getattr(m, "weight", None) + bias = getattr(m, "bias", None) + if isinstance(weight, torch.nn.Parameter): + init_func(weight) + if isinstance(bias, torch.nn.Parameter): + bias.data.zero_() + + +def zero_last_gamma(model: nn.Module, init_val=0) -> None: + import efficientvit.models.nn.ops as ops + + for m in model.modules(): + if isinstance(m, ops.ResidualBlock) and isinstance( + m.shortcut, ops.IdentityLayer + ): + if isinstance(m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv)): + parent_module = m.main.point_conv + elif isinstance(m.main, ops.ResBlock): + parent_module = m.main.conv2 + elif isinstance(m.main, ops.ConvLayer): + parent_module = m.main + elif isinstance(m.main, (ops.LiteMLA)): + parent_module = m.main.proj + else: + parent_module = None + if parent_module is not None: + norm = getattr(parent_module, "norm", None) + if norm is not None: + nn.init.constant_(norm.weight, init_val) diff --git a/src/efficientvit/apps/utils/lr.py b/src/efficientvit/apps/utils/lr.py new file mode 100644 index 0000000000000000000000000000000000000000..fe10c6aee6ce3be5afdc766bbc028829d492903a --- /dev/null +++ b/src/efficientvit/apps/utils/lr.py @@ -0,0 +1,48 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import math + +import torch + +from src.efficientvit.models.utils.list import val2list + +__all__ = ["CosineLRwithWarmup"] + + +class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + warmup_steps: int, + warmup_lr: float, + decay_steps: int or list[int], + last_epoch: int = -1, + ) -> None: + self.warmup_steps = warmup_steps + self.warmup_lr = warmup_lr + self.decay_steps = val2list(decay_steps) + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> list[float]: + if self.last_epoch < self.warmup_steps: + return [ + (base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps + + self.warmup_lr + for base_lr in self.base_lrs + ] + else: + current_steps = self.last_epoch - self.warmup_steps + decay_steps = [0] + self.decay_steps + idx = len(decay_steps) - 2 + for i, decay_step in enumerate(decay_steps[:-1]): + if decay_step <= current_steps < decay_steps[i + 1]: + idx = i + break + current_steps -= decay_steps[idx] + decay_step = decay_steps[idx + 1] - decay_steps[idx] + return [ + 0.5 * base_lr * (1 + math.cos(math.pi * current_steps / decay_step)) + for base_lr in self.base_lrs + ] diff --git a/src/efficientvit/apps/utils/metric.py b/src/efficientvit/apps/utils/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..9e656679c8fcf4d3c6320fc10de72d47310f7cf2 --- /dev/null +++ b/src/efficientvit/apps/utils/metric.py @@ -0,0 +1,37 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch + +from src.efficientvit.apps.utils.dist import sync_tensor + +__all__ = ["AverageMeter"] + + +class AverageMeter: + """Computes and stores the average and current value.""" + + def __init__(self, is_distributed=True): + self.is_distributed = is_distributed + self.sum = 0 + self.count = 0 + + def _sync(self, val: torch.Tensor or int or float) -> torch.Tensor or int or float: + return sync_tensor(val, reduce="sum") if self.is_distributed else val + + def update(self, val: torch.Tensor or int or float, delta_n=1): + self.count += self._sync(delta_n) + self.sum += self._sync(val * delta_n) + + def get_count(self) -> torch.Tensor or int or float: + return ( + self.count.item() + if isinstance(self.count, torch.Tensor) and self.count.numel() == 1 + else self.count + ) + + @property + def avg(self): + avg = -1 if self.count == 0 else self.sum / self.count + return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg diff --git a/src/efficientvit/apps/utils/misc.py b/src/efficientvit/apps/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..111b6618ab20bd02b5b6d8785091122c82fc8a24 --- /dev/null +++ b/src/efficientvit/apps/utils/misc.py @@ -0,0 +1,111 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import os + +import yaml + +__all__ = [ + "parse_with_yaml", + "parse_unknown_args", + "partial_update_config", + "resolve_and_load_config", + "load_config", + "dump_config", +] + + +def parse_with_yaml(config_str: str) -> str or dict: + try: + # add space manually for dict + if "{" in config_str and "}" in config_str and ":" in config_str: + out_str = config_str.replace(":", ": ") + else: + out_str = config_str + return yaml.safe_load(out_str) + except ValueError: + # return raw string if parsing fails + return config_str + + +def parse_unknown_args(unknown: list) -> dict: + """Parse unknown args.""" + index = 0 + parsed_dict = {} + while index < len(unknown): + key, val = unknown[index], unknown[index + 1] + index += 2 + if not key.startswith("--"): + continue + key = key[2:] + + # try parsing with either dot notation or full yaml notation + # Note that the vanilla case "--key value" will be parsed the same + if "." in key: + # key == a.b.c, val == val --> parsed_dict[a][b][c] = val + keys = key.split(".") + dict_to_update = parsed_dict + for key in keys[:-1]: + if not ( + key in dict_to_update and isinstance(dict_to_update[key], dict) + ): + dict_to_update[key] = {} + dict_to_update = dict_to_update[key] + dict_to_update[keys[-1]] = parse_with_yaml( + val + ) # so we can parse lists, bools, etc... + else: + parsed_dict[key] = parse_with_yaml(val) + return parsed_dict + + +def partial_update_config(config: dict, partial_config: dict) -> dict: + for key in partial_config: + if ( + key in config + and isinstance(partial_config[key], dict) + and isinstance(config[key], dict) + ): + partial_update_config(config[key], partial_config[key]) + else: + config[key] = partial_config[key] + return config + + +def resolve_and_load_config(path: str, config_name="config.yaml") -> dict: + path = os.path.realpath(os.path.expanduser(path)) + if os.path.isdir(path): + config_path = os.path.join(path, config_name) + else: + config_path = path + if os.path.isfile(config_path): + pass + else: + raise Exception(f"Cannot find a valid config at {path}") + config = load_config(config_path) + return config + + +class SafeLoaderWithTuple(yaml.SafeLoader): + """A yaml safe loader with python tuple loading capabilities.""" + + def construct_python_tuple(self, node): + return tuple(self.construct_sequence(node)) + + +SafeLoaderWithTuple.add_constructor( + "tag:yaml.org,2002:python/tuple", SafeLoaderWithTuple.construct_python_tuple +) + + +def load_config(filename: str) -> dict: + """Load a yaml file.""" + filename = os.path.realpath(os.path.expanduser(filename)) + return yaml.load(open(filename), Loader=SafeLoaderWithTuple) + + +def dump_config(config: dict, filename: str) -> None: + """Dump a config file""" + filename = os.path.realpath(os.path.expanduser(filename)) + yaml.dump(config, open(filename, "w"), sort_keys=False) diff --git a/src/efficientvit/apps/utils/opt.py b/src/efficientvit/apps/utils/opt.py new file mode 100644 index 0000000000000000000000000000000000000000..79a03507c8b0aa8ad6e7210657630d5af6555521 --- /dev/null +++ b/src/efficientvit/apps/utils/opt.py @@ -0,0 +1,31 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch + +__all__ = ["REGISTERED_OPTIMIZER_DICT", "build_optimizer"] + +# register optimizer here +# name: optimizer, kwargs with default values +REGISTERED_OPTIMIZER_DICT: dict[str, tuple[type, dict[str, any]]] = { + "sgd": (torch.optim.SGD, {"momentum": 0.9, "nesterov": True}), + "adam": (torch.optim.Adam, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}), + "adamw": ( + torch.optim.AdamW, + {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}, + ), +} + + +def build_optimizer( + net_params, optimizer_name: str, optimizer_params: dict or None, init_lr: float +) -> torch.optim.Optimizer: + optimizer_class, default_params = REGISTERED_OPTIMIZER_DICT[optimizer_name] + optimizer_params = optimizer_params or {} + + for key in default_params: + if key in optimizer_params: + default_params[key] = optimizer_params[key] + optimizer = optimizer_class(net_params, init_lr, **default_params) + return optimizer diff --git a/src/efficientvit/models/__init__.py b/src/efficientvit/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/efficientvit/models/efficientvit/__init__.py b/src/efficientvit/models/efficientvit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cea677f24763b605249c05ea37483b579c507cbc --- /dev/null +++ b/src/efficientvit/models/efficientvit/__init__.py @@ -0,0 +1,8 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from .backbone import * +from .cls import * +from .sam import * +from .seg import * diff --git a/src/efficientvit/models/efficientvit/backbone.py b/src/efficientvit/models/efficientvit/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..2c861a5f1486d243bd8eb0d120f6d646dfe8615e --- /dev/null +++ b/src/efficientvit/models/efficientvit/backbone.py @@ -0,0 +1,372 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch +import torch.nn as nn + +from src.efficientvit.models.nn import (ConvLayer, DSConv, EfficientViTBlock, + FusedMBConv, IdentityLayer, MBConv, + OpSequential, ResBlock, ResidualBlock) +from src.efficientvit.models.utils import build_kwargs_from_config + +__all__ = [ + "EfficientViTBackbone", + "efficientvit_backbone_b0", + "efficientvit_backbone_b1", + "efficientvit_backbone_b2", + "efficientvit_backbone_b3", + "EfficientViTLargeBackbone", + "efficientvit_backbone_l0", + "efficientvit_backbone_l1", + "efficientvit_backbone_l2", + "efficientvit_backbone_l3", +] + + +class EfficientViTBackbone(nn.Module): + def __init__( + self, + width_list: list[int], + depth_list: list[int], + in_channels=3, + dim=32, + expand_ratio=4, + norm="bn2d", + act_func="hswish", + ) -> None: + super().__init__() + + self.width_list = [] + # input stem + self.input_stem = [ + ConvLayer( + in_channels=3, + out_channels=width_list[0], + stride=2, + norm=norm, + act_func=act_func, + ) + ] + for _ in range(depth_list[0]): + block = self.build_local_block( + in_channels=width_list[0], + out_channels=width_list[0], + stride=1, + expand_ratio=1, + norm=norm, + act_func=act_func, + ) + self.input_stem.append(ResidualBlock(block, IdentityLayer())) + in_channels = width_list[0] + self.input_stem = OpSequential(self.input_stem) + self.width_list.append(in_channels) + + # stages + self.stages = [] + for w, d in zip(width_list[1:3], depth_list[1:3]): + stage = [] + for i in range(d): + stride = 2 if i == 0 else 1 + block = self.build_local_block( + in_channels=in_channels, + out_channels=w, + stride=stride, + expand_ratio=expand_ratio, + norm=norm, + act_func=act_func, + ) + block = ResidualBlock(block, IdentityLayer() if stride == 1 else None) + stage.append(block) + in_channels = w + self.stages.append(OpSequential(stage)) + self.width_list.append(in_channels) + + for w, d in zip(width_list[3:], depth_list[3:]): + stage = [] + block = self.build_local_block( + in_channels=in_channels, + out_channels=w, + stride=2, + expand_ratio=expand_ratio, + norm=norm, + act_func=act_func, + fewer_norm=True, + ) + stage.append(ResidualBlock(block, None)) + in_channels = w + + for _ in range(d): + stage.append( + EfficientViTBlock( + in_channels=in_channels, + dim=dim, + expand_ratio=expand_ratio, + norm=norm, + act_func=act_func, + ) + ) + self.stages.append(OpSequential(stage)) + self.width_list.append(in_channels) + self.stages = nn.ModuleList(self.stages) + + @staticmethod + def build_local_block( + in_channels: int, + out_channels: int, + stride: int, + expand_ratio: float, + norm: str, + act_func: str, + fewer_norm: bool = False, + ) -> nn.Module: + if expand_ratio == 1: + block = DSConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + use_bias=(True, False) if fewer_norm else False, + norm=(None, norm) if fewer_norm else norm, + act_func=(act_func, None), + ) + else: + block = MBConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + expand_ratio=expand_ratio, + use_bias=(True, True, False) if fewer_norm else False, + norm=(None, None, norm) if fewer_norm else norm, + act_func=(act_func, act_func, None), + ) + return block + + def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: + output_dict = {"input": x} + output_dict["stage0"] = x = self.input_stem(x) + for stage_id, stage in enumerate(self.stages, 1): + output_dict["stage%d" % stage_id] = x = stage(x) + output_dict["stage_final"] = x + return output_dict + + +def efficientvit_backbone_b0(**kwargs) -> EfficientViTBackbone: + backbone = EfficientViTBackbone( + width_list=[8, 16, 32, 64, 128], + depth_list=[1, 2, 2, 2, 2], + dim=16, + **build_kwargs_from_config(kwargs, EfficientViTBackbone), + ) + return backbone + + +def efficientvit_backbone_b1(**kwargs) -> EfficientViTBackbone: + backbone = EfficientViTBackbone( + width_list=[16, 32, 64, 128, 256], + depth_list=[1, 2, 3, 3, 4], + dim=16, + **build_kwargs_from_config(kwargs, EfficientViTBackbone), + ) + return backbone + + +def efficientvit_backbone_b2(**kwargs) -> EfficientViTBackbone: + backbone = EfficientViTBackbone( + width_list=[24, 48, 96, 192, 384], + depth_list=[1, 3, 4, 4, 6], + dim=32, + **build_kwargs_from_config(kwargs, EfficientViTBackbone), + ) + return backbone + + +def efficientvit_backbone_b3(**kwargs) -> EfficientViTBackbone: + backbone = EfficientViTBackbone( + width_list=[32, 64, 128, 256, 512], + depth_list=[1, 4, 6, 6, 9], + dim=32, + **build_kwargs_from_config(kwargs, EfficientViTBackbone), + ) + return backbone + + +class EfficientViTLargeBackbone(nn.Module): + def __init__( + self, + width_list: list[int], + depth_list: list[int], + block_list: list[str] or None = None, + expand_list: list[float] or None = None, + fewer_norm_list: list[bool] or None = None, + in_channels=3, + qkv_dim=32, + norm="bn2d", + act_func="gelu", + ) -> None: + super().__init__() + block_list = block_list or ["res", "fmb", "fmb", "mb", "att"] + expand_list = expand_list or [1, 4, 4, 4, 6] + fewer_norm_list = fewer_norm_list or [False, False, False, True, True] + + self.width_list = [] + self.stages = [] + # stage 0 + stage0 = [ + ConvLayer( + in_channels=3, + out_channels=width_list[0], + stride=2, + norm=norm, + act_func=act_func, + ) + ] + for _ in range(depth_list[0]): + block = self.build_local_block( + block=block_list[0], + in_channels=width_list[0], + out_channels=width_list[0], + stride=1, + expand_ratio=expand_list[0], + norm=norm, + act_func=act_func, + fewer_norm=fewer_norm_list[0], + ) + stage0.append(ResidualBlock(block, IdentityLayer())) + in_channels = width_list[0] + self.stages.append(OpSequential(stage0)) + self.width_list.append(in_channels) + + for stage_id, (w, d) in enumerate(zip(width_list[1:], depth_list[1:]), start=1): + stage = [] + block = self.build_local_block( + block=( + "mb" + if block_list[stage_id] not in ["mb", "fmb"] + else block_list[stage_id] + ), + in_channels=in_channels, + out_channels=w, + stride=2, + expand_ratio=expand_list[stage_id] * 4, + norm=norm, + act_func=act_func, + fewer_norm=fewer_norm_list[stage_id], + ) + stage.append(ResidualBlock(block, None)) + in_channels = w + + for _ in range(d): + if block_list[stage_id].startswith("att"): + stage.append( + EfficientViTBlock( + in_channels=in_channels, + dim=qkv_dim, + expand_ratio=expand_list[stage_id], + scales=(3,) if block_list[stage_id] == "att@3" else (5,), + norm=norm, + act_func=act_func, + ) + ) + else: + block = self.build_local_block( + block=block_list[stage_id], + in_channels=in_channels, + out_channels=in_channels, + stride=1, + expand_ratio=expand_list[stage_id], + norm=norm, + act_func=act_func, + fewer_norm=fewer_norm_list[stage_id], + ) + block = ResidualBlock(block, IdentityLayer()) + stage.append(block) + self.stages.append(OpSequential(stage)) + self.width_list.append(in_channels) + self.stages = nn.ModuleList(self.stages) + + @staticmethod + def build_local_block( + block: str, + in_channels: int, + out_channels: int, + stride: int, + expand_ratio: float, + norm: str, + act_func: str, + fewer_norm: bool = False, + ) -> nn.Module: + if block == "res": + block = ResBlock( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + use_bias=(True, False) if fewer_norm else False, + norm=(None, norm) if fewer_norm else norm, + act_func=(act_func, None), + ) + elif block == "fmb": + block = FusedMBConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + expand_ratio=expand_ratio, + use_bias=(True, False) if fewer_norm else False, + norm=(None, norm) if fewer_norm else norm, + act_func=(act_func, None), + ) + elif block == "mb": + block = MBConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + expand_ratio=expand_ratio, + use_bias=(True, True, False) if fewer_norm else False, + norm=(None, None, norm) if fewer_norm else norm, + act_func=(act_func, act_func, None), + ) + else: + raise ValueError(block) + return block + + def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: + output_dict = {"input": x} + for stage_id, stage in enumerate(self.stages): + output_dict["stage%d" % stage_id] = x = stage(x) + output_dict["stage_final"] = x + return output_dict + + +def efficientvit_backbone_l0(**kwargs) -> EfficientViTLargeBackbone: + backbone = EfficientViTLargeBackbone( + width_list=[32, 64, 128, 256, 512], + depth_list=[1, 1, 1, 4, 4], + **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), + ) + return backbone + + +def efficientvit_backbone_l1(**kwargs) -> EfficientViTLargeBackbone: + backbone = EfficientViTLargeBackbone( + width_list=[32, 64, 128, 256, 512], + depth_list=[1, 1, 1, 6, 6], + **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), + ) + return backbone + + +def efficientvit_backbone_l2(**kwargs) -> EfficientViTLargeBackbone: + backbone = EfficientViTLargeBackbone( + width_list=[32, 64, 128, 256, 512], + depth_list=[1, 2, 2, 8, 8], + **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), + ) + return backbone + + +def efficientvit_backbone_l3(**kwargs) -> EfficientViTLargeBackbone: + backbone = EfficientViTLargeBackbone( + width_list=[64, 128, 256, 512, 1024], + depth_list=[1, 2, 2, 8, 8], + **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), + ) + return backbone diff --git a/src/efficientvit/models/efficientvit/cls.py b/src/efficientvit/models/efficientvit/cls.py new file mode 100644 index 0000000000000000000000000000000000000000..98fac3c4be1ac504585197dbee6cea6f087ffddd --- /dev/null +++ b/src/efficientvit/models/efficientvit/cls.py @@ -0,0 +1,174 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch +import torch.nn as nn + +from src.efficientvit.models.efficientvit.backbone import ( + EfficientViTBackbone, EfficientViTLargeBackbone) +from src.efficientvit.models.nn import ConvLayer, LinearLayer, OpSequential +from src.efficientvit.models.utils import build_kwargs_from_config + +__all__ = [ + "EfficientViTCls", + ###################### + "efficientvit_cls_b0", + "efficientvit_cls_b1", + "efficientvit_cls_b2", + "efficientvit_cls_b3", + ###################### + "efficientvit_cls_l1", + "efficientvit_cls_l2", + "efficientvit_cls_l3", +] + + +class ClsHead(OpSequential): + def __init__( + self, + in_channels: int, + width_list: list[int], + n_classes=1000, + dropout=0.0, + norm="bn2d", + act_func="hswish", + fid="stage_final", + ): + ops = [ + ConvLayer(in_channels, width_list[0], 1, norm=norm, act_func=act_func), + nn.AdaptiveAvgPool2d(output_size=1), + LinearLayer( + width_list[0], width_list[1], False, norm="ln", act_func=act_func + ), + LinearLayer(width_list[1], n_classes, True, dropout, None, None), + ] + super().__init__(ops) + + self.fid = fid + + def forward(self, feed_dict: dict[str, torch.Tensor]) -> torch.Tensor: + x = feed_dict[self.fid] + return OpSequential.forward(self, x) + + +class EfficientViTCls(nn.Module): + def __init__( + self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, head: ClsHead + ) -> None: + super().__init__() + self.backbone = backbone + self.head = head + + def forward(self, x: torch.Tensor) -> torch.Tensor: + feed_dict = self.backbone(x) + output = self.head(feed_dict) + return output + + +def efficientvit_cls_b0(**kwargs) -> EfficientViTCls: + from efficientvit.models.efficientvit.backbone import \ + efficientvit_backbone_b0 + + backbone = efficientvit_backbone_b0(**kwargs) + + head = ClsHead( + in_channels=128, + width_list=[1024, 1280], + **build_kwargs_from_config(kwargs, ClsHead), + ) + model = EfficientViTCls(backbone, head) + return model + + +def efficientvit_cls_b1(**kwargs) -> EfficientViTCls: + from efficientvit.models.efficientvit.backbone import \ + efficientvit_backbone_b1 + + backbone = efficientvit_backbone_b1(**kwargs) + + head = ClsHead( + in_channels=256, + width_list=[1536, 1600], + **build_kwargs_from_config(kwargs, ClsHead), + ) + model = EfficientViTCls(backbone, head) + return model + + +def efficientvit_cls_b2(**kwargs) -> EfficientViTCls: + from efficientvit.models.efficientvit.backbone import \ + efficientvit_backbone_b2 + + backbone = efficientvit_backbone_b2(**kwargs) + + head = ClsHead( + in_channels=384, + width_list=[2304, 2560], + **build_kwargs_from_config(kwargs, ClsHead), + ) + model = EfficientViTCls(backbone, head) + return model + + +def efficientvit_cls_b3(**kwargs) -> EfficientViTCls: + from efficientvit.models.efficientvit.backbone import \ + efficientvit_backbone_b3 + + backbone = efficientvit_backbone_b3(**kwargs) + + head = ClsHead( + in_channels=512, + width_list=[2304, 2560], + **build_kwargs_from_config(kwargs, ClsHead), + ) + model = EfficientViTCls(backbone, head) + return model + + +def efficientvit_cls_l1(**kwargs) -> EfficientViTCls: + from efficientvit.models.efficientvit.backbone import \ + efficientvit_backbone_l1 + + backbone = efficientvit_backbone_l1(**kwargs) + + head = ClsHead( + in_channels=512, + width_list=[3072, 3200], + act_func="gelu", + **build_kwargs_from_config(kwargs, ClsHead), + ) + model = EfficientViTCls(backbone, head) + return model + + +def efficientvit_cls_l2(**kwargs) -> EfficientViTCls: + from efficientvit.models.efficientvit.backbone import \ + efficientvit_backbone_l2 + + backbone = efficientvit_backbone_l2(**kwargs) + + head = ClsHead( + in_channels=512, + width_list=[3072, 3200], + act_func="gelu", + **build_kwargs_from_config(kwargs, ClsHead), + ) + model = EfficientViTCls(backbone, head) + return model + + +def efficientvit_cls_l3(**kwargs) -> EfficientViTCls: + from efficientvit.models.efficientvit.backbone import \ + efficientvit_backbone_l3 + + backbone = efficientvit_backbone_l3(**kwargs) + + head = ClsHead( + in_channels=1024, + width_list=[6144, 6400], + act_func="gelu", + **build_kwargs_from_config(kwargs, ClsHead), + ) + model = EfficientViTCls(backbone, head) + return model diff --git a/src/efficientvit/models/efficientvit/sam.py b/src/efficientvit/models/efficientvit/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..fe331ef112b64c09ae338033074737645e67e11d --- /dev/null +++ b/src/efficientvit/models/efficientvit/sam.py @@ -0,0 +1,653 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import copy + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms +from segment_anything import SamAutomaticMaskGenerator +from segment_anything.modeling import (MaskDecoder, PromptEncoder, + TwoWayTransformer) +from segment_anything.modeling.mask_decoder import MaskDecoder +from segment_anything.modeling.prompt_encoder import PromptEncoder +from segment_anything.utils.amg import build_all_layer_point_grids +from segment_anything.utils.transforms import ResizeLongestSide +from torchvision.transforms.functional import resize, to_pil_image + +from src.efficientvit.models.efficientvit.backbone import ( + EfficientViTBackbone, EfficientViTLargeBackbone) +from src.efficientvit.models.nn import (ConvLayer, DAGBlock, FusedMBConv, + IdentityLayer, MBConv, OpSequential, + ResBlock, ResidualBlock, UpSampleLayer, + build_norm) +from src.efficientvit.models.utils import build_kwargs_from_config, get_device + +__all__ = [ + "SamPad", + "SamResize", + "SamNeck", + "EfficientViTSamImageEncoder", + "EfficientViTSam", + "EfficientViTSamPredictor", + "EfficientViTSamAutomaticMaskGenerator", + "efficientvit_sam_l0", + "efficientvit_sam_l1", + "efficientvit_sam_l2", + "efficientvit_sam_xl0", + "efficientvit_sam_xl1", +] + + +class SamPad: + def __init__(self, size: int, fill: float = 0, pad_mode="corner") -> None: + self.size = size + self.fill = fill + self.pad_mode = pad_mode + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + h, w = image.shape[-2:] + th, tw = self.size, self.size + assert th >= h and tw >= w + if self.pad_mode == "corner": + image = F.pad(image, (0, tw - w, 0, th - h), value=self.fill) + else: + raise NotImplementedError + return image + + def __repr__(self) -> str: + return f"{type(self).__name__}(size={self.size},mode={self.pad_mode},fill={self.fill})" + + +class SamResize: + def __init__(self, size: int) -> None: + self.size = size + + def __call__(self, image: np.ndarray) -> np.ndarray: + h, w, _ = image.shape + long_side = max(h, w) + if long_side != self.size: + return self.apply_image(image) + else: + return image + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape( + image.shape[0], image.shape[1], self.size + ) + return np.array(resize(to_pil_image(image), target_size)) + + @staticmethod + def get_preprocess_shape( + oldh: int, oldw: int, long_side_length: int + ) -> tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) + + def __repr__(self) -> str: + return f"{type(self).__name__}(size={self.size})" + + +class SamNeck(DAGBlock): + def __init__( + self, + fid_list: list[str], + in_channel_list: list[int], + head_width: int, + head_depth: int, + expand_ratio: float, + middle_op: str, + out_dim: int = 256, + norm="bn2d", + act_func="gelu", + ): + inputs = {} + for fid, in_channel in zip(fid_list, in_channel_list): + inputs[fid] = OpSequential( + [ + ConvLayer(in_channel, head_width, 1, norm=norm, act_func=None), + UpSampleLayer(size=(64, 64)), + ] + ) + + middle = [] + for _ in range(head_depth): + if middle_op == "mb": + block = MBConv( + head_width, + head_width, + expand_ratio=expand_ratio, + norm=norm, + act_func=(act_func, act_func, None), + ) + elif middle_op == "fmb": + block = FusedMBConv( + head_width, + head_width, + expand_ratio=expand_ratio, + norm=norm, + act_func=(act_func, None), + ) + elif middle_op == "res": + block = ResBlock( + head_width, + head_width, + expand_ratio=expand_ratio, + norm=norm, + act_func=(act_func, None), + ) + else: + raise NotImplementedError + middle.append(ResidualBlock(block, IdentityLayer())) + middle = OpSequential(middle) + + outputs = { + "sam_encoder": OpSequential( + [ + ConvLayer( + head_width, + out_dim, + 1, + use_bias=True, + norm=None, + act_func=None, + ), + ] + ) + } + + super(SamNeck, self).__init__( + inputs, "add", None, middle=middle, outputs=outputs + ) + + +class EfficientViTSamImageEncoder(nn.Module): + def __init__( + self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, neck: SamNeck + ): + super().__init__() + self.backbone = backbone + self.neck = neck + + self.norm = build_norm("ln2d", 256) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + feed_dict = self.backbone(x) + feed_dict = self.neck(feed_dict) + + output = feed_dict["sam_encoder"] + output = self.norm(output) + return output + + +class EfficientViTSam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: EfficientViTSamImageEncoder, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + image_size: tuple[int, int] = (1024, 512), + ) -> None: + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + + self.image_size = image_size + + self.transform = transforms.Compose( + [ + SamResize(self.image_size[1]), + transforms.ToTensor(), + transforms.Normalize( + mean=[123.675 / 255, 116.28 / 255, 103.53 / 255], + std=[58.395 / 255, 57.12 / 255, 57.375 / 255], + ), + SamPad(self.image_size[1]), + ] + ) + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: tuple[int, ...], + original_size: tuple[int, ...], + ) -> torch.Tensor: + masks = F.interpolate( + masks, + (self.image_size[0], self.image_size[0]), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate( + masks, original_size, mode="bilinear", align_corners=False + ) + return masks + + +class EfficientViTSamPredictor: + def __init__(self, sam_model: EfficientViTSam) -> None: + self.model = sam_model + self.reset_image() + + @property + def transform(self): + return self + + @property + def device(self): + return get_device(self.model) + + def reset_image(self) -> None: + self.is_image_set = False + self.features = None + self.original_size = None + self.input_size = None + + def apply_coords(self, coords: np.ndarray, im_size=None) -> np.ndarray: + old_h, old_w = self.original_size + new_h, new_w = self.input_size + coords = copy.deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, im_size=None) -> np.ndarray: + boxes = self.apply_coords(boxes.reshape(-1, 2, 2)) + return boxes.reshape(-1, 4) + + @torch.inference_mode() + def set_image(self, image: np.ndarray, image_format: str = "RGB") -> None: + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + self.reset_image() + + self.original_size = image.shape[:2] + self.input_size = ResizeLongestSide.get_preprocess_shape( + *self.original_size, long_side_length=self.model.image_size[0] + ) + + torch_data = ( + self.model.transform(image).unsqueeze(dim=0).to(get_device(self.model)) + ) + self.features = self.model.image_encoder(torch_data) + self.is_image_set = True + + def predict( + self, + point_coords: np.ndarray or None = None, + point_labels: np.ndarray or None = None, + box: np.ndarray or None = None, + mask_input: np.ndarray or None = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + device = get_device(self.model) + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.apply_coords(point_coords) + coords_torch = torch.as_tensor( + point_coords, dtype=torch.float, device=device + ) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.apply_boxes(box) + box_torch = torch.as_tensor(box, dtype=torch.float, device=device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor( + mask_input, dtype=torch.float, device=device + ) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks = masks[0].detach().cpu().numpy() + iou_predictions = iou_predictions[0].detach().cpu().numpy() + low_res_masks = low_res_masks[0].detach().cpu().numpy() + return masks, iou_predictions, low_res_masks + + @torch.inference_mode() + def predict_torch( + self, + point_coords: torch.Tensor or None = None, + point_labels: torch.Tensor or None = None, + boxes: torch.Tensor or None = None, + mask_input: torch.Tensor or None = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks( + low_res_masks, self.input_size, self.original_size + ) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + +class EfficientViTSamAutomaticMaskGenerator(SamAutomaticMaskGenerator): + def __init__( + self, + model: EfficientViTSam, + points_per_side: int or None = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: list[np.ndarray] or None = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + ) -> None: + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from pycocotools import \ + mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = EfficientViTSamPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + + +def build_efficientvit_sam( + image_encoder: EfficientViTSamImageEncoder, image_size: int +) -> EfficientViTSam: + return EfficientViTSam( + image_encoder=image_encoder, + prompt_encoder=PromptEncoder( + embed_dim=256, + image_embedding_size=(64, 64), + input_image_size=(1024, 1024), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=256, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=256, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + image_size=(1024, image_size), + ) + + +def efficientvit_sam_l0(image_size: int = 512, **kwargs) -> EfficientViTSam: + from efficientvit.models.efficientvit.backbone import \ + efficientvit_backbone_l0 + + backbone = efficientvit_backbone_l0(**kwargs) + + neck = SamNeck( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + head_width=256, + head_depth=4, + expand_ratio=1, + middle_op="fmb", + ) + + image_encoder = EfficientViTSamImageEncoder(backbone, neck) + return build_efficientvit_sam(image_encoder, image_size) + + +def efficientvit_sam_l1(image_size: int = 512, **kwargs) -> EfficientViTSam: + from efficientvit.models.efficientvit.backbone import \ + efficientvit_backbone_l1 + + backbone = efficientvit_backbone_l1(**kwargs) + + neck = SamNeck( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + head_width=256, + head_depth=8, + expand_ratio=1, + middle_op="fmb", + ) + + image_encoder = EfficientViTSamImageEncoder(backbone, neck) + return build_efficientvit_sam(image_encoder, image_size) + + +def efficientvit_sam_l2(image_size: int = 512, **kwargs) -> EfficientViTSam: + from efficientvit.models.efficientvit.backbone import \ + efficientvit_backbone_l2 + + backbone = efficientvit_backbone_l2(**kwargs) + + neck = SamNeck( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + head_width=256, + head_depth=12, + expand_ratio=1, + middle_op="fmb", + ) + + image_encoder = EfficientViTSamImageEncoder(backbone, neck) + return build_efficientvit_sam(image_encoder, image_size) + + +def efficientvit_sam_xl0(image_size: int = 1024, **kwargs) -> EfficientViTSam: + from efficientvit.models.efficientvit.backbone import \ + EfficientViTLargeBackbone + + backbone = EfficientViTLargeBackbone( + width_list=[32, 64, 128, 256, 512, 1024], + depth_list=[0, 1, 1, 2, 3, 3], + block_list=["res", "fmb", "fmb", "fmb", "att@3", "att@3"], + expand_list=[1, 4, 4, 4, 4, 6], + fewer_norm_list=[False, False, False, False, True, True], + **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), + ) + + neck = SamNeck( + fid_list=["stage5", "stage4", "stage3"], + in_channel_list=[1024, 512, 256], + head_width=256, + head_depth=6, + expand_ratio=4, + middle_op="fmb", + ) + + image_encoder = EfficientViTSamImageEncoder(backbone, neck) + return build_efficientvit_sam(image_encoder, image_size) + + +def efficientvit_sam_xl1(image_size: int = 1024, **kwargs) -> EfficientViTSam: + from src.efficientvit.models.efficientvit.backbone import \ + EfficientViTLargeBackbone + + backbone = EfficientViTLargeBackbone( + width_list=[32, 64, 128, 256, 512, 1024], + depth_list=[1, 2, 2, 4, 6, 6], + block_list=["res", "fmb", "fmb", "fmb", "att@3", "att@3"], + expand_list=[1, 4, 4, 4, 4, 6], + fewer_norm_list=[False, False, False, False, True, True], + **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), + ) + + neck = SamNeck( + fid_list=["stage5", "stage4", "stage3"], + in_channel_list=[1024, 512, 256], + head_width=256, + head_depth=12, + expand_ratio=4, + middle_op="fmb", + ) + + image_encoder = EfficientViTSamImageEncoder(backbone, neck) + return build_efficientvit_sam(image_encoder, image_size) diff --git a/src/efficientvit/models/efficientvit/seg.py b/src/efficientvit/models/efficientvit/seg.py new file mode 100644 index 0000000000000000000000000000000000000000..c2e4bf44a8e13a66b63cbf39f74dbee02b9d7045 --- /dev/null +++ b/src/efficientvit/models/efficientvit/seg.py @@ -0,0 +1,355 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch +import torch.nn as nn + +from src.efficientvit.models.efficientvit.backbone import ( + EfficientViTBackbone, EfficientViTLargeBackbone) +from src.efficientvit.models.nn import (ConvLayer, DAGBlock, FusedMBConv, + IdentityLayer, MBConv, OpSequential, + ResidualBlock, UpSampleLayer) +from src.efficientvit.models.utils import build_kwargs_from_config + +__all__ = [ + "EfficientViTSeg", + "efficientvit_seg_b0", + "efficientvit_seg_b1", + "efficientvit_seg_b2", + "efficientvit_seg_b3", + "efficientvit_seg_l1", + "efficientvit_seg_l2", +] + + +class SegHead(DAGBlock): + def __init__( + self, + fid_list: list[str], + in_channel_list: list[int], + stride_list: list[int], + head_stride: int, + head_width: int, + head_depth: int, + expand_ratio: float, + middle_op: str, + final_expand: float or None, + n_classes: int, + dropout=0, + norm="bn2d", + act_func="hswish", + ): + inputs = {} + for fid, in_channel, stride in zip(fid_list, in_channel_list, stride_list): + factor = stride // head_stride + if factor == 1: + inputs[fid] = ConvLayer( + in_channel, head_width, 1, norm=norm, act_func=None + ) + else: + inputs[fid] = OpSequential( + [ + ConvLayer(in_channel, head_width, 1, norm=norm, act_func=None), + UpSampleLayer(factor=factor), + ] + ) + + middle = [] + for _ in range(head_depth): + if middle_op == "mbconv": + block = MBConv( + head_width, + head_width, + expand_ratio=expand_ratio, + norm=norm, + act_func=(act_func, act_func, None), + ) + elif middle_op == "fmbconv": + block = FusedMBConv( + head_width, + head_width, + expand_ratio=expand_ratio, + norm=norm, + act_func=(act_func, None), + ) + else: + raise NotImplementedError + middle.append(ResidualBlock(block, IdentityLayer())) + middle = OpSequential(middle) + + outputs = { + "segout": OpSequential( + [ + ( + None + if final_expand is None + else ConvLayer( + head_width, + head_width * final_expand, + 1, + norm=norm, + act_func=act_func, + ) + ), + ConvLayer( + head_width * (final_expand or 1), + n_classes, + 1, + use_bias=True, + dropout=dropout, + norm=None, + act_func=None, + ), + ] + ) + } + + super(SegHead, self).__init__( + inputs, "add", None, middle=middle, outputs=outputs + ) + + +class EfficientViTSeg(nn.Module): + def __init__( + self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, head: SegHead + ) -> None: + super().__init__() + self.backbone = backbone + self.head = head + + def forward(self, x: torch.Tensor) -> torch.Tensor: + feed_dict = self.backbone(x) + feed_dict = self.head(feed_dict) + + return feed_dict["segout"] + + +def efficientvit_seg_b0(dataset: str, **kwargs) -> EfficientViTSeg: + from efficientvit.models.efficientvit.backbone import \ + efficientvit_backbone_b0 + + backbone = efficientvit_backbone_b0(**kwargs) + + if dataset == "cityscapes": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[128, 64, 32], + stride_list=[32, 16, 8], + head_stride=8, + head_width=32, + head_depth=1, + expand_ratio=4, + middle_op="mbconv", + final_expand=4, + n_classes=19, + **build_kwargs_from_config(kwargs, SegHead), + ) + else: + raise NotImplementedError + model = EfficientViTSeg(backbone, head) + return model + + +def efficientvit_seg_b1(dataset: str, **kwargs) -> EfficientViTSeg: + from efficientvit.models.efficientvit.backbone import \ + efficientvit_backbone_b1 + + backbone = efficientvit_backbone_b1(**kwargs) + + if dataset == "cityscapes": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[256, 128, 64], + stride_list=[32, 16, 8], + head_stride=8, + head_width=64, + head_depth=3, + expand_ratio=4, + middle_op="mbconv", + final_expand=4, + n_classes=19, + **build_kwargs_from_config(kwargs, SegHead), + ) + elif dataset == "ade20k": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[256, 128, 64], + stride_list=[32, 16, 8], + head_stride=8, + head_width=64, + head_depth=3, + expand_ratio=4, + middle_op="mbconv", + final_expand=None, + n_classes=150, + **build_kwargs_from_config(kwargs, SegHead), + ) + else: + raise NotImplementedError + model = EfficientViTSeg(backbone, head) + return model + + +def efficientvit_seg_b2(dataset: str, **kwargs) -> EfficientViTSeg: + from efficientvit.models.efficientvit.backbone import \ + efficientvit_backbone_b2 + + backbone = efficientvit_backbone_b2(**kwargs) + + if dataset == "cityscapes": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[384, 192, 96], + stride_list=[32, 16, 8], + head_stride=8, + head_width=96, + head_depth=3, + expand_ratio=4, + middle_op="mbconv", + final_expand=4, + n_classes=19, + **build_kwargs_from_config(kwargs, SegHead), + ) + elif dataset == "ade20k": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[384, 192, 96], + stride_list=[32, 16, 8], + head_stride=8, + head_width=96, + head_depth=3, + expand_ratio=4, + middle_op="mbconv", + final_expand=None, + n_classes=150, + **build_kwargs_from_config(kwargs, SegHead), + ) + else: + raise NotImplementedError + model = EfficientViTSeg(backbone, head) + return model + + +def efficientvit_seg_b3(dataset: str, **kwargs) -> EfficientViTSeg: + from efficientvit.models.efficientvit.backbone import \ + efficientvit_backbone_b3 + + backbone = efficientvit_backbone_b3(**kwargs) + + if dataset == "cityscapes": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + stride_list=[32, 16, 8], + head_stride=8, + head_width=128, + head_depth=3, + expand_ratio=4, + middle_op="mbconv", + final_expand=4, + n_classes=19, + **build_kwargs_from_config(kwargs, SegHead), + ) + elif dataset == "ade20k": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + stride_list=[32, 16, 8], + head_stride=8, + head_width=128, + head_depth=3, + expand_ratio=4, + middle_op="mbconv", + final_expand=None, + n_classes=150, + **build_kwargs_from_config(kwargs, SegHead), + ) + else: + raise NotImplementedError + model = EfficientViTSeg(backbone, head) + return model + + +def efficientvit_seg_l1(dataset: str, **kwargs) -> EfficientViTSeg: + from efficientvit.models.efficientvit.backbone import \ + efficientvit_backbone_l1 + + backbone = efficientvit_backbone_l1(**kwargs) + + if dataset == "cityscapes": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + stride_list=[32, 16, 8], + head_stride=8, + head_width=256, + head_depth=3, + expand_ratio=1, + middle_op="fmbconv", + final_expand=None, + n_classes=19, + act_func="gelu", + **build_kwargs_from_config(kwargs, SegHead), + ) + elif dataset == "ade20k": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + stride_list=[32, 16, 8], + head_stride=8, + head_width=128, + head_depth=3, + expand_ratio=4, + middle_op="fmbconv", + final_expand=8, + n_classes=150, + act_func="gelu", + **build_kwargs_from_config(kwargs, SegHead), + ) + else: + raise NotImplementedError + model = EfficientViTSeg(backbone, head) + return model + + +def efficientvit_seg_l2(dataset: str, **kwargs) -> EfficientViTSeg: + from efficientvit.models.efficientvit.backbone import \ + efficientvit_backbone_l2 + + backbone = efficientvit_backbone_l2(**kwargs) + + if dataset == "cityscapes": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + stride_list=[32, 16, 8], + head_stride=8, + head_width=256, + head_depth=5, + expand_ratio=1, + middle_op="fmbconv", + final_expand=None, + n_classes=19, + act_func="gelu", + **build_kwargs_from_config(kwargs, SegHead), + ) + elif dataset == "ade20k": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + stride_list=[32, 16, 8], + head_stride=8, + head_width=128, + head_depth=3, + expand_ratio=4, + middle_op="fmbconv", + final_expand=8, + n_classes=150, + act_func="gelu", + **build_kwargs_from_config(kwargs, SegHead), + ) + else: + raise NotImplementedError + model = EfficientViTSeg(backbone, head) + return model diff --git a/src/efficientvit/models/nn/__init__.py b/src/efficientvit/models/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6152158a1a8a0b4d2fc53622bdf338fbf34809d --- /dev/null +++ b/src/efficientvit/models/nn/__init__.py @@ -0,0 +1,8 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from .act import * +from .drop import * +from .norm import * +from .ops import * diff --git a/src/efficientvit/models/nn/act.py b/src/efficientvit/models/nn/act.py new file mode 100644 index 0000000000000000000000000000000000000000..31d439e24a3453222265c63593537942657ff8eb --- /dev/null +++ b/src/efficientvit/models/nn/act.py @@ -0,0 +1,30 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from functools import partial + +import torch.nn as nn + +from src.efficientvit.models.utils import build_kwargs_from_config + +__all__ = ["build_act"] + + +# register activation function here +REGISTERED_ACT_DICT: dict[str, type] = { + "relu": nn.ReLU, + "relu6": nn.ReLU6, + "hswish": nn.Hardswish, + "silu": nn.SiLU, + "gelu": partial(nn.GELU, approximate="tanh"), +} + + +def build_act(name: str, **kwargs) -> nn.Module or None: + if name in REGISTERED_ACT_DICT: + act_cls = REGISTERED_ACT_DICT[name] + args = build_kwargs_from_config(kwargs, act_cls) + return act_cls(**args) + else: + return None diff --git a/src/efficientvit/models/nn/drop.py b/src/efficientvit/models/nn/drop.py new file mode 100644 index 0000000000000000000000000000000000000000..0c10aa3dbc89a360fd4c19b1cb7172c2ceea71eb --- /dev/null +++ b/src/efficientvit/models/nn/drop.py @@ -0,0 +1,98 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import numpy as np +import torch +import torch.nn as nn + +from src.efficientvit.apps.trainer.run_config import Scheduler +from src.efficientvit.models.nn.ops import IdentityLayer, ResidualBlock +from src.efficientvit.models.utils import build_kwargs_from_config + +__all__ = ["apply_drop_func"] + + +def apply_drop_func(network: nn.Module, drop_config: dict[str, any] or None) -> None: + if drop_config is None: + return + + drop_lookup_table = { + "droppath": apply_droppath, + } + + drop_func = drop_lookup_table[drop_config["name"]] + drop_kwargs = build_kwargs_from_config(drop_config, drop_func) + + drop_func(network, **drop_kwargs) + + +def apply_droppath( + network: nn.Module, + drop_prob: float, + linear_decay=True, + scheduled=True, + skip=0, +) -> None: + all_valid_blocks = [] + for m in network.modules(): + for name, sub_module in m.named_children(): + if isinstance(sub_module, ResidualBlock) and isinstance( + sub_module.shortcut, IdentityLayer + ): + all_valid_blocks.append((m, name, sub_module)) + all_valid_blocks = all_valid_blocks[skip:] + for i, (m, name, sub_module) in enumerate(all_valid_blocks): + prob = ( + drop_prob * (i + 1) / len(all_valid_blocks) if linear_decay else drop_prob + ) + new_module = DropPathResidualBlock( + sub_module.main, + sub_module.shortcut, + sub_module.post_act, + sub_module.pre_norm, + prob, + scheduled, + ) + m._modules[name] = new_module + + +class DropPathResidualBlock(ResidualBlock): + def __init__( + self, + main: nn.Module, + shortcut: nn.Module or None, + post_act=None, + pre_norm: nn.Module or None = None, + ###################################### + drop_prob: float = 0, + scheduled=True, + ): + super().__init__(main, shortcut, post_act, pre_norm) + + self.drop_prob = drop_prob + self.scheduled = scheduled + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if ( + not self.training + or self.drop_prob == 0 + or not isinstance(self.shortcut, IdentityLayer) + ): + return ResidualBlock.forward(self, x) + else: + drop_prob = self.drop_prob + if self.scheduled: + drop_prob *= np.clip(Scheduler.PROGRESS, 0, 1) + keep_prob = 1 - drop_prob + + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device + ) + random_tensor.floor_() # binarize + + res = self.forward_main(x) / keep_prob * random_tensor + self.shortcut(x) + if self.post_act: + res = self.post_act(res) + return res diff --git a/src/efficientvit/models/nn/norm.py b/src/efficientvit/models/nn/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..03fcacba84e6b2258b4fc4b893f6e2e5151257bb --- /dev/null +++ b/src/efficientvit/models/nn/norm.py @@ -0,0 +1,157 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch +import torch.nn as nn +from torch.nn.modules.batchnorm import _BatchNorm + +from src.efficientvit.models.utils import build_kwargs_from_config + +__all__ = ["LayerNorm2d", "build_norm", "reset_bn", "set_norm_eps"] + + +class LayerNorm2d(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = x - torch.mean(x, dim=1, keepdim=True) + out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps) + if self.elementwise_affine: + out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) + return out + + +# register normalization function here +REGISTERED_NORM_DICT: dict[str, type] = { + "bn2d": nn.BatchNorm2d, + "ln": nn.LayerNorm, + "ln2d": LayerNorm2d, +} + + +def build_norm(name="bn2d", num_features=None, **kwargs) -> nn.Module or None: + if name in ["ln", "ln2d"]: + kwargs["normalized_shape"] = num_features + else: + kwargs["num_features"] = num_features + if name in REGISTERED_NORM_DICT: + norm_cls = REGISTERED_NORM_DICT[name] + args = build_kwargs_from_config(kwargs, norm_cls) + return norm_cls(**args) + else: + return None + + +def reset_bn( + model: nn.Module, + data_loader: list, + sync=True, + progress_bar=False, +) -> None: + import copy + + import torch.nn.functional as F + from tqdm import tqdm + + from efficientvit.apps.utils import AverageMeter, is_master, sync_tensor + from efficientvit.models.utils import get_device, list_join + + bn_mean = {} + bn_var = {} + + tmp_model = copy.deepcopy(model) + for name, m in tmp_model.named_modules(): + if isinstance(m, _BatchNorm): + bn_mean[name] = AverageMeter(is_distributed=False) + bn_var[name] = AverageMeter(is_distributed=False) + + def new_forward(bn, mean_est, var_est): + def lambda_forward(x): + x = x.contiguous() + if sync: + batch_mean = ( + x.mean(0, keepdim=True) + .mean(2, keepdim=True) + .mean(3, keepdim=True) + ) # 1, C, 1, 1 + batch_mean = sync_tensor(batch_mean, reduce="cat") + batch_mean = torch.mean(batch_mean, dim=0, keepdim=True) + + batch_var = (x - batch_mean) * (x - batch_mean) + batch_var = ( + batch_var.mean(0, keepdim=True) + .mean(2, keepdim=True) + .mean(3, keepdim=True) + ) + batch_var = sync_tensor(batch_var, reduce="cat") + batch_var = torch.mean(batch_var, dim=0, keepdim=True) + else: + batch_mean = ( + x.mean(0, keepdim=True) + .mean(2, keepdim=True) + .mean(3, keepdim=True) + ) # 1, C, 1, 1 + batch_var = (x - batch_mean) * (x - batch_mean) + batch_var = ( + batch_var.mean(0, keepdim=True) + .mean(2, keepdim=True) + .mean(3, keepdim=True) + ) + + batch_mean = torch.squeeze(batch_mean) + batch_var = torch.squeeze(batch_var) + + mean_est.update(batch_mean.data, x.size(0)) + var_est.update(batch_var.data, x.size(0)) + + # bn forward using calculated mean & var + _feature_dim = batch_mean.shape[0] + return F.batch_norm( + x, + batch_mean, + batch_var, + bn.weight[:_feature_dim], + bn.bias[:_feature_dim], + False, + 0.0, + bn.eps, + ) + + return lambda_forward + + m.forward = new_forward(m, bn_mean[name], bn_var[name]) + + # skip if there is no batch normalization layers in the network + if len(bn_mean) == 0: + return + + tmp_model.eval() + with torch.no_grad(): + with tqdm( + total=len(data_loader), + desc="reset bn", + disable=not progress_bar or not is_master(), + ) as t: + for images in data_loader: + images = images.to(get_device(tmp_model)) + tmp_model(images) + t.set_postfix( + { + "bs": images.size(0), + "res": list_join(images.shape[-2:], "x"), + } + ) + t.update() + + for name, m in model.named_modules(): + if name in bn_mean and bn_mean[name].count > 0: + feature_dim = bn_mean[name].avg.size(0) + assert isinstance(m, _BatchNorm) + m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg) + m.running_var.data[:feature_dim].copy_(bn_var[name].avg) + + +def set_norm_eps(model: nn.Module, eps: float or None = None) -> None: + for m in model.modules(): + if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)): + if eps is not None: + m.eps = eps diff --git a/src/efficientvit/models/nn/ops.py b/src/efficientvit/models/nn/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4fd0716834c091142168f56037aead9527180222 --- /dev/null +++ b/src/efficientvit/models/nn/ops.py @@ -0,0 +1,585 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import autocast + +from src.efficientvit.models.nn.act import build_act +from src.efficientvit.models.nn.norm import build_norm +from src.efficientvit.models.utils import (get_same_padding, list_sum, resize, + val2list, val2tuple) + +__all__ = [ + "ConvLayer", + "UpSampleLayer", + "LinearLayer", + "IdentityLayer", + "DSConv", + "MBConv", + "FusedMBConv", + "ResBlock", + "LiteMLA", + "EfficientViTBlock", + "ResidualBlock", + "DAGBlock", + "OpSequential", +] + + +################################################################################# +# Basic Layers # +################################################################################# + + +class ConvLayer(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + dilation=1, + groups=1, + use_bias=False, + dropout=0, + norm="bn2d", + act_func="relu", + ): + super(ConvLayer, self).__init__() + + padding = get_same_padding(kernel_size) + padding *= dilation + + self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=(kernel_size, kernel_size), + stride=(stride, stride), + padding=padding, + dilation=(dilation, dilation), + groups=groups, + bias=use_bias, + ) + self.norm = build_norm(norm, num_features=out_channels) + self.act = build_act(act_func) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.dropout is not None: + x = self.dropout(x) + x = self.conv(x) + if self.norm: + x = self.norm(x) + if self.act: + x = self.act(x) + return x + + +class UpSampleLayer(nn.Module): + def __init__( + self, + mode="bicubic", + size: int or tuple[int, int] or list[int] or None = None, + factor=2, + align_corners=False, + ): + super(UpSampleLayer, self).__init__() + self.mode = mode + self.size = val2list(size, 2) if size is not None else None + self.factor = None if self.size is not None else factor + self.align_corners = align_corners + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if ( + self.size is not None and tuple(x.shape[-2:]) == self.size + ) or self.factor == 1: + return x + return resize(x, self.size, self.factor, self.mode, self.align_corners) + + +class LinearLayer(nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + use_bias=True, + dropout=0, + norm=None, + act_func=None, + ): + super(LinearLayer, self).__init__() + + self.dropout = nn.Dropout(dropout, inplace=False) if dropout > 0 else None + self.linear = nn.Linear(in_features, out_features, use_bias) + self.norm = build_norm(norm, num_features=out_features) + self.act = build_act(act_func) + + def _try_squeeze(self, x: torch.Tensor) -> torch.Tensor: + if x.dim() > 2: + x = torch.flatten(x, start_dim=1) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._try_squeeze(x) + if self.dropout: + x = self.dropout(x) + x = self.linear(x) + if self.norm: + x = self.norm(x) + if self.act: + x = self.act(x) + return x + + +class IdentityLayer(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +################################################################################# +# Basic Blocks # +################################################################################# + + +class DSConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + use_bias=False, + norm=("bn2d", "bn2d"), + act_func=("relu6", None), + ): + super(DSConv, self).__init__() + + use_bias = val2tuple(use_bias, 2) + norm = val2tuple(norm, 2) + act_func = val2tuple(act_func, 2) + + self.depth_conv = ConvLayer( + in_channels, + in_channels, + kernel_size, + stride, + groups=in_channels, + norm=norm[0], + act_func=act_func[0], + use_bias=use_bias[0], + ) + self.point_conv = ConvLayer( + in_channels, + out_channels, + 1, + norm=norm[1], + act_func=act_func[1], + use_bias=use_bias[1], + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.depth_conv(x) + x = self.point_conv(x) + return x + + +class MBConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=6, + use_bias=False, + norm=("bn2d", "bn2d", "bn2d"), + act_func=("relu6", "relu6", None), + ): + super(MBConv, self).__init__() + + use_bias = val2tuple(use_bias, 3) + norm = val2tuple(norm, 3) + act_func = val2tuple(act_func, 3) + mid_channels = mid_channels or round(in_channels * expand_ratio) + + self.inverted_conv = ConvLayer( + in_channels, + mid_channels, + 1, + stride=1, + norm=norm[0], + act_func=act_func[0], + use_bias=use_bias[0], + ) + self.depth_conv = ConvLayer( + mid_channels, + mid_channels, + kernel_size, + stride=stride, + groups=mid_channels, + norm=norm[1], + act_func=act_func[1], + use_bias=use_bias[1], + ) + self.point_conv = ConvLayer( + mid_channels, + out_channels, + 1, + norm=norm[2], + act_func=act_func[2], + use_bias=use_bias[2], + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.inverted_conv(x) + x = self.depth_conv(x) + x = self.point_conv(x) + return x + + +class FusedMBConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=6, + groups=1, + use_bias=False, + norm=("bn2d", "bn2d"), + act_func=("relu6", None), + ): + super().__init__() + use_bias = val2tuple(use_bias, 2) + norm = val2tuple(norm, 2) + act_func = val2tuple(act_func, 2) + + mid_channels = mid_channels or round(in_channels * expand_ratio) + + self.spatial_conv = ConvLayer( + in_channels, + mid_channels, + kernel_size, + stride, + groups=groups, + use_bias=use_bias[0], + norm=norm[0], + act_func=act_func[0], + ) + self.point_conv = ConvLayer( + mid_channels, + out_channels, + 1, + use_bias=use_bias[1], + norm=norm[1], + act_func=act_func[1], + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.spatial_conv(x) + x = self.point_conv(x) + return x + + +class ResBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=1, + use_bias=False, + norm=("bn2d", "bn2d"), + act_func=("relu6", None), + ): + super().__init__() + use_bias = val2tuple(use_bias, 2) + norm = val2tuple(norm, 2) + act_func = val2tuple(act_func, 2) + + mid_channels = mid_channels or round(in_channels * expand_ratio) + + self.conv1 = ConvLayer( + in_channels, + mid_channels, + kernel_size, + stride, + use_bias=use_bias[0], + norm=norm[0], + act_func=act_func[0], + ) + self.conv2 = ConvLayer( + mid_channels, + out_channels, + kernel_size, + 1, + use_bias=use_bias[1], + norm=norm[1], + act_func=act_func[1], + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv1(x) + x = self.conv2(x) + return x + + +class LiteMLA(nn.Module): + r"""Lightweight multi-scale linear attention""" + + def __init__( + self, + in_channels: int, + out_channels: int, + heads: int or None = None, + heads_ratio: float = 1.0, + dim=8, + use_bias=False, + norm=(None, "bn2d"), + act_func=(None, None), + kernel_func="relu", + scales: tuple[int, ...] = (5,), + eps=1.0e-15, + ): + super(LiteMLA, self).__init__() + self.eps = eps + heads = heads or int(in_channels // dim * heads_ratio) + + total_dim = heads * dim + + use_bias = val2tuple(use_bias, 2) + norm = val2tuple(norm, 2) + act_func = val2tuple(act_func, 2) + + self.dim = dim + self.qkv = ConvLayer( + in_channels, + 3 * total_dim, + 1, + use_bias=use_bias[0], + norm=norm[0], + act_func=act_func[0], + ) + self.aggreg = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d( + 3 * total_dim, + 3 * total_dim, + scale, + padding=get_same_padding(scale), + groups=3 * total_dim, + bias=use_bias[0], + ), + nn.Conv2d( + 3 * total_dim, + 3 * total_dim, + 1, + groups=3 * heads, + bias=use_bias[0], + ), + ) + for scale in scales + ] + ) + self.kernel_func = build_act(kernel_func, inplace=False) + + self.proj = ConvLayer( + total_dim * (1 + len(scales)), + out_channels, + 1, + use_bias=use_bias[1], + norm=norm[1], + act_func=act_func[1], + ) + + @autocast(enabled=False) + def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor: + B, _, H, W = list(qkv.size()) + + if qkv.dtype == torch.float16: + qkv = qkv.float() + + qkv = torch.reshape( + qkv, + ( + B, + -1, + 3 * self.dim, + H * W, + ), + ) + qkv = torch.transpose(qkv, -1, -2) + q, k, v = ( + qkv[..., 0 : self.dim], + qkv[..., self.dim : 2 * self.dim], + qkv[..., 2 * self.dim :], + ) + + # lightweight linear attention + q = self.kernel_func(q) + k = self.kernel_func(k) + + # linear matmul + trans_k = k.transpose(-1, -2) + + v = F.pad(v, (0, 1), mode="constant", value=1) + kv = torch.matmul(trans_k, v) + out = torch.matmul(q, kv) + out = out[..., :-1] / (out[..., -1:] + self.eps) + + out = torch.transpose(out, -1, -2) + out = torch.reshape(out, (B, -1, H, W)) + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # generate multi-scale q, k, v + qkv = self.qkv(x) + multi_scale_qkv = [qkv] + for op in self.aggreg: + multi_scale_qkv.append(op(qkv)) + multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1) + + out = self.relu_linear_att(multi_scale_qkv) + out = self.proj(out) + + return out + + +class EfficientViTBlock(nn.Module): + def __init__( + self, + in_channels: int, + heads_ratio: float = 1.0, + dim=32, + expand_ratio: float = 4, + scales=(5,), + norm="bn2d", + act_func="hswish", + ): + super(EfficientViTBlock, self).__init__() + self.context_module = ResidualBlock( + LiteMLA( + in_channels=in_channels, + out_channels=in_channels, + heads_ratio=heads_ratio, + dim=dim, + norm=(None, norm), + scales=scales, + ), + IdentityLayer(), + ) + local_module = MBConv( + in_channels=in_channels, + out_channels=in_channels, + expand_ratio=expand_ratio, + use_bias=(True, True, False), + norm=(None, None, norm), + act_func=(act_func, act_func, None), + ) + self.local_module = ResidualBlock(local_module, IdentityLayer()) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.context_module(x) + x = self.local_module(x) + return x + + +################################################################################# +# Functional Blocks # +################################################################################# + + +class ResidualBlock(nn.Module): + def __init__( + self, + main: nn.Module or None, + shortcut: nn.Module or None, + post_act=None, + pre_norm: nn.Module or None = None, + ): + super(ResidualBlock, self).__init__() + + self.pre_norm = pre_norm + self.main = main + self.shortcut = shortcut + self.post_act = build_act(post_act) + + def forward_main(self, x: torch.Tensor) -> torch.Tensor: + if self.pre_norm is None: + return self.main(x) + else: + return self.main(self.pre_norm(x)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.main is None: + res = x + elif self.shortcut is None: + res = self.forward_main(x) + else: + res = self.forward_main(x) + self.shortcut(x) + if self.post_act: + res = self.post_act(res) + return res + + +class DAGBlock(nn.Module): + def __init__( + self, + inputs: dict[str, nn.Module], + merge: str, + post_input: nn.Module or None, + middle: nn.Module, + outputs: dict[str, nn.Module], + ): + super(DAGBlock, self).__init__() + + self.input_keys = list(inputs.keys()) + self.input_ops = nn.ModuleList(list(inputs.values())) + self.merge = merge + self.post_input = post_input + + self.middle = middle + + self.output_keys = list(outputs.keys()) + self.output_ops = nn.ModuleList(list(outputs.values())) + + def forward(self, feature_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + feat = [ + op(feature_dict[key]) for key, op in zip(self.input_keys, self.input_ops) + ] + if self.merge == "add": + feat = list_sum(feat) + elif self.merge == "cat": + feat = torch.concat(feat, dim=1) + else: + raise NotImplementedError + if self.post_input is not None: + feat = self.post_input(feat) + feat = self.middle(feat) + for key, op in zip(self.output_keys, self.output_ops): + feature_dict[key] = op(feat) + return feature_dict + + +class OpSequential(nn.Module): + def __init__(self, op_list: list[nn.Module or None]): + super(OpSequential, self).__init__() + valid_op_list = [] + for op in op_list: + if op is not None: + valid_op_list.append(op) + self.op_list = nn.ModuleList(valid_op_list) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for op in self.op_list: + x = op(x) + return x diff --git a/src/efficientvit/models/utils/__init__.py b/src/efficientvit/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0aab6b0a576b33e1e72029210f7b4232c9b7b8b6 --- /dev/null +++ b/src/efficientvit/models/utils/__init__.py @@ -0,0 +1,7 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from .list import * +from .network import * +from .random import * diff --git a/src/efficientvit/models/utils/list.py b/src/efficientvit/models/utils/list.py new file mode 100644 index 0000000000000000000000000000000000000000..0a2c3291a88ab1d3cc77f7bc7d5eb475e9670a28 --- /dev/null +++ b/src/efficientvit/models/utils/list.py @@ -0,0 +1,57 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +__all__ = [ + "list_sum", + "list_mean", + "weighted_list_sum", + "list_join", + "val2list", + "val2tuple", + "squeeze_list", +] + + +def list_sum(x: list) -> any: + return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) + + +def list_mean(x: list) -> any: + return list_sum(x) / len(x) + + +def weighted_list_sum(x: list, weights: list) -> any: + assert len(x) == len(weights) + return ( + x[0] * weights[0] + if len(x) == 1 + else x[0] * weights[0] + weighted_list_sum(x[1:], weights[1:]) + ) + + +def list_join(x: list, sep="\t", format_str="%s") -> str: + return sep.join([format_str % val for val in x]) + + +def val2list(x: list or tuple or any, repeat_time=1) -> list: + if isinstance(x, (list, tuple)): + return list(x) + return [x for _ in range(repeat_time)] + + +def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: + x = val2list(x) + + # repeat elements if necessary + if len(x) > 0: + x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] + + return tuple(x) + + +def squeeze_list(x: list or None) -> list or any: + if x is not None and len(x) == 1: + return x[0] + else: + return x diff --git a/src/efficientvit/models/utils/network.py b/src/efficientvit/models/utils/network.py new file mode 100644 index 0000000000000000000000000000000000000000..6ba96ec255dc7543be2a7995fed58f7d139d2c75 --- /dev/null +++ b/src/efficientvit/models/utils/network.py @@ -0,0 +1,77 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import os +from inspect import signature + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = [ + "is_parallel", + "get_device", + "get_same_padding", + "resize", + "build_kwargs_from_config", + "load_state_dict_from_file", +] + + +def is_parallel(model: nn.Module) -> bool: + return isinstance( + model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) + ) + + +def get_device(model: nn.Module) -> torch.device: + return model.parameters().__next__().device + + +def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]: + if isinstance(kernel_size, tuple): + return tuple([get_same_padding(ks) for ks in kernel_size]) + else: + assert kernel_size % 2 > 0, "kernel size should be odd number" + return kernel_size // 2 + + +def resize( + x: torch.Tensor, + size: any or None = None, + scale_factor: list[float] or None = None, + mode: str = "bicubic", + align_corners: bool or None = False, +) -> torch.Tensor: + if mode in {"bilinear", "bicubic"}: + return F.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + ) + elif mode in {"nearest", "area"}: + return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode) + else: + raise NotImplementedError(f"resize(mode={mode}) not implemented.") + + +def build_kwargs_from_config(config: dict, target_func: callable) -> dict[str, any]: + valid_keys = list(signature(target_func).parameters) + kwargs = {} + for key in config: + if key in valid_keys: + kwargs[key] = config[key] + return kwargs + + +def load_state_dict_from_file( + file: str, only_state_dict=True +) -> dict[str, torch.Tensor]: + file = os.path.realpath(os.path.expanduser(file)) + checkpoint = torch.load(file, map_location="cpu") + if only_state_dict and "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + return checkpoint diff --git a/src/efficientvit/models/utils/random.py b/src/efficientvit/models/utils/random.py new file mode 100644 index 0000000000000000000000000000000000000000..0257f7ab93a3781c159a917823c36d8ada976292 --- /dev/null +++ b/src/efficientvit/models/utils/random.py @@ -0,0 +1,73 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import numpy as np +import torch + +__all__ = [ + "torch_randint", + "torch_random", + "torch_shuffle", + "torch_uniform", + "torch_random_choices", +] + + +def torch_randint( + low: int, high: int, generator: torch.Generator or None = None +) -> int: + """uniform: [low, high)""" + if low == high: + return low + else: + assert low < high + return int(torch.randint(low=low, high=high, generator=generator, size=(1,))) + + +def torch_random(generator: torch.Generator or None = None) -> float: + """uniform distribution on the interval [0, 1)""" + return float(torch.rand(1, generator=generator)) + + +def torch_shuffle( + src_list: list[any], generator: torch.Generator or None = None +) -> list[any]: + rand_indexes = torch.randperm(len(src_list), generator=generator).tolist() + return [src_list[i] for i in rand_indexes] + + +def torch_uniform( + low: float, high: float, generator: torch.Generator or None = None +) -> float: + """uniform distribution on the interval [low, high)""" + rand_val = torch_random(generator) + return (high - low) * rand_val + low + + +def torch_random_choices( + src_list: list[any], + generator: torch.Generator or None = None, + k=1, + weight_list: list[float] or None = None, +) -> any or list: + if weight_list is None: + rand_idx = torch.randint( + low=0, high=len(src_list), generator=generator, size=(k,) + ) + out_list = [src_list[i] for i in rand_idx] + else: + assert len(weight_list) == len(src_list) + accumulate_weight_list = np.cumsum(weight_list) + + out_list = [] + for _ in range(k): + val = torch_uniform(0, accumulate_weight_list[-1], generator) + active_id = 0 + for i, weight_val in enumerate(accumulate_weight_list): + active_id = i + if weight_val > val: + break + out_list.append(src_list[active_id]) + + return out_list[0] if k == 1 else out_list diff --git a/src/efficientvit/sam_model_zoo.py b/src/efficientvit/sam_model_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc0c1c290a924ddf93e5dae326b6f3a5d17c7a1 --- /dev/null +++ b/src/efficientvit/sam_model_zoo.py @@ -0,0 +1,53 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from src.efficientvit.models.efficientvit import (EfficientViTSam, + efficientvit_sam_l0, + efficientvit_sam_l1, + efficientvit_sam_l2, + efficientvit_sam_xl0, + efficientvit_sam_xl1) +from src.efficientvit.models.nn.norm import set_norm_eps +from src.efficientvit.models.utils import load_state_dict_from_file + +__all__ = ["create_sam_model"] + + +REGISTERED_SAM_MODEL: dict[str, str] = { + "l0": "assets/checkpoints/sam/l0.pt", + "l1": "assets/checkpoints/sam/l1.pt", + "l2": "assets/checkpoints/sam/l2.pt", + "xl0": "assets/checkpoints/sam/xl0.pt", + "xl1": "assets/checkpoints/sam/xl1.pt", +} + + +def create_sam_model( + name: str, pretrained=True, weight_url: str or None = None, **kwargs +) -> EfficientViTSam: + model_dict = { + "l0": efficientvit_sam_l0, + "l1": efficientvit_sam_l1, + "l2": efficientvit_sam_l2, + "xl0": efficientvit_sam_xl0, + "xl1": efficientvit_sam_xl1, + } + + model_id = name.split("-")[0] + if model_id not in model_dict: + raise ValueError( + f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}" + ) + else: + model = model_dict[model_id](**kwargs) + set_norm_eps(model, 1e-6) + + if pretrained: + weight_url = weight_url or REGISTERED_SAM_MODEL.get(name, None) + if weight_url is None: + raise ValueError(f"Do not find the pretrained weight of {name}.") + else: + weight = load_state_dict_from_file(weight_url) + model.load_state_dict(weight) + return model diff --git a/src/ip_adapter/attention_processor.py b/src/ip_adapter/attention_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..c38f0e4886895d987ea2b4b1f0694546a7728287 --- /dev/null +++ b/src/ip_adapter/attention_processor.py @@ -0,0 +1,424 @@ +# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + import xformers + import xformers.ops + + xformers_available = True +except Exception as e: + xformers_available = False + + +class AttnProcessor(nn.Module): + r""" + Default processor for performing attention-related computations. + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, + end_pos:, :] + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + if xformers_available: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + else: + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + if xformers_available: + ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None) + else: + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + # hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + with torch.no_grad(): + self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) + # print(self.attn_map.shape) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states \ No newline at end of file diff --git a/src/ip_adapter/resampler.py b/src/ip_adapter/resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..149ffa5c031bb18a948f471a83c99e031bea14a6 --- /dev/null +++ b/src/ip_adapter/resampler.py @@ -0,0 +1,120 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +import math + +import torch +import torch.nn as nn + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head ** -0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) \ No newline at end of file diff --git a/src/ip_adapter/utils.py b/src/ip_adapter/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a105f3701c15e8d3bbf838d79bacc51e91d0696 --- /dev/null +++ b/src/ip_adapter/utils.py @@ -0,0 +1,5 @@ +import torch.nn.functional as F + + +def is_torch2_available(): + return hasattr(F, "scaled_dot_product_attention") diff --git a/src/pipelines/instantid_pipeline.py b/src/pipelines/instantid_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..029536a35c149d06c8b2da35c6a2cbc66180a683 --- /dev/null +++ b/src/pipelines/instantid_pipeline.py @@ -0,0 +1,720 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers import StableDiffusionXLControlNetPipeline +from PIL import Image +from torchvision.transforms.functional import to_tensor +from einops import rearrange +from torch import einsum +import math +from torchvision.utils import save_image +from diffusers.utils import load_image +import cv2 + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +class RegionControlNet_AttnProcessor: + def __init__(self, attention_op=None, controller=None, place_in_unet=None): + self.attention_op = attention_op + self.controller = controller + self.place_in_unet = place_in_unet + + def __call__( + self, + attn, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + **cross_attention_kwargs + ) -> torch.Tensor: + residual = hidden_states + + args = () if USE_PEFT_BACKEND else (scale,) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states, *args) + + is_cross = True + if encoder_hidden_states is None: + is_cross = False + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + attention_probs = self.controller(attention_probs, is_cross, self.place_in_unet) + hidden_states = torch.bmm(attention_probs, value) + + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +def revise_regionally_controlnet_forward(unet, controller): + def change_forward(unet, count, place_in_unet): + for name, layer in unet.named_children(): + if layer.__class__.__name__ == 'Attention': + layer.set_processor(RegionControlNet_AttnProcessor(controller=controller, place_in_unet=place_in_unet)) + if 'attn2' in name: + count += 1 + else: + count = change_forward(layer, count, place_in_unet) + return count + + # use this to ensure the order + cross_attention_idx = change_forward(unet.down_blocks, 0, "down") + cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx, "up") + cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx, "mid") + print(f'Number of attention layer registered {cross_attention_idx}') + controller.num_att_layers = cross_attention_idx*2 + +class InstantidMultiConceptPipeline(StableDiffusionXLControlNetPipeline): + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + controller=None, + concept_models=None, + indices_to_alter=None, + face_app=None, + stage=None, + region_masks=None, + **kwargs, + ): + # revise_regionally_controlnet_forward(self.unet, controller) + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + batch_size = 2 + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + global_prompt = prompt[0] + global_negative_prompt = negative_prompt + region_prompts = [pt[0] for pt in prompt[1]] + region_negative_prompts = [pt[1] for pt in prompt[1]] + ref_images = [pt[2] for pt in prompt[1]] + + concat_prompts = global_prompt + region_prompts + concat_negative_prompts = global_negative_prompt + region_negative_prompts + + ( + concat_prompt_embeds, + concat_negative_prompt_embeds, + concat_pooled_prompt_embeds, + concat_negative_pooled_prompt_embeds, + ) = self.encode_prompt( + concat_prompts, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + concat_negative_prompts, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + prompt_embeds = concat_prompt_embeds[:2] + negative_prompt_embeds = concat_negative_prompt_embeds[:2] + pooled_prompt_embeds = concat_pooled_prompt_embeds[:2] + negative_pooled_prompt_embeds = concat_negative_pooled_prompt_embeds[:2] + + region_prompt_embeds_list = [] + region_add_text_embeds_list = [] + for region_prompt_embeds, region_negative_prompt_embeds, region_pooled_prompt_embeds, region_negative_pooled_prompt_embeds in zip(concat_prompt_embeds[2:], concat_negative_prompt_embeds[2:], concat_pooled_prompt_embeds[2:], concat_negative_pooled_prompt_embeds[2:]): + region_prompt_embeds_list.append( + torch.concat([region_negative_prompt_embeds.unsqueeze(0), region_prompt_embeds.unsqueeze(0)], dim=0).to(concept_models._execution_device)) + region_add_text_embeds_list.append( + torch.concat([region_negative_pooled_prompt_embeds.unsqueeze(0), region_pooled_prompt_embeds.unsqueeze(0)], dim=0).to(concept_models._execution_device)) + + + if stage==2: + mask_list = [mask.float().to(dtype=prompt_embeds.dtype, device=device) if mask is not None else None for mask in region_masks] + image_embedding_list = get_face_embedding(face_app, ref_images) + image_prompt_image_emb_list = [] + for image_embeds in image_embedding_list: + prompt_image_emb = concept_models._encode_prompt_image_emb(image_embeds, + concept_models._execution_device, + num_images_per_prompt, + concept_models.unet.dtype, + True) + image_prompt_image_emb_list.append(prompt_image_emb) + + + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel) and image is not None: + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=1 * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel) and image is not None: + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size//2 * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.1 repeat latent + latents = torch.cat([latents, latents.clone()]) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + add_time_ids_list = [] + region_add_time_ids = concept_models._get_add_time_ids(original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim) + for _ in range(len(prompt[1])): + add_time_ids_list.append(torch.concat([region_add_time_ids, region_add_time_ids], dim=0).to(concept_models._execution_device)) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + # hyper-parameters + scale_range = np.linspace(1, 0.5, len(self.scheduler.timesteps)) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if i > 15 and stage == 2: + region_mask = self.get_region_mask(mask_list, noise_pred.shape[2], noise_pred.shape[3]) + edit_noise = torch.concat([noise_pred[1:2], noise_pred[3:4]], dim=0) + new_noise_pred = torch.zeros_like(edit_noise) + new_noise_pred[:, :, region_mask == 0] = edit_noise[:, :, region_mask == 0] + replace_ratio = 1.0 + new_noise_pred[:, :, region_mask != 0] = (1 - replace_ratio) * edit_noise[:, :, region_mask != 0] + + for region_prompt_embeds, region_add_text_embeds, region_add_time_ids, concept_mask, region_prompt, region_prompt_image_emb in zip(region_prompt_embeds_list, region_add_text_embeds_list, add_time_ids_list, mask_list, region_prompts, image_prompt_image_emb_list): + if concept_mask is not None: + concept_mask = F.interpolate(concept_mask.unsqueeze(0).unsqueeze(0), + size=(noise_pred.shape[2], noise_pred.shape[3]), + mode='nearest').squeeze().to(dtype=noise_pred.dtype, device=concept_models._execution_device) + + region_latent_model_input = latent_model_input[3:4].clone().to(concept_models._execution_device) + + region_latent_model_input = torch.cat([region_latent_model_input] * 2) + region_added_cond_kwargs = {"text_embeds": region_add_text_embeds, + "time_ids": region_add_time_ids} + + if image is not None: + down_block_res_samples, mid_block_res_sample = self.controlnet( + region_latent_model_input, + t, + encoder_hidden_states=region_prompt_image_emb, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=region_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in + down_block_res_samples] + mid_block_res_sample = torch.cat( + [torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + else: + down_block_res_samples = None + mid_block_res_sample = None + + region_encoder_hidden_states = torch.cat([region_prompt_embeds, region_prompt_image_emb], dim=1) + + region_noise_pred = concept_models.unet( + region_latent_model_input, + t, + encoder_hidden_states=region_encoder_hidden_states, + cross_attention_kwargs=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=region_added_cond_kwargs, + return_dict=False, + )[0] + + + new_noise_pred = new_noise_pred.to(concept_models._execution_device) + new_noise_pred[:, :, concept_mask==1] += replace_ratio * (region_noise_pred[:, :, concept_mask==1] / (concept_mask.reshape(1, 1, *concept_mask.shape)[:, :, concept_mask==1].to(region_noise_pred.device))) + + + new_noise_pred = new_noise_pred.to(noise_pred.device) + noise_pred[1, :, :, :] = new_noise_pred[0] + noise_pred[3, :, :, :] = new_noise_pred[1] + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # manually for max memory savings + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) + + def check_image(self, image, prompt, prompt_embeds): + pass + + def get_region_mask(self, mask_list, feat_height, feat_width): + exclusive_mask = torch.zeros((feat_height, feat_width)) + for mask in mask_list: + if mask is not None: + mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(feat_height, feat_width), + mode='nearest').squeeze().to(dtype=exclusive_mask.dtype, device=exclusive_mask.device) + exclusive_mask = ((mask == 1) | (exclusive_mask == 1)).to(dtype=mask.dtype) + return exclusive_mask + +def get_face_embedding(face_app, ref_images): + emb_list = [] + for img_path in ref_images: + face_image = load_image(img_path) + + # prepare face emb + face_info = face_app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR)) + face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * x['bbox'][3] - x['bbox'][1])[0] # only use the maximum face + face_emb = face_info['embedding'] + emb_list.append(face_emb) + # face_kps = draw_kps(face_image, face_info['kps']) + return emb_list \ No newline at end of file diff --git a/src/pipelines/instantid_single_pieline.py b/src/pipelines/instantid_single_pieline.py new file mode 100644 index 0000000000000000000000000000000000000000..133944e2008ae327a85b3a5d9e8244042a4e82e8 --- /dev/null +++ b/src/pipelines/instantid_single_pieline.py @@ -0,0 +1,772 @@ +# Copyright 2024 The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import cv2 +import math + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F + +from diffusers.image_processor import PipelineImageInput + +from diffusers.models import ControlNetModel + +from diffusers.utils import ( + deprecate, + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput + +from diffusers import StableDiffusionXLControlNetPipeline +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.utils.import_utils import is_xformers_available + +from src.ip_adapter.resampler import Resampler +from src.ip_adapter.utils import is_torch2_available + +if is_torch2_available(): + from src.ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor +else: + from src.ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate insightface + >>> import diffusers + >>> from diffusers.utils import load_image + >>> from diffusers.models import ControlNetModel + + >>> import cv2 + >>> import torch + >>> import numpy as np + >>> from PIL import Image + + >>> from insightface.app import FaceAnalysis + >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps + + >>> # download 'antelopev2' under ./models + >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + >>> app.prepare(ctx_id=0, det_size=(640, 640)) + + >>> # download models under ./checkpoints + >>> face_adapter = f'./checkpoints/ip-adapter.bin' + >>> controlnet_path = f'./checkpoints/ControlNetModel' + + >>> # load IdentityNet + >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) + + >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.cuda() + + >>> # load adapter + >>> pipe.load_ip_adapter_instantid(face_adapter) + + >>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality" + >>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured" + + >>> # load an image + >>> image = load_image("your-example.jpg") + + >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1] + >>> face_emb = face_info['embedding'] + >>> face_kps = draw_kps(face_image, face_info['kps']) + + >>> pipe.set_ip_adapter_scale(0.8) + + >>> # generate image + >>> image = pipe( + ... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8 + ... ).images[0] + ``` +""" + + +def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]): + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, + 360, 1) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + + +class InstantidSingleConceptPipeline(StableDiffusionXLControlNetPipeline): + + def cuda(self, dtype=torch.float16, use_xformers=False): + self.to('cuda', dtype) + + if hasattr(self, 'image_proj_model'): + self.image_proj_model.to(self.unet.device).to(self.unet.dtype) + + if use_xformers: + if is_xformers_available(): + import xformers + from packaging import version + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + self.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5): + self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens) + self.set_ip_adapter(model_ckpt, num_tokens, scale) + + def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16): + + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=num_tokens, + embedding_dim=image_emb_dim, + output_dim=self.unet.config.cross_attention_dim, + ff_mult=4, + ) + + image_proj_model.eval() + + self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype) + state_dict = torch.load(model_ckpt, map_location="cpu") + if 'image_proj' in state_dict: + state_dict = state_dict["image_proj"] + self.image_proj_model.load_state_dict(state_dict) + + self.image_proj_model_in_features = image_emb_dim + + def set_ip_adapter(self, model_ckpt, num_tokens, scale): + + unet = self.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype) + else: + attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=scale, + num_tokens=num_tokens).to(unet.device, dtype=unet.dtype) + unet.set_attn_processor(attn_procs) + + state_dict = torch.load(model_ckpt, map_location="cpu") + ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values()) + if 'ip_adapter' in state_dict: + state_dict = state_dict['ip_adapter'] + ip_layers.load_state_dict(state_dict) + + def set_ip_adapter_scale(self, scale): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for attn_processor in unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + def _encode_prompt_image_emb(self, prompt_image_emb, device, num_images_per_prompt, dtype, + do_classifier_free_guidance): + + if isinstance(prompt_image_emb, torch.Tensor): + prompt_image_emb = prompt_image_emb.clone().detach() + else: + prompt_image_emb = torch.tensor(prompt_image_emb) + + prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype) + prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features]) + + if do_classifier_free_guidance: + prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0) + else: + prompt_image_emb = torch.cat([prompt_image_emb], dim=0) + + prompt_image_emb = self.image_proj_model(prompt_image_emb) + + bs_embed, seq_len, _ = prompt_image_emb.shape + prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1) + prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1) + + return prompt_image_emb + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + image_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + + # IP adapter + ip_adapter_scale=None, + + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + image_embeds (`torch.FloatTensor`, *optional*): + Pre-generated image embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned containing the output images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 0. set ip_adapter_scale + if ip_adapter_scale is not None: + self.set_ip_adapter_scale(ip_adapter_scale) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode image prompt + prompt_image_emb = self._encode_prompt_image_emb(image_embeds, + device, + num_images_per_prompt, + self.unet.dtype, + self.do_classifier_free_guidance) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=prompt_image_emb, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=encoder_hidden_states, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/src/pipelines/lora_pipeline.py b/src/pipelines/lora_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..b24b1e64e6696897f365761c6037126ffd225395 --- /dev/null +++ b/src/pipelines/lora_pipeline.py @@ -0,0 +1,681 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers import StableDiffusionXLControlNetPipeline +from PIL import Image +from torchvision.transforms.functional import to_tensor +from einops import rearrange +from torch import einsum +import math +from torchvision.utils import save_image + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +class RegionControlNet_AttnProcessor: + def __init__(self, attention_op=None, controller=None, place_in_unet=None): + self.attention_op = attention_op + self.controller = controller + self.place_in_unet = place_in_unet + + def __call__( + self, + attn, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + **cross_attention_kwargs + ) -> torch.Tensor: + residual = hidden_states + + args = () if USE_PEFT_BACKEND else (scale,) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states, *args) + + is_cross = True + if encoder_hidden_states is None: + is_cross = False + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + attention_probs = self.controller(attention_probs, is_cross, self.place_in_unet) + hidden_states = torch.bmm(attention_probs, value) + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +def revise_regionally_controlnet_forward(unet, controller): + def change_forward(unet, count, place_in_unet): + for name, layer in unet.named_children(): + if layer.__class__.__name__ == 'Attention': + layer.set_processor(RegionControlNet_AttnProcessor(controller=controller, place_in_unet=place_in_unet)) + if 'attn2' in name: + count += 1 + else: + count = change_forward(layer, count, place_in_unet) + return count + + # use this to ensure the order + cross_attention_idx = change_forward(unet.down_blocks, 0, "down") + cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx, "up") + cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx, "mid") + print(f'Number of attention layer registered {cross_attention_idx}') + controller.num_att_layers = cross_attention_idx*2 + +class LoraMultiConceptPipeline(StableDiffusionXLControlNetPipeline): + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None + ): + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + controller=None, + concept_models=None, + stage=None, + region_masks=None, + lora_list=None, + styleL=None, + **kwargs, + ): + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + batch_size = 2 + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + global_prompt = prompt[0] + global_negative_prompt = negative_prompt + region_prompts = [pt[0] for pt in prompt[1]] + region_negative_prompts = [pt[1] for pt in prompt[1]] + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + global_prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + global_negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + region_prompt_embeds_list = [] + region_add_text_embeds_list = [] + for lora_param, region_prompt, region_negative_prompt in zip(lora_list, region_prompts, region_negative_prompts): + if styleL: + concept_models.set_adapters([lora_param, "style"], adapter_weights=[0.7, 0.5]) + else: + concept_models.set_adapters(lora_param) + region_prompt_embeds, region_negative_prompt_embeds, region_pooled_prompt_embeds, region_negative_pooled_prompt_embeds = concept_models.encode_prompt( + prompt=region_prompt, device=concept_models._execution_device, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=region_negative_prompt, lora_scale=text_encoder_lora_scale + ) + region_prompt_embeds_list.append(torch.concat([region_negative_prompt_embeds, region_prompt_embeds], dim=0).to(concept_models._execution_device)) + region_add_text_embeds_list.append(torch.concat([region_negative_pooled_prompt_embeds, region_pooled_prompt_embeds], dim=0).to(concept_models._execution_device)) + + if stage==2: + mask_list = [mask.float().to(dtype=prompt_embeds.dtype, device=device) if mask is not None else None for mask in region_masks] + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel) and image is not None: + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel) and image is not None: + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size//2 * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.1 repeat latent + latents = torch.cat([latents, latents.clone()]) + + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + add_time_ids_list = [] + for _ in lora_list: + region_add_time_ids = concept_models._get_add_time_ids(original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim) + add_time_ids_list.append(torch.concat([region_add_time_ids, region_add_time_ids], dim=0).to(concept_models._execution_device)) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + # hyper-parameters + scale_range = np.linspace(1, 0.5, len(self.scheduler.timesteps)) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + if image is not None: + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + else: + down_block_res_samples = None + mid_block_res_sample = None + + + + # predict the noise residual + if image is not None: + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + else: + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if i > 15 and stage == 2: + region_mask = self.get_region_mask(mask_list, noise_pred.shape[2], noise_pred.shape[3]) + edit_noise = torch.concat([noise_pred[1:2], noise_pred[3:4]], dim=0) + new_noise_pred = torch.zeros_like(edit_noise) + new_noise_pred[:, :, region_mask == 0] = edit_noise[:, :, region_mask == 0] + replace_ratio = 1.0 + new_noise_pred[:, :, region_mask != 0] = (1 - replace_ratio) * edit_noise[:, :, region_mask != 0] + + for region_prompt_embeds, region_add_text_embeds, region_add_time_ids, concept_mask, region_prompt, lora_param in zip(region_prompt_embeds_list, region_add_text_embeds_list, add_time_ids_list, mask_list, region_prompts, lora_list): + if concept_mask is not None: + concept_mask = F.interpolate(concept_mask.unsqueeze(0).unsqueeze(0), + size=(noise_pred.shape[2], noise_pred.shape[3]), + mode='nearest').squeeze().to(dtype=noise_pred.dtype, device=concept_models._execution_device) + + + region_latent_model_input = latent_model_input[3:4].clone().to(concept_models._execution_device) + + region_latent_model_input = torch.cat([region_latent_model_input] * 2) + region_added_cond_kwargs = {"text_embeds": region_add_text_embeds, + "time_ids": region_add_time_ids} + if styleL: + concept_models.set_adapters([lora_param, "style"], adapter_weights=[0.7, 0.5]) + else: + concept_models.set_adapters(lora_param) + region_noise_pred = concept_models.unet( + region_latent_model_input, + t, + encoder_hidden_states=region_prompt_embeds, + cross_attention_kwargs={'scale': 0.8}, + added_cond_kwargs=region_added_cond_kwargs, + return_dict=False, + )[0] + + new_noise_pred = new_noise_pred.to(concept_models._execution_device) + new_noise_pred[:, :, concept_mask==1] += replace_ratio * (region_noise_pred[:, :, concept_mask==1] / (concept_mask.reshape(1, 1, *concept_mask.shape)[:, :, concept_mask==1].to(region_noise_pred.device))) + + + new_noise_pred = new_noise_pred.to(noise_pred.device) + noise_pred[1, :, :, :] = new_noise_pred[0] + noise_pred[3, :, :, :] = new_noise_pred[1] + + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # manually for max memory savings + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + if stage==2: + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) + + def check_image(self, image, prompt, prompt_embeds): + pass + + def get_region_mask(self, mask_list, feat_height, feat_width): + exclusive_mask = torch.zeros((feat_height, feat_width)) + for mask in mask_list: + if mask is not None: + mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(feat_height, feat_width), + mode='nearest').squeeze().to(dtype=exclusive_mask.dtype, device=exclusive_mask.device) + exclusive_mask = ((mask == 1) | (exclusive_mask == 1)).to(dtype=mask.dtype) + return exclusive_mask diff --git a/src/prompt_attention/p2p_attention.py b/src/prompt_attention/p2p_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..afb92d2ed66d8f8b16eb9b70ace242bb47198693 --- /dev/null +++ b/src/prompt_attention/p2p_attention.py @@ -0,0 +1,148 @@ +from typing import Optional, Union, Tuple, List, Callable, Dict +import torch +import torch.nn.functional as nnf +import numpy as np +import abc +import src.prompt_attention.p2p_utils as p2p_utils +import src.prompt_attention.seq_aligner as seq_aligner + + + +class AttentionControl(abc.ABC): + + def step_callback(self, x_t): + return x_t + + def between_steps(self): + return + + @property + def num_uncond_att_layers(self): + # return self.num_att_layers if self.low_resource else 0 + return 0 + + @abc.abstractmethod + def forward(self, attn, is_cross: bool, place_in_unet: str): + raise NotImplementedError + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + if self.cur_att_layer >= self.num_uncond_att_layers: + if self.low_resource: + attn = self.forward(attn, is_cross, place_in_unet) + else: + h = attn.shape[0] + attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) + self.cur_att_layer += 1 + if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: + self.cur_att_layer = 0 + self.cur_step += 1 + self.between_steps() + return attn + + def reset(self): + self.cur_step = 0 + self.cur_att_layer = 0 + + def __init__(self, low_resource=False, width=None, height=None): + self.cur_step = 0 + self.num_att_layers = -1 + self.cur_att_layer = 0 + self.low_resource = low_resource + self.width = width + self.height = height + +class AttentionStore(AttentionControl): + + @staticmethod + def get_empty_store(): + return {"down_cross": [], "mid_cross": [], "up_cross": [], + "down_self": [], "mid_self": [], "up_self": []} + + def forward(self, attn, is_cross: bool, place_in_unet: str): + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + # if attn.shape[1] <= att_size * 64: + return attn + + def between_steps(self): + if self.save_global_store: + if len(self.attention_store) == 0: + self.attention_store = self.step_store + else: + for key in self.attention_store: + for i in range(len(self.attention_store[key])): + self.attention_store[key][i] += self.step_store[key][i] + else: + self.attention_store = self.step_store + self.step_store = self.get_empty_store() + + def get_average_attention(self): + average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in + self.attention_store} + return average_attention + + def reset(self): + super(AttentionStore, self).reset() + self.step_store = self.get_empty_store() + self.attention_store = {} + + def __init__(self, width, height, low_resolution=False, save_global_store=False): + super(AttentionStore, self).__init__(low_resolution, width, height) + self.step_store = self.get_empty_store() + self.attention_store = {} + self.save_global_store = save_global_store + +class AttentionControlEdit(AttentionStore, abc.ABC): + def __init__(self, prompts, num_steps: int, + cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], + self_replace_steps: Union[float, Tuple[float, float]], + local_blend=None, width=None, height=None, tokenizer=None, device=None): + super(AttentionControlEdit, self).__init__(width, height) + self.batch_size = len(prompts) + self.cross_replace_alpha = p2p_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, + tokenizer).to(device) + if type(self_replace_steps) is float: + self_replace_steps = 0, self_replace_steps + self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) + self.local_blend = local_blend + + def step_callback(self, x_t): + print("step_callback") + if self.local_blend is not None: + x_t = self.local_blend(x_t, self.attention_store) + return x_t + + def replace_self_attention(self, attn_base, att_replace): + if att_replace.shape[2] <= self.width * self.height: + return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) + else: + return att_replace + + @abc.abstractmethod + def replace_cross_attention(self, attn_base, att_replace): + raise NotImplementedError + + def forward(self, attn, is_cross: bool, place_in_unet: str): + super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) + if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): + h = attn.shape[0] // (self.batch_size) + attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) + attn_base, attn_repalce = attn[0], attn[1:] + if is_cross: + alpha_words = self.cross_replace_alpha[self.cur_step] + attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + ( + 1 - alpha_words) * attn_repalce + attn[1:] = attn_repalce_new + else: + attn[1:] = self.replace_self_attention(attn_base, attn_repalce) + attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) + return attn + +class AttentionReplace(AttentionControlEdit): + def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, width, height, + local_blend = None, tokenizer=None, device=None, dtype=None): + super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, width, height, tokenizer=tokenizer, device=device) + self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(dtype=dtype, device=device) + + def replace_cross_attention(self, attn_base, att_replace): + return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) + diff --git a/src/prompt_attention/p2p_utils.py b/src/prompt_attention/p2p_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..13a6dd4a1443378689e89a03de49981a131571b3 --- /dev/null +++ b/src/prompt_attention/p2p_utils.py @@ -0,0 +1,74 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFont +import cv2 +from typing import Optional, Union, Tuple, List, Callable, Dict + + + +def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, + word_inds: Optional[torch.Tensor] = None): + if type(bounds) is float: + bounds = 0, bounds + start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) + if word_inds is None: + word_inds = torch.arange(alpha.shape[2]) + alpha[: start, prompt_ind, word_inds] = 0 + alpha[start: end, prompt_ind, word_inds] = 1 + alpha[end:, prompt_ind, word_inds] = 0 + return alpha + +def get_word_inds(text: str, word_place: int, tokenizer): + split_text = text.split(" ") + if type(word_place) is str: + word_place = [i for i, word in enumerate(split_text) if word_place == word] + elif type(word_place) is int: + word_place = [word_place] + out = [] + if len(word_place) > 0: + words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] + cur_len, ptr = 0, 0 + + for i in range(len(words_encode)): + cur_len += len(words_encode[i]) + if ptr in word_place: + out.append(i + 1) + if cur_len >= len(split_text[ptr]): + ptr += 1 + cur_len = 0 + return np.array(out) + +def get_time_words_attention_alpha(prompts, num_steps, + cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], + tokenizer, max_num_words=77): + if type(cross_replace_steps) is not dict: + cross_replace_steps = {"default_": cross_replace_steps} + if "default_" not in cross_replace_steps: + cross_replace_steps["default_"] = (0., 1.) + alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) + for i in range(len(prompts) - 1): + alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], + i) + for key, item in cross_replace_steps.items(): + if key != "default_": + inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] + for i, ind in enumerate(inds): + if len(ind) > 0: + alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) + alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) + return alpha_time_words + diff --git a/src/prompt_attention/seq_aligner.py b/src/prompt_attention/seq_aligner.py new file mode 100644 index 0000000000000000000000000000000000000000..4530d05595a25f395693ef814e49a49b1581ce12 --- /dev/null +++ b/src/prompt_attention/seq_aligner.py @@ -0,0 +1,66 @@ +import torch +import numpy as np + + +def get_word_inds(text: str, word_place: int, tokenizer): + split_text = text.split(" ") + if type(word_place) is str: + word_place = [i for i, word in enumerate(split_text) if word_place == word] + elif type(word_place) is int: + word_place = [word_place] + out = [] + if len(word_place) > 0: + words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] + cur_len, ptr = 0, 0 + + for i in range(len(words_encode)): + cur_len += len(words_encode[i]) + if ptr in word_place: + out.append(i + 1) + if cur_len >= len(split_text[ptr]): + ptr += 1 + cur_len = 0 + return np.array(out) + +def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): + words_x = x.split(' ') + words_y = y.split(' ') + if len(words_x) != len(words_y): + raise ValueError(f"attention replacement edit can only be applied on prompts with the same length" + f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.") + inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] + inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] + inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] + mapper = np.zeros((max_len, max_len)) + i = j = 0 + cur_inds = 0 + while i < max_len and j < max_len: + if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: + inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] + if len(inds_source_) == len(inds_target_): + mapper[inds_source_, inds_target_] = 1 + else: + ratio = 1 / len(inds_target_) + for i_t in inds_target_: + mapper[inds_source_, i_t] = ratio + cur_inds += 1 + i += len(inds_source_) + j += len(inds_target_) + elif cur_inds < len(inds_source): + mapper[i, j] = 1 + i += 1 + j += 1 + else: + mapper[j, j] = 1 + i += 1 + j += 1 + + return torch.from_numpy(mapper).float() + +def get_replacement_mapper(prompts, tokenizer, max_len=77): + x_seq = prompts[0] + mappers = [] + for i in range(1, len(prompts)): + mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) + mappers.append(mapper) + return torch.stack(mappers) \ No newline at end of file