diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..69d497db4880f8e7fe27313b35b571d2f5952c30 Binary files /dev/null and b/.DS_Store differ diff --git a/ImageConductor_app.py b/ImageConductor_app.py new file mode 100644 index 0000000000000000000000000000000000000000..83bc9f07e920ffa8a3ccd6d2f23512daca0c7376 --- /dev/null +++ b/ImageConductor_app.py @@ -0,0 +1,586 @@ +import os +import gradio as gr +import numpy as np +import cv2 +import uuid +import torch +import torchvision +import json + +from PIL import Image +from omegaconf import OmegaConf +from einops import rearrange, repeat +from torchvision import transforms +from transformers import CLIPTextModel, CLIPTokenizer +from diffusers import AutoencoderKL, DDIMScheduler + +from pipelines.pipeline_imagecoductor import ImageConductorPipeline +from modules.unet import UNet3DConditionFlowModel +from utils.gradio_utils import ensure_dirname, split_filename, visualize_drag, image2pil, image2arr +from utils.utils import create_image_controlnet, create_flow_controlnet, interpolate_trajectory, load_weights, load_model, bivariate_Gaussian +from utils.lora_utils import add_LoRA_to_controlnet +from utils.visualizer import Visualizer, vis_flow_to_video +#### Description #### +title = r"""

CustomNet: Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models

""" + +head = r""" +
+

Image Conductor: Precision Control for Interactive Video Synthesis

+
+ + Project Page + + + + +
+
+
+""" + + + +descriptions = r""" +Official Gradio Demo for Image Conductor: Precision Control for Interactive Video Synthesis.
+🧙Image Conductor enables precise, fine-grained control for generating motion-controllable videos from images, advancing the practical application of interactive video synthesis.
+""" + + +instructions = r""" + - ⭐️ step1: Upload or select one image from Example. + - ⭐️ step2: Click 'Add Drag' to draw some drags. + - ⭐️ step3: Input text prompt that complements the image (Necessary). + - ⭐️ step4: Select 'Drag Mode' to specify the control of camera transition or object movement. + - ⭐️ step5: Click 'Run' button to generate video assets. + - ⭐️ others: Click 'Delete last drag' to delete the whole lastest path. Click 'Delete last step' to delete the lastest clicked control point. + """ + +citation = r""" +If Image Conductor is helpful, please help to ⭐ the Github Repo. Thanks! +[![GitHub Stars](https://img.shields.io/github/stars/liyaowei-stu%2FImageConductor)](https://github.com/liyaowei-stu/ImageConductor) +--- + +📝 **Citation** +
+If our work is useful for your research, please consider citing: +```bibtex +@misc{li2024imageconductor, + title={Image Conductor: Precision Control for Interactive Video Synthesis}, + author={Li, Yaowei and Wang, Xintao and Zhang, Zhaoyang and Wang, Zhouxia and Yuan, Ziyang and Xie, Liangbin and Zou, Yuexian and Shan, Ying}, + year={2024}, + eprint={2406.15339}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` + +📧 **Contact** +
+If you have any questions, please feel free to reach me out at ywl@stu.pku.edu.cn. + +# """ + +os.makedirs("models/personalized") +os.system(f'wget https://huggingface.co/TencentARC/ImageConductor/blob/main/flow_controlnet.ckpt -P models/') +os.system(f'wget https://huggingface.co/TencentARC/ImageConductor/blob/main/image_controlnet.ckpt -P models/') +os.system(f'wget https://huggingface.co/TencentARC/ImageConductor/blob/main/unet.ckpt -P models/') +os.system(f'wget https://huggingface.co/TencentARC/ImageConductor/blob/main/helloobjects_V12c.safetensors -P models/personalized') +os.system(f'wget https://huggingface.co/TencentARC/ImageConductor/blob/main/TUSUN.safetensors -P models/personalized') + + + + +# - - - - - examples - - - - - # +image_examples = [ + ["__asset__/images/object/turtle-1.jpg", + "a sea turtle gracefully swimming over a coral reef in the clear blue ocean.", + "object", + 11318446767408804497, + "", + json.load(open("__asset__/trajs/object/turtle-1.json")), + "__asset__/images/object/turtle-1.jpg", + ], + + ["__asset__/images/object/rose-1.jpg", + "a red rose engulfed in flames.", + "object", + 6854275249656120509, + "", + json.load(open("__asset__/trajs/object/rose-1.json")), + "__asset__/images/object/rose-1.jpg", + ], + + ["__asset__/images/object/jellyfish-1.jpg", + "intricate detailing,photorealism,hyperrealistic, glowing jellyfish mushroom, flying, starry sky, bokeh, golden ratio composition.", + "object", + 17966188172968903484, + "HelloObject", + json.load(open("__asset__/trajs/object/jellyfish-1.json")), + "__asset__/images/object/jellyfish-1.jpg", + ], + + + ["__asset__/images/camera/lush-1.jpg", + "detailed craftsmanship, photorealism, hyperrealistic, roaring waterfall, misty spray, lush greenery, vibrant rainbow, golden ratio composition.", + "camera", + 7970487946960948963, + "HelloObject", + json.load(open("__asset__/trajs/camera/lush-1.json")), + "__asset__/images/camera/lush-1.jpg", + ], + + ["__asset__/images/camera/tusun-1.jpg", + "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing.", + "camera", + 996953226890228361, + "TUSUN", + json.load(open("__asset__/trajs/camera/tusun-1.json")), + "__asset__/images/camera/tusun-1.jpg", + ], + + ["__asset__/images/camera/painting-1.jpg", + "A oil painting.", + "camera", + 16867854766769816385, + "", + json.load(open("__asset__/trajs/camera/painting-1.json")), + "__asset__/images/camera/painting-1.jpg", + ], + +] + + +DREAM_BOOTH = { + 'HelloObject': 'models/personalized/helloobjects_V12c.safetensors', +} + +LORA = { + 'TUSUN': 'models/personalized/TUSUN.safetensors', +} + +LORA_ALPHA = { + 'TUSUN': 0.6, +} + +NPROMPT = { + "HelloObject": 'FastNegativeV2,(bad-artist:1),(worst quality, low quality:1.4),(bad_prompt_version2:0.8),bad-hands-5,lowres,bad anatomy,bad hands,((text)),(watermark),error,missing fingers,extra digit,fewer digits,cropped,worst quality,low quality,normal quality,((username)),blurry,(extra limbs),bad-artist-anime,badhandv4,EasyNegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3,BadDream,(three hands:1.6),(three legs:1.2),(more than two hands:1.4),(more than two legs,:1.2)' +} + +output_dir = "outputs" +ensure_dirname(output_dir) + +def points_to_flows(track_points, model_length, height, width): + input_drag = np.zeros((model_length - 1, height, width, 2)) + for splited_track in track_points: + if len(splited_track) == 1: # stationary point + displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1]) + splited_track = tuple([splited_track[0], displacement_point]) + # interpolate the track + splited_track = interpolate_trajectory(splited_track, model_length) + splited_track = splited_track[:model_length] + if len(splited_track) < model_length: + splited_track = splited_track + [splited_track[-1]] * (model_length -len(splited_track)) + for i in range(model_length - 1): + start_point = splited_track[i] + end_point = splited_track[i+1] + input_drag[i][int(start_point[1])][int(start_point[0])][0] = end_point[0] - start_point[0] + input_drag[i][int(start_point[1])][int(start_point[0])][1] = end_point[1] - start_point[1] + return input_drag + +class ImageConductor: + def __init__(self, device, unet_path, image_controlnet_path, flow_controlnet_path, height, width, model_length, lora_rank=64): + self.device = device + tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").cuda() + vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").cuda() + inference_config = OmegaConf.load("configs/inference/inference.yaml") + unet = UNet3DConditionFlowModel.from_pretrained_2d("runwayml/stable-diffusion-v1-5", subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) + + self.vae = vae + + ### >>> Initialize UNet module >>> ### + load_model(unet, unet_path) + + ### >>> Initialize image controlnet module >>> ### + image_controlnet = create_image_controlnet("configs/inference/image_condition.yaml", unet) + load_model(image_controlnet, image_controlnet_path) + ### >>> Initialize flow controlnet module >>> ### + flow_controlnet = create_flow_controlnet("configs/inference/flow_condition.yaml", unet) + add_LoRA_to_controlnet(lora_rank, flow_controlnet) + load_model(flow_controlnet, flow_controlnet_path) + + unet.eval().to(device) + image_controlnet.eval().to(device) + flow_controlnet.eval().to(device) + + self.pipeline = ImageConductorPipeline( + unet=unet, + vae=vae, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), + image_controlnet=image_controlnet, + flow_controlnet=flow_controlnet, + ).to(device) + + + self.height = height + self.width = width + # _, model_step, _ = split_filename(model_path) + # self.ouput_prefix = f'{model_step}_{width}X{height}' + self.model_length = model_length + + blur_kernel = bivariate_Gaussian(kernel_size=99, sig_x=10, sig_y=10, theta=0, grid=None, isotropic=True) + + self.blur_kernel = blur_kernel + + @torch.no_grad() + def run(self, first_frame_path, tracking_points, prompt, drag_mode, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized): + + + original_width, original_height=384, 256 + if isinstance(tracking_points, list): + input_all_points = tracking_points + else: + input_all_points = tracking_points.constructor_args['value'] + + + resized_all_points = [tuple([tuple([float(e1[0]*self.width/original_width), float(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points] + + dir, base, ext = split_filename(first_frame_path) + id = base.split('_')[-1] + + + with open(f'{output_dir}/points-{id}.json', 'w') as f: + json.dump(input_all_points, f) + + + visualized_drag, _ = visualize_drag(first_frame_path, resized_all_points, drag_mode, self.width, self.height, self.model_length) + + ## image condition + image_transforms = transforms.Compose([ + transforms.RandomResizedCrop( + (self.height, self.width), (1.0, 1.0), + ratio=(self.width/self.height, self.width/self.height) + ), + transforms.ToTensor(), + ]) + + image_norm = lambda x: x + image_paths = [first_frame_path] + controlnet_images = [image_norm(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths] + controlnet_images = torch.stack(controlnet_images).unsqueeze(0).cuda() + controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w") + num_controlnet_images = controlnet_images.shape[2] + controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w") + controlnet_images = self.vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215 + controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images) + + # flow condition + controlnet_flows = points_to_flows(resized_all_points, self.model_length, self.height, self.width) + for i in range(0, self.model_length-1): + controlnet_flows[i] = cv2.filter2D(controlnet_flows[i], -1, self.blur_kernel) + controlnet_flows = np.concatenate([np.zeros_like(controlnet_flows[0])[np.newaxis, ...], controlnet_flows], axis=0) # pad the first frame with zero flow + os.makedirs(os.path.join(output_dir, "control_flows"), exist_ok=True) + trajs_video = vis_flow_to_video(controlnet_flows, num_frames=self.model_length) # T-1 x H x W x 3 + torchvision.io.write_video(f'{output_dir}/control_flows/sample-{id}-train_flow.mp4', trajs_video, fps=8, video_codec='h264', options={'crf': '10'}) + controlnet_flows = torch.from_numpy(controlnet_flows)[None].to(controlnet_images)[:, :self.model_length, ...] + controlnet_flows = rearrange(controlnet_flows, "b f h w c-> b c f h w") + + dreambooth_model_path = DREAM_BOOTH.get(personalized, '') + lora_model_path = LORA.get(personalized, '') + lora_alpha = LORA_ALPHA.get(personalized, 0.6) + self.pipeline = load_weights( + self.pipeline, + dreambooth_model_path = dreambooth_model_path, + lora_model_path = lora_model_path, + lora_alpha = lora_alpha, + ).to(device) + + if NPROMPT.get(personalized, '') != '': + negative_prompt = NPROMPT.get(personalized) + + if randomize_seed: + random_seed = torch.seed() + else: + seed = int(seed) + random_seed = seed + torch.manual_seed(random_seed) + torch.cuda.manual_seed_all(random_seed) + print(f"current seed: {torch.initial_seed()}") + sample = self.pipeline( + prompt, + negative_prompt = negative_prompt, + num_inference_steps = num_inference_steps, + guidance_scale = guidance_scale, + width = self.width, + height = self.height, + video_length = self.model_length, + controlnet_images = controlnet_images, # 1 4 1 32 48 + controlnet_image_index = [0], + controlnet_flows = controlnet_flows,# [1, 2, 16, 256, 384] + control_mode = drag_mode, + eval_mode = True, + ).videos + + outputs_path = os.path.join(output_dir, f'output_{i}_{id}.mp4') + vis_video = (rearrange(sample[0], 'c t h w -> t h w c') * 255.).clip(0, 255) + torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec='h264', options={'crf': '10'}) + + return visualized_drag, outputs_path + + +def reset_states(first_frame_path, tracking_points): + first_frame_path = gr.State() + tracking_points = gr.State([]) + return None, first_frame_path, tracking_points + + +def preprocess_image(image): + image_pil = image2pil(image.name) + raw_w, raw_h = image_pil.size + resize_ratio = max(384/raw_w, 256/raw_h) + image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR) + image_pil = transforms.CenterCrop((256, 384))(image_pil.convert('RGB')) + id = str(uuid.uuid4())[:4] + first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg") + image_pil.save(first_frame_path, quality=95) + return first_frame_path, first_frame_path, gr.State([]) + + +def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.SelectData): # SelectData is a subclass of EventData + if drag_mode=='object': + color = (255, 0, 0, 255) + elif drag_mode=='camera': + color = (0, 0, 255, 255) + + + print(f"You selected {evt.value} at {evt.index} from {evt.target}") + tracking_points.constructor_args['value'][-1].append(evt.index) + print(tracking_points.constructor_args) + + transparent_background = Image.open(first_frame_path).convert('RGBA') + w, h = transparent_background.size + transparent_layer = np.zeros((h, w, 4)) + for track in tracking_points.constructor_args['value']: + if len(track) > 1: + for i in range(len(track)-1): + start_point = track[i] + end_point = track[i+1] + vx = end_point[0] - start_point[0] + vy = end_point[1] - start_point[1] + arrow_length = np.sqrt(vx**2 + vy**2) + if i == len(track)-2: + cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length) + else: + cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,) + else: + cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1) + + transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) + trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) + return tracking_points, trajectory_map + + +def add_drag(tracking_points): + tracking_points.constructor_args['value'].append([]) + print(tracking_points.constructor_args) + return tracking_points + + +def delete_last_drag(tracking_points, first_frame_path, drag_mode): + if drag_mode=='object': + color = (255, 0, 0, 255) + elif drag_mode=='camera': + color = (0, 0, 255, 255) + tracking_points.constructor_args['value'].pop() + transparent_background = Image.open(first_frame_path).convert('RGBA') + w, h = transparent_background.size + transparent_layer = np.zeros((h, w, 4)) + for track in tracking_points.constructor_args['value']: + if len(track) > 1: + for i in range(len(track)-1): + start_point = track[i] + end_point = track[i+1] + vx = end_point[0] - start_point[0] + vy = end_point[1] - start_point[1] + arrow_length = np.sqrt(vx**2 + vy**2) + if i == len(track)-2: + cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length) + else: + cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,) + else: + cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1) + + transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) + trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) + return tracking_points, trajectory_map + + +def delete_last_step(tracking_points, first_frame_path, drag_mode): + if drag_mode=='object': + color = (255, 0, 0, 255) + elif drag_mode=='camera': + color = (0, 0, 255, 255) + tracking_points.constructor_args['value'][-1].pop() + transparent_background = Image.open(first_frame_path).convert('RGBA') + w, h = transparent_background.size + transparent_layer = np.zeros((h, w, 4)) + for track in tracking_points.constructor_args['value']: + if len(track) > 1: + for i in range(len(track)-1): + start_point = track[i] + end_point = track[i+1] + vx = end_point[0] - start_point[0] + vy = end_point[1] - start_point[1] + arrow_length = np.sqrt(vx**2 + vy**2) + if i == len(track)-2: + cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length) + else: + cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,) + else: + cv2.circle(transparent_layer, tuple(track[0]), 5,color, -1) + + transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) + trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) + return tracking_points, trajectory_map + + +block = gr.Blocks( + theme=gr.themes.Soft( + radius_size=gr.themes.sizes.radius_none, + text_size=gr.themes.sizes.text_md + ) + ).queue() +with block as demo: + with gr.Row(): + with gr.Column(): + gr.HTML(head) + + gr.Markdown(descriptions) + + with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"): + with gr.Row(equal_height=True): + gr.Markdown(instructions) + + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + unet_path = 'models/unet.ckpt' + image_controlnet_path = 'models/image_controlnet.ckpt' + flow_controlnet_path = 'models/flow_controlnet.ckpt' + ImageConductor_net = ImageConductor(device=device, + unet_path=unet_path, + image_controlnet_path=image_controlnet_path, + flow_controlnet_path=flow_controlnet_path, + height=256, + width=384, + model_length=16 + ) + first_frame_path = gr.State() + tracking_points = gr.State([]) + + + with gr.Row(): + with gr.Column(scale=1): + image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"]) + add_drag_button = gr.Button(value="Add Drag") + reset_button = gr.Button(value="Reset") + delete_last_drag_button = gr.Button(value="Delete last drag") + delete_last_step_button = gr.Button(value="Delete last step") + + + + with gr.Column(scale=7): + with gr.Row(): + with gr.Column(scale=6): + input_image = gr.Image(label=None, + interactive=True, + height=256, + width=384,) + with gr.Column(scale=6): + output_image = gr.Image(label="Motion Path", + interactive=False, + height=256, + width=384,) + with gr.Row(): + with gr.Column(scale=1): + prompt = gr.Textbox(value="a wonderful elf.", label="Prompt (highly-recommended)", interactive=True, visible=True) + negative_prompt = gr.Text( + label="Negative Prompt", + max_lines=5, + placeholder="Please input your negative prompt", + value='worst quality, low quality, letterboxed',lines=1 + ) + drag_mode = gr.Radio(['camera', 'object'], label='Drag mode: ', value='object', scale=2) + run_button = gr.Button(value="Run") + + with gr.Accordion("More input params", open=False, elem_id="accordion1"): + with gr.Group(): + seed = gr.Textbox( + label="Seed: ", value=561793204, + ) + randomize_seed = gr.Checkbox(label="Randomize seed", value=False) + + with gr.Group(): + with gr.Row(): + guidance_scale = gr.Slider( + label="Guidance scale", + minimum=1, + maximum=12, + step=0.1, + value=8.5, + ) + num_inference_steps = gr.Slider( + label="Number of inference steps", + minimum=1, + maximum=50, + step=1, + value=25, + ) + + with gr.Group(): + personalized = gr.Dropdown(label="Personalized template", choices=['HelloObject', 'TUSUN'], value="") + + with gr.Column(scale=7): + output_video = gr.Video(value=None, + label="Output Video", + width=384, + height=256) + + + with gr.Row(): + def process_example(input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path): + + return input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path + + example = gr.Examples( + label="Input Example", + examples=image_examples, + inputs=[input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path], + outputs=[input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path], + fn=process_example, + run_on_click=True, + examples_per_page=10 + ) + + with gr.Row(): + gr.Markdown(citation) + + + image_upload_button.upload(preprocess_image, image_upload_button, [input_image, first_frame_path, tracking_points]) + + add_drag_button.click(add_drag, tracking_points, tracking_points) + + delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image]) + + delete_last_step_button.click(delete_last_step, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image]) + + reset_button.click(reset_states, [first_frame_path, tracking_points], [input_image, first_frame_path, tracking_points]) + + input_image.select(add_tracking_points, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image]) + + run_button.click(ImageConductor_net.run, [first_frame_path, tracking_points, prompt, drag_mode, + negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized], + [output_image, output_video]) + +demo.launch(server_name="0.0.0.0", debug=True, server_port=12345) diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..ac1509de5a28f48a3b551ceb3dbc1fffbc5c2e59 --- /dev/null +++ b/app.py @@ -0,0 +1,577 @@ +import os +import gradio as gr +import numpy as np +import cv2 +import uuid +import torch +import torchvision +import json + +from PIL import Image +from omegaconf import OmegaConf +from einops import rearrange, repeat +from torchvision import transforms +from transformers import CLIPTextModel, CLIPTokenizer +from diffusers import AutoencoderKL, DDIMScheduler + +from pipelines.pipeline_imagecoductor import ImageConductorPipeline +from modules.unet import UNet3DConditionFlowModel +from utils.gradio_utils import ensure_dirname, split_filename, visualize_drag, image2pil, image2arr +from utils.utils import create_image_controlnet, create_flow_controlnet, interpolate_trajectory, load_weights, load_model, bivariate_Gaussian +from utils.lora_utils import add_LoRA_to_controlnet +from utils.visualizer import Visualizer, vis_flow_to_video +#### Description #### +title = r"""

CustomNet: Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models

""" + +head = r""" +
+

Image Conductor: Precision Control for Interactive Video Synthesis

+
+ + Project Page + + + + +
+
+
+""" + + + +descriptions = r""" +Official Gradio Demo for Image Conductor: Precision Control for Interactive Video Synthesis.
+🧙Image Conductor enables precise, fine-grained control for generating motion-controllable videos from images, advancing the practical application of interactive video synthesis.
+""" + + +instructions = r""" + - ⭐️ step1: Upload or select one image from Example. + - ⭐️ step2: Click 'Add Drag' to draw some drags. + - ⭐️ step3: Input text prompt that complements the image (Necessary). + - ⭐️ step4: Select 'Drag Mode' to specify the control of camera transition or object movement. + - ⭐️ step5: Click 'Run' button to generate video assets. + - ⭐️ others: Click 'Delete last drag' to delete the whole lastest path. Click 'Delete last step' to delete the lastest clicked control point. + """ + +citation = r""" +If Image Conductor is helpful, please help to ⭐ the Github Repo. Thanks! +[![GitHub Stars](https://img.shields.io/github/stars/liyaowei-stu%2FImageConductor)](https://github.com/liyaowei-stu/ImageConductor) +--- + +📝 **Citation** +
+If our work is useful for your research, please consider citing: +```bibtex +@misc{li2024imageconductor, + title={Image Conductor: Precision Control for Interactive Video Synthesis}, + author={Li, Yaowei and Wang, Xintao and Zhang, Zhaoyang and Wang, Zhouxia and Yuan, Ziyang and Xie, Liangbin and Zou, Yuexian and Shan, Ying}, + year={2024}, + eprint={2406.15339}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` + +📧 **Contact** +
+If you have any questions, please feel free to reach me out at ywl@stu.pku.edu.cn. + +# """ + + +# - - - - - examples - - - - - # +image_examples = [ + ["__asset__/images/object/turtle-1.jpg", + "a sea turtle gracefully swimming over a coral reef in the clear blue ocean.", + "object", + 11318446767408804497, + "", + json.load(open("__asset__/trajs/object/turtle-1.json")), + "__asset__/images/object/turtle-1.jpg", + ], + + ["__asset__/images/object/rose-1.jpg", + "a red rose engulfed in flames.", + "object", + 6854275249656120509, + "", + json.load(open("__asset__/trajs/object/rose-1.json")), + "__asset__/images/object/rose-1.jpg", + ], + + ["__asset__/images/object/jellyfish-1.jpg", + "intricate detailing,photorealism,hyperrealistic, glowing jellyfish mushroom, flying, starry sky, bokeh, golden ratio composition.", + "object", + 17966188172968903484, + "HelloObject", + json.load(open("__asset__/trajs/object/jellyfish-1.json")), + "__asset__/images/object/jellyfish-1.jpg", + ], + + + ["__asset__/images/camera/lush-1.jpg", + "detailed craftsmanship, photorealism, hyperrealistic, roaring waterfall, misty spray, lush greenery, vibrant rainbow, golden ratio composition.", + "camera", + 7970487946960948963, + "HelloObject", + json.load(open("__asset__/trajs/camera/lush-1.json")), + "__asset__/images/camera/lush-1.jpg", + ], + + ["__asset__/images/camera/tusun-1.jpg", + "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing.", + "camera", + 996953226890228361, + "TUSUN", + json.load(open("__asset__/trajs/camera/tusun-1.json")), + "__asset__/images/camera/tusun-1.jpg", + ], + + ["__asset__/images/camera/painting-1.jpg", + "A oil painting.", + "camera", + 16867854766769816385, + "", + json.load(open("__asset__/trajs/camera/painting-1.json")), + "__asset__/images/camera/painting-1.jpg", + ], + +] + + +DREAM_BOOTH = { + 'HelloObject': 'models/personalized/helloobjects_V12c.safetensors', +} + +LORA = { + 'TUSUN': 'models/personalized/TUSUN.safetensors', +} + +LORA_ALPHA = { + 'TUSUN': 0.6, +} + +NPROMPT = { + "HelloObject": 'FastNegativeV2,(bad-artist:1),(worst quality, low quality:1.4),(bad_prompt_version2:0.8),bad-hands-5,lowres,bad anatomy,bad hands,((text)),(watermark),error,missing fingers,extra digit,fewer digits,cropped,worst quality,low quality,normal quality,((username)),blurry,(extra limbs),bad-artist-anime,badhandv4,EasyNegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3,BadDream,(three hands:1.6),(three legs:1.2),(more than two hands:1.4),(more than two legs,:1.2)' +} + +output_dir = "outputs" +ensure_dirname(output_dir) + +def points_to_flows(track_points, model_length, height, width): + input_drag = np.zeros((model_length - 1, height, width, 2)) + for splited_track in track_points: + if len(splited_track) == 1: # stationary point + displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1]) + splited_track = tuple([splited_track[0], displacement_point]) + # interpolate the track + splited_track = interpolate_trajectory(splited_track, model_length) + splited_track = splited_track[:model_length] + if len(splited_track) < model_length: + splited_track = splited_track + [splited_track[-1]] * (model_length -len(splited_track)) + for i in range(model_length - 1): + start_point = splited_track[i] + end_point = splited_track[i+1] + input_drag[i][int(start_point[1])][int(start_point[0])][0] = end_point[0] - start_point[0] + input_drag[i][int(start_point[1])][int(start_point[0])][1] = end_point[1] - start_point[1] + return input_drag + +class ImageConductor: + def __init__(self, device, unet_path, image_controlnet_path, flow_controlnet_path, height, width, model_length, lora_rank=64): + self.device = device + tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").cuda() + vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").cuda() + inference_config = OmegaConf.load("configs/inference/inference.yaml") + unet = UNet3DConditionFlowModel.from_pretrained_2d("runwayml/stable-diffusion-v1-5", subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) + + self.vae = vae + + ### >>> Initialize UNet module >>> ### + load_model(unet, unet_path) + + ### >>> Initialize image controlnet module >>> ### + image_controlnet = create_image_controlnet("configs/inference/image_condition.yaml", unet) + load_model(image_controlnet, image_controlnet_path) + ### >>> Initialize flow controlnet module >>> ### + flow_controlnet = create_flow_controlnet("configs/inference/flow_condition.yaml", unet) + add_LoRA_to_controlnet(lora_rank, flow_controlnet) + load_model(flow_controlnet, flow_controlnet_path) + + unet.eval().to(device) + image_controlnet.eval().to(device) + flow_controlnet.eval().to(device) + + self.pipeline = ImageConductorPipeline( + unet=unet, + vae=vae, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), + image_controlnet=image_controlnet, + flow_controlnet=flow_controlnet, + ).to(device) + + + self.height = height + self.width = width + # _, model_step, _ = split_filename(model_path) + # self.ouput_prefix = f'{model_step}_{width}X{height}' + self.model_length = model_length + + blur_kernel = bivariate_Gaussian(kernel_size=99, sig_x=10, sig_y=10, theta=0, grid=None, isotropic=True) + + self.blur_kernel = blur_kernel + + @torch.no_grad() + def run(self, first_frame_path, tracking_points, prompt, drag_mode, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized): + + + original_width, original_height=384, 256 + if isinstance(tracking_points, list): + input_all_points = tracking_points + else: + input_all_points = tracking_points.constructor_args['value'] + + + resized_all_points = [tuple([tuple([float(e1[0]*self.width/original_width), float(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points] + + dir, base, ext = split_filename(first_frame_path) + id = base.split('_')[-1] + + + with open(f'{output_dir}/points-{id}.json', 'w') as f: + json.dump(input_all_points, f) + + + visualized_drag, _ = visualize_drag(first_frame_path, resized_all_points, drag_mode, self.width, self.height, self.model_length) + + ## image condition + image_transforms = transforms.Compose([ + transforms.RandomResizedCrop( + (self.height, self.width), (1.0, 1.0), + ratio=(self.width/self.height, self.width/self.height) + ), + transforms.ToTensor(), + ]) + + image_norm = lambda x: x + image_paths = [first_frame_path] + controlnet_images = [image_norm(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths] + controlnet_images = torch.stack(controlnet_images).unsqueeze(0).cuda() + controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w") + num_controlnet_images = controlnet_images.shape[2] + controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w") + controlnet_images = self.vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215 + controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images) + + # flow condition + controlnet_flows = points_to_flows(resized_all_points, self.model_length, self.height, self.width) + for i in range(0, self.model_length-1): + controlnet_flows[i] = cv2.filter2D(controlnet_flows[i], -1, self.blur_kernel) + controlnet_flows = np.concatenate([np.zeros_like(controlnet_flows[0])[np.newaxis, ...], controlnet_flows], axis=0) # pad the first frame with zero flow + os.makedirs(os.path.join(output_dir, "control_flows"), exist_ok=True) + trajs_video = vis_flow_to_video(controlnet_flows, num_frames=self.model_length) # T-1 x H x W x 3 + torchvision.io.write_video(f'{output_dir}/control_flows/sample-{id}-train_flow.mp4', trajs_video, fps=8, video_codec='h264', options={'crf': '10'}) + controlnet_flows = torch.from_numpy(controlnet_flows)[None].to(controlnet_images)[:, :self.model_length, ...] + controlnet_flows = rearrange(controlnet_flows, "b f h w c-> b c f h w") + + dreambooth_model_path = DREAM_BOOTH.get(personalized, '') + lora_model_path = LORA.get(personalized, '') + lora_alpha = LORA_ALPHA.get(personalized, 0.6) + self.pipeline = load_weights( + self.pipeline, + dreambooth_model_path = dreambooth_model_path, + lora_model_path = lora_model_path, + lora_alpha = lora_alpha, + ).to(device) + + if NPROMPT.get(personalized, '') != '': + negative_prompt = NPROMPT.get(personalized) + + if randomize_seed: + random_seed = torch.seed() + else: + seed = int(seed) + random_seed = seed + torch.manual_seed(random_seed) + torch.cuda.manual_seed_all(random_seed) + print(f"current seed: {torch.initial_seed()}") + sample = self.pipeline( + prompt, + negative_prompt = negative_prompt, + num_inference_steps = num_inference_steps, + guidance_scale = guidance_scale, + width = self.width, + height = self.height, + video_length = self.model_length, + controlnet_images = controlnet_images, # 1 4 1 32 48 + controlnet_image_index = [0], + controlnet_flows = controlnet_flows,# [1, 2, 16, 256, 384] + control_mode = drag_mode, + eval_mode = True, + ).videos + + outputs_path = os.path.join(output_dir, f'output_{i}_{id}.mp4') + vis_video = (rearrange(sample[0], 'c t h w -> t h w c') * 255.).clip(0, 255) + torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec='h264', options={'crf': '10'}) + + return visualized_drag, outputs_path + + +def reset_states(first_frame_path, tracking_points): + first_frame_path = gr.State() + tracking_points = gr.State([]) + return None, first_frame_path, tracking_points + + +def preprocess_image(image): + image_pil = image2pil(image.name) + raw_w, raw_h = image_pil.size + resize_ratio = max(384/raw_w, 256/raw_h) + image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR) + image_pil = transforms.CenterCrop((256, 384))(image_pil.convert('RGB')) + id = str(uuid.uuid4())[:4] + first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg") + image_pil.save(first_frame_path, quality=95) + return first_frame_path, first_frame_path, gr.State([]) + + +def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.SelectData): # SelectData is a subclass of EventData + if drag_mode=='object': + color = (255, 0, 0, 255) + elif drag_mode=='camera': + color = (0, 0, 255, 255) + + + print(f"You selected {evt.value} at {evt.index} from {evt.target}") + tracking_points.constructor_args['value'][-1].append(evt.index) + print(tracking_points.constructor_args) + + transparent_background = Image.open(first_frame_path).convert('RGBA') + w, h = transparent_background.size + transparent_layer = np.zeros((h, w, 4)) + for track in tracking_points.constructor_args['value']: + if len(track) > 1: + for i in range(len(track)-1): + start_point = track[i] + end_point = track[i+1] + vx = end_point[0] - start_point[0] + vy = end_point[1] - start_point[1] + arrow_length = np.sqrt(vx**2 + vy**2) + if i == len(track)-2: + cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length) + else: + cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,) + else: + cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1) + + transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) + trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) + return tracking_points, trajectory_map + + +def add_drag(tracking_points): + tracking_points.constructor_args['value'].append([]) + print(tracking_points.constructor_args) + return tracking_points + + +def delete_last_drag(tracking_points, first_frame_path, drag_mode): + if drag_mode=='object': + color = (255, 0, 0, 255) + elif drag_mode=='camera': + color = (0, 0, 255, 255) + tracking_points.constructor_args['value'].pop() + transparent_background = Image.open(first_frame_path).convert('RGBA') + w, h = transparent_background.size + transparent_layer = np.zeros((h, w, 4)) + for track in tracking_points.constructor_args['value']: + if len(track) > 1: + for i in range(len(track)-1): + start_point = track[i] + end_point = track[i+1] + vx = end_point[0] - start_point[0] + vy = end_point[1] - start_point[1] + arrow_length = np.sqrt(vx**2 + vy**2) + if i == len(track)-2: + cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length) + else: + cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,) + else: + cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1) + + transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) + trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) + return tracking_points, trajectory_map + + +def delete_last_step(tracking_points, first_frame_path, drag_mode): + if drag_mode=='object': + color = (255, 0, 0, 255) + elif drag_mode=='camera': + color = (0, 0, 255, 255) + tracking_points.constructor_args['value'][-1].pop() + transparent_background = Image.open(first_frame_path).convert('RGBA') + w, h = transparent_background.size + transparent_layer = np.zeros((h, w, 4)) + for track in tracking_points.constructor_args['value']: + if len(track) > 1: + for i in range(len(track)-1): + start_point = track[i] + end_point = track[i+1] + vx = end_point[0] - start_point[0] + vy = end_point[1] - start_point[1] + arrow_length = np.sqrt(vx**2 + vy**2) + if i == len(track)-2: + cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length) + else: + cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,) + else: + cv2.circle(transparent_layer, tuple(track[0]), 5,color, -1) + + transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) + trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) + return tracking_points, trajectory_map + + +block = gr.Blocks( + theme=gr.themes.Soft( + radius_size=gr.themes.sizes.radius_none, + text_size=gr.themes.sizes.text_md + ) + ).queue() +with block as demo: + with gr.Row(): + with gr.Column(): + gr.HTML(head) + + gr.Markdown(descriptions) + + with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"): + with gr.Row(equal_height=True): + gr.Markdown(instructions) + + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + unet_path = 'models/unet.ckpt' + image_controlnet_path = 'models/image_controlnet.ckpt' + flow_controlnet_path = 'models/flow_controlnet.ckpt' + ImageConductor_net = ImageConductor(device=device, + unet_path=unet_path, + image_controlnet_path=image_controlnet_path, + flow_controlnet_path=flow_controlnet_path, + height=256, + width=384, + model_length=16 + ) + first_frame_path = gr.State() + tracking_points = gr.State([]) + + + with gr.Row(): + with gr.Column(scale=1): + image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"]) + add_drag_button = gr.Button(value="Add Drag") + reset_button = gr.Button(value="Reset") + delete_last_drag_button = gr.Button(value="Delete last drag") + delete_last_step_button = gr.Button(value="Delete last step") + + + + with gr.Column(scale=7): + with gr.Row(): + with gr.Column(scale=6): + input_image = gr.Image(label=None, + interactive=True, + height=256, + width=384,) + with gr.Column(scale=6): + output_image = gr.Image(label="Motion Path", + interactive=False, + height=256, + width=384,) + with gr.Row(): + with gr.Column(scale=1): + prompt = gr.Textbox(value="a wonderful elf.", label="Prompt (highly-recommended)", interactive=True, visible=True) + negative_prompt = gr.Text( + label="Negative Prompt", + max_lines=5, + placeholder="Please input your negative prompt", + value='worst quality, low quality, letterboxed',lines=1 + ) + drag_mode = gr.Radio(['camera', 'object'], label='Drag mode: ', value='object', scale=2) + run_button = gr.Button(value="Run") + + with gr.Accordion("More input params", open=False, elem_id="accordion1"): + with gr.Group(): + seed = gr.Textbox( + label="Seed: ", value=561793204, + ) + randomize_seed = gr.Checkbox(label="Randomize seed", value=False) + + with gr.Group(): + with gr.Row(): + guidance_scale = gr.Slider( + label="Guidance scale", + minimum=1, + maximum=12, + step=0.1, + value=8.5, + ) + num_inference_steps = gr.Slider( + label="Number of inference steps", + minimum=1, + maximum=50, + step=1, + value=25, + ) + + with gr.Group(): + personalized = gr.Dropdown(label="Personalized template", choices=['HelloObject', 'TUSUN'], value="") + + with gr.Column(scale=7): + output_video = gr.Video(value=None, + label="Output Video", + width=384, + height=256) + + + with gr.Row(): + def process_example(input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path): + + return input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path + + example = gr.Examples( + label="Input Example", + examples=image_examples, + inputs=[input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path], + outputs=[input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path], + fn=process_example, + run_on_click=True, + examples_per_page=10 + ) + + with gr.Row(): + gr.Markdown(citation) + + + image_upload_button.upload(preprocess_image, image_upload_button, [input_image, first_frame_path, tracking_points]) + + add_drag_button.click(add_drag, tracking_points, tracking_points) + + delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image]) + + delete_last_step_button.click(delete_last_step, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image]) + + reset_button.click(reset_states, [first_frame_path, tracking_points], [input_image, first_frame_path, tracking_points]) + + input_image.select(add_tracking_points, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image]) + + run_button.click(ImageConductor_net.run, [first_frame_path, tracking_points, prompt, drag_mode, + negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized], + [output_image, output_video]) + +demo.launch(server_name="0.0.0.0", debug=True, server_port=12345) diff --git a/configs/.DS_Store b/configs/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/configs/.DS_Store differ diff --git a/configs/inference/flow_condition.yaml b/configs/inference/flow_condition.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7c7c229b8a0356771009e019b4d2a8107bb2d1d5 --- /dev/null +++ b/configs/inference/flow_condition.yaml @@ -0,0 +1,18 @@ +controlnet_additional_kwargs: + set_noisy_sample_input_to_zero: true + use_simplified_condition_embedding: true + conditioning_channels: 2 + concate_conditioning_mask: false + + use_motion_module: true + motion_module_resolutions: [1,2,4,8] + motion_module_mid_block: false + motion_module_type: "Vanilla" + + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: [ "Temporal_Self" ] + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 diff --git a/configs/inference/image_condition.yaml b/configs/inference/image_condition.yaml new file mode 100644 index 0000000000000000000000000000000000000000..033d02fabcdfb1421854cd35598f0a4e43270bdf --- /dev/null +++ b/configs/inference/image_condition.yaml @@ -0,0 +1,18 @@ +controlnet_additional_kwargs: + set_noisy_sample_input_to_zero: true + use_simplified_condition_embedding: true + conditioning_channels: 4 + concate_conditioning_mask: true + + use_motion_module: true + motion_module_resolutions: [1,2,4,8] + motion_module_mid_block: false + motion_module_type: "Vanilla" + + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: [ "Temporal_Self" ] + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 diff --git a/configs/inference/inference.yaml b/configs/inference/inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd9a34d7839374221548c5a39cfb1c744f96b3db --- /dev/null +++ b/configs/inference/inference.yaml @@ -0,0 +1,22 @@ +unet_additional_kwargs: + use_inflated_groupnorm: true + use_motion_module: true + motion_module_resolutions: [1,2,4,8] + motion_module_mid_block: false + motion_module_type: Vanilla + + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: [ "Temporal_Self", "Temporal_Self" ] + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 + zero_initialize: true + +noise_scheduler_kwargs: + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" + steps_offset: 1 + clip_sample: False diff --git a/models/.DS_Store b/models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..71a30893823876826ad4987d58e9ce4b41291cea Binary files /dev/null and b/models/.DS_Store differ diff --git a/modules/__pycache__/attention.cpython-310.pyc b/modules/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95b780a82fdc5cdeafc7e567c3a86d4a43efe2ab Binary files /dev/null and b/modules/__pycache__/attention.cpython-310.pyc differ diff --git a/modules/__pycache__/flow_controlnet.cpython-310.pyc b/modules/__pycache__/flow_controlnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c35d216cd0d7baef8196beb2c2446afa2eb20e2c Binary files /dev/null and b/modules/__pycache__/flow_controlnet.cpython-310.pyc differ diff --git a/modules/__pycache__/image_controlnet.cpython-310.pyc b/modules/__pycache__/image_controlnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dd72a42170af585e9828058717f46e2abd0c5ce Binary files /dev/null and b/modules/__pycache__/image_controlnet.cpython-310.pyc differ diff --git a/modules/__pycache__/motion_module.cpython-310.pyc b/modules/__pycache__/motion_module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02d724ed29e82bf76b28920ea94f9327127dd4e2 Binary files /dev/null and b/modules/__pycache__/motion_module.cpython-310.pyc differ diff --git a/modules/__pycache__/resnet.cpython-310.pyc b/modules/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b875ec156c21364e2018dd36c2bf05af2909104 Binary files /dev/null and b/modules/__pycache__/resnet.cpython-310.pyc differ diff --git a/modules/__pycache__/unet.cpython-310.pyc b/modules/__pycache__/unet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d777c33c68db0a708e791a9f7024e9994148cc59 Binary files /dev/null and b/modules/__pycache__/unet.cpython-310.pyc differ diff --git a/modules/__pycache__/unet_blocks.cpython-310.pyc b/modules/__pycache__/unet_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2abde7953e12c1b9ffbed93ba9d69a90f256bd3 Binary files /dev/null and b/modules/__pycache__/unet_blocks.cpython-310.pyc differ diff --git a/modules/attention.py b/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..34353798c721fa35660600892503614dbe6a0170 --- /dev/null +++ b/modules/attention.py @@ -0,0 +1,396 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py + +import logging +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models import ModelMixin +from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward +from diffusers.utils import BaseOutput +from diffusers.utils.torch_utils import maybe_allow_in_graph +from einops import rearrange, repeat +from torch import Tensor, nn + +logger = logging.getLogger(__name__) + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +@maybe_allow_in_graph +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + # validate input dim + if hidden_states.dim() != 5: + raise ValueError( + f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + ) + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # shenanigans for motion module + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + encoder_hidden_states = repeat( + encoder_hidden_states, "b n c -> (b f) n c", f=video_length + ) + + # 1. Input + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length, + encoder_attention_mask=encoder_attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + + # 3. Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + unet_use_cross_frame_attention: bool = False, + unet_use_temporal_attention: bool = False, + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.unet_use_cross_frame_attention = unet_use_cross_frame_attention + self.unet_use_temporal_attention = unet_use_temporal_attention + + # Define 3 blocks. Each block has its own normalization layer. + # Self-Attn / SC-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + if unet_use_cross_frame_attention: + # this isn't actually implemented anywhere in the AnimateDiff codebase or in Diffusers... + raise NotImplementedError("SC-Attn is not implemented yet.") + else: + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=( + cross_attention_dim if only_cross_attention else None + ), + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None: + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) + + # 4. Temporal Attn + assert unet_use_temporal_attention is not None + if unet_use_temporal_attention: + self.attn_temp = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm( + dim, elementwise_affine=norm_elementwise_affine + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + video_length=None, + ): + # SparseCausal-Attention + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + if self.unet_use_cross_frame_attention: + cross_attention_kwargs["video_length"] = video_length + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=( + encoder_hidden_states if self.only_cross_attention else None + ), + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # 4. Temporal-Attention + if self.unet_use_temporal_attention: + d = hidden_states.shape[1] + hidden_states = rearrange( + hidden_states, "(b f) d c -> (b d) f c", f=video_length + ) + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm_temp(hidden_states) + ) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # 4. Temporal-Attention + if self.unet_use_temporal_attention: + d = hidden_states.shape[1] + hidden_states = rearrange( + hidden_states, "(b f) d c -> (b d) f c", f=video_length + ) + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm_temp(hidden_states) + ) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states diff --git a/modules/flow_controlnet.py b/modules/flow_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9de0fcc3c8cf04e71e54ac295c5527bf75dfb2cf --- /dev/null +++ b/modules/flow_controlnet.py @@ -0,0 +1,591 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Changes were made to this source code by Yuwei Guo. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from diffusers import ModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention_processor import AttentionProcessor +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers.loaders import UNet2DConditionLoadersMixin, PeftAdapterMixin +from diffusers.utils import BaseOutput, logging +from einops import rearrange, repeat +from torch import nn +from torch.nn import functional as F + +from .resnet import InflatedConv3d +from .unet_blocks import ( + CrossAttnDownBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + get_down_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FlowControlNetOutput(BaseOutput): + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class FlowControlNetConditioningEmbedding(nn.Module): + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class FlowControlNetModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 3, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + + use_motion_module = True, + motion_module_resolutions = ( 1,2,4,8 ), + motion_module_mid_block = False, + motion_module_type = "Vanilla", + motion_module_kwargs = { + "num_attention_heads": 8, + "num_transformer_block": 1, + "attention_block_types": ["Temporal_Self"], + "temporal_position_encoding": True, + "temporal_position_encoding_max_len": 32, + "temporal_attention_dim_div": 1, + "causal_temporal_attention": False, + }, + + concate_conditioning_mask: bool = True, + use_simplified_condition_embedding: bool = False, + + set_noisy_sample_input_to_zero: bool = False, + ): + super().__init__() + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + # input + self.set_noisy_sample_input_to_zero = set_noisy_sample_input_to_zero + + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = InflatedConv3d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + conditioning_channels = conditioning_channels * 8 * 8 + if concate_conditioning_mask: + conditioning_channels = conditioning_channels + 1 + self.concate_conditioning_mask = concate_conditioning_mask + + # control net conditioning embedding + if use_simplified_condition_embedding: + self.controlnet_cond_embedding = zero_module( + InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding) + ) + else: + self.controlnet_cond_embedding = FlowControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + self.use_simplified_condition_embedding = use_simplified_condition_embedding + + # time + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + res = 2 ** i + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + use_inflated_groupnorm=True, + + use_motion_module=use_motion_module and (res in motion_module_resolutions), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + controlnet_block = InflatedConv3d(mid_block_channel, mid_block_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=num_attention_heads[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + + use_inflated_groupnorm=True, + use_motion_module=use_motion_module and motion_module_mid_block, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + + controlnet_additional_kwargs: dict = {}, + ): + controlnet = cls( + in_channels=unet.config.in_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + class_embed_type=unet.config.class_embed_type, + num_class_embeds=unet.config.num_class_embeds, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + + **controlnet_additional_kwargs, + ) + controlnet.unshuffle = nn.PixelUnshuffle(8) + + if load_weights_from_unet: + m, u = controlnet.conv_in.load_state_dict(cls.image_layer_filter(unet.conv_in.state_dict()), strict=False) + assert len(u) == 0 + m, u = controlnet.time_proj.load_state_dict(cls.image_layer_filter(unet.time_proj.state_dict()), strict=False) + assert len(u) == 0 + m, u = controlnet.time_embedding.load_state_dict(cls.image_layer_filter(unet.time_embedding.state_dict()), strict=False) + assert len(u) == 0 + + if controlnet.class_embedding: + m, u = controlnet.class_embedding.load_state_dict(cls.image_layer_filter(unet.class_embedding.state_dict()), strict=False) + assert len(u) == 0 + m, u = controlnet.down_blocks.load_state_dict(cls.image_layer_filter(unet.down_blocks.state_dict()), strict=False) + assert len(u) == 0 + m, u = controlnet.mid_block.load_state_dict(cls.image_layer_filter(unet.mid_block.state_dict()), strict=False) + assert len(u) == 0 + + + return controlnet + + @staticmethod + def image_layer_filter(state_dict): + new_state_dict = {} + for name, param in state_dict.items(): + if "motion_modules." in name or "lora" in name: continue + new_state_dict[name] = param + return new_state_dict + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + + controlnet_cond: torch.FloatTensor, + conditioning_mask: Optional[torch.FloatTensor] = None, + + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[FlowControlNetOutput, Tuple]: + # set input noise to zero + if self.set_noisy_sample_input_to_zero: + sample = torch.zeros_like(sample).to(sample.device) + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0]) + encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0] // encoder_hidden_states.shape[0], 1, 1) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + sample = self.conv_in(sample) + + + if self.concate_conditioning_mask: + controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1) + controlnet_cond = self.unshuffle(controlnet_cond.permute(0,2,1,3,4)) + controlnet_cond = controlnet_cond.contiguous().permute(0,2,1,3,4) + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + + sample = sample + controlnet_cond + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + # cross_attention_kwargs=cross_attention_kwargs, + ) + else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + # cross_attention_kwargs=cross_attention_kwargs, + ) + + # 5. controlnet blocks + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + + scales = scales * conditioning_scale + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return FlowControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/modules/image_controlnet.py b/modules/image_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..65d358db885008b25d22ff2639e7ca533b1b0e46 --- /dev/null +++ b/modules/image_controlnet.py @@ -0,0 +1,721 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Changes were made to this source code by Yuwei Guo. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from diffusers import ModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention_processor import AttentionProcessor +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers.utils import BaseOutput, logging +from einops import rearrange, repeat +from torch import nn +from torch.nn import functional as F + +from .resnet import InflatedConv3d +from .unet_blocks import ( + CrossAttnDownBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + get_down_block, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ImageControlNetOutput(BaseOutput): + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class ImageControlNetConditioningEmbedding(nn.Module): + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = InflatedConv3d( + conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 + ) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append( + InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1) + ) + self.blocks.append( + InflatedConv3d( + channel_in, channel_out, kernel_size=3, padding=1, stride=2 + ) + ) + + self.conv_out = zero_module( + InflatedConv3d( + block_out_channels[-1], + conditioning_embedding_channels, + kernel_size=3, + padding=1, + ) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class ImageControlNetModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 3, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + use_motion_module=True, + motion_module_resolutions=(1, 2, 4, 8), + motion_module_mid_block=False, + motion_module_type="Vanilla", + motion_module_kwargs={ + "num_attention_heads": 8, + "num_transformer_block": 1, + "attention_block_types": ["Temporal_Self"], + "temporal_position_encoding": True, + "temporal_position_encoding_max_len": 32, + "temporal_attention_dim_div": 1, + "causal_temporal_attention": False, + }, + concate_conditioning_mask: bool = True, + use_simplified_condition_embedding: bool = False, + set_noisy_sample_input_to_zero: bool = False, + ): + super().__init__() + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len( + only_cross_attention + ) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len( + down_block_types + ): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + # input + self.set_noisy_sample_input_to_zero = set_noisy_sample_input_to_zero + + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = InflatedConv3d( + in_channels, + block_out_channels[0], + kernel_size=conv_in_kernel, + padding=conv_in_padding, + ) + + if concate_conditioning_mask: + conditioning_channels = conditioning_channels + 1 + self.concate_conditioning_mask = concate_conditioning_mask + + # control net conditioning embedding + if use_simplified_condition_embedding: + self.controlnet_cond_embedding = zero_module( + InflatedConv3d( + conditioning_channels, + block_out_channels[0], + kernel_size=conv_in_kernel, + padding=conv_in_padding, + ) + ) + else: + self.controlnet_cond_embedding = ImageControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + self.use_simplified_condition_embedding = use_simplified_condition_embedding + + # time + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim + ) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + res = 2**i + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=( + attention_head_dim[i] + if attention_head_dim[i] is not None + else output_channel + ), + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_inflated_groupnorm=True, + use_motion_module=use_motion_module + and (res in motion_module_resolutions), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = InflatedConv3d( + output_channel, output_channel, kernel_size=1 + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = InflatedConv3d( + output_channel, output_channel, kernel_size=1 + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + controlnet_block = InflatedConv3d( + mid_block_channel, mid_block_channel, kernel_size=1 + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=num_attention_heads[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_inflated_groupnorm=True, + use_motion_module=use_motion_module and motion_module_mid_block, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + controlnet_additional_kwargs: dict = {}, + ): + controlnet = cls( + in_channels=unet.config.in_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + class_embed_type=unet.config.class_embed_type, + num_class_embeds=unet.config.num_class_embeds, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + **controlnet_additional_kwargs, + ) + + if load_weights_from_unet: + m, u = controlnet.conv_in.load_state_dict( + cls.image_layer_filter(unet.conv_in.state_dict()), strict=False + ) + assert len(u) == 0 + m, u = controlnet.time_proj.load_state_dict( + cls.image_layer_filter(unet.time_proj.state_dict()), strict=False + ) + assert len(u) == 0 + m, u = controlnet.time_embedding.load_state_dict( + cls.image_layer_filter(unet.time_embedding.state_dict()), strict=False + ) + assert len(u) == 0 + + if controlnet.class_embedding: + m, u = controlnet.class_embedding.load_state_dict( + cls.image_layer_filter(unet.class_embedding.state_dict()), + strict=False, + ) + assert len(u) == 0 + m, u = controlnet.down_blocks.load_state_dict( + cls.image_layer_filter(unet.down_blocks.state_dict()), strict=False + ) + assert len(u) == 0 + m, u = controlnet.mid_block.load_state_dict( + cls.image_layer_filter(unet.mid_block.state_dict()), strict=False + ) + assert len(u) == 0 + + return controlnet + + @staticmethod + def image_layer_filter(state_dict): + new_state_dict = {} + for name, param in state_dict.items(): + if "motion_modules." in name or "lora" in name: + continue + new_state_dict[name] = param + return new_state_dict + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = ( + num_sliceable_layers * [slice_size] + if not isinstance(slice_size, list) + else slice_size + ) + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice( + module: torch.nn.Module, slice_size: List[int] + ): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.FloatTensor, + conditioning_mask: Optional[torch.FloatTensor] = None, + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ImageControlNetOutput, Tuple]: + + # set input noise to zero + if self.set_noisy_sample_input_to_zero: + sample = torch.zeros_like(sample).to(sample.device) + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0]) + encoder_hidden_states = encoder_hidden_states.repeat( + sample.shape[0] // encoder_hidden_states.shape[0], 1, 1 + ) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError( + "class_labels should be provided when num_class_embeds > 0" + ) + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + sample = self.conv_in(sample) + + if self.concate_conditioning_mask: + controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1) + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + + sample = sample + controlnet_cond + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + # cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + # cross_attention_kwargs=cross_attention_kwargs, + ) + + # 5. controlnet blocks + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip( + down_block_res_samples, self.controlnet_down_blocks + ): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + ( + down_block_res_sample, + ) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace( + -1, 0, len(down_block_res_samples) + 1, device=sample.device + ) # 0.1 to 1.0 + + scales = scales * conditioning_scale + down_block_res_samples = [ + sample * scale for sample, scale in zip(down_block_res_samples, scales) + ] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [ + sample * conditioning_scale for sample in down_block_res_samples + ] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) + for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean( + mid_block_res_sample, dim=(2, 3), keepdim=True + ) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ImageControlNetOutput( + down_block_res_samples=down_block_res_samples, + mid_block_res_sample=mid_block_res_sample, + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + if "temporal_transformer" not in sub_name: + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + if "temporal_transformer" not in name: + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): + r""" + Sets the attention processor to use to compute attention. + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + if "temporal_transformer" not in sub_name: + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + if "temporal_transformer" not in name: + fn_recursive_attn_processor(name, module, processor) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + diff --git a/modules/motion_module.py b/modules/motion_module.py new file mode 100644 index 0000000000000000000000000000000000000000..b70cc665b047556b625ec61e132c7d7144d1aaf9 --- /dev/null +++ b/modules/motion_module.py @@ -0,0 +1,355 @@ +import math +from dataclasses import dataclass +from typing import Callable, Optional + +import torch +import torch.nn.functional as F +from diffusers.models.attention import Attention, FeedForward +from diffusers.utils import BaseOutput +from diffusers.utils.torch_utils import maybe_allow_in_graph +from einops import rearrange, repeat +from torch import Tensor, nn + + +def zero_module(module): + # Zero out the parameters of a module and return it. + for p in module.parameters(): + p.detach().zero_() + return module + + +@dataclass +class TemporalTransformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict): + if motion_module_type == "Vanilla": + return VanillaTemporalModule( + in_channels=in_channels, + **motion_module_kwargs, + ) + else: + raise ValueError + + +class VanillaTemporalModule(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads=8, + num_transformer_block=2, + attention_block_types=("Temporal_Self", "Temporal_Self"), + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=24, + temporal_attention_dim_div=1, + zero_initialize=True, + ): + super().__init__() + + self.temporal_transformer = TemporalTransformer3DModel( + in_channels=in_channels, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels + // num_attention_heads + // temporal_attention_dim_div, + num_layers=num_transformer_block, + attention_block_types=attention_block_types, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + + if zero_initialize: + self.temporal_transformer.proj_out = zero_module( + self.temporal_transformer.proj_out + ) + self.skip_temporal_layers = False # Whether to skip temporal layer + + def forward( + self, + input_tensor, + temb, + encoder_hidden_states, + attention_mask=None, + anchor_frame_idx=None, + ): + if self.skip_temporal_layers is True: + return input_tensor + + hidden_states = input_tensor + hidden_states = self.temporal_transformer( + hidden_states, encoder_hidden_states, attention_mask + ) + + output = hidden_states + return output + + +@maybe_allow_in_graph +class TemporalTransformer3DModel(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads, + attention_head_dim, + num_layers, + attention_block_types=( + "Temporal_Self", + "Temporal_Self", + ), + dropout=0.0, + norm_num_groups=32, + cross_attention_dim=768, + activation_fn="geglu", + attention_bias=False, + upcast_attention=False, + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=24, + ): + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + TemporalTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + attention_block_types=attention_block_types, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + upcast_attention=upcast_attention, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + for d in range(num_layers) + ] + ) + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states: Tensor, + encoder_hidden_states: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + ): + assert ( + hidden_states.dim() == 5 + ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * weight, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + + # Transformer Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + video_length=video_length, + ) + + # output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + + output = hidden_states + residual + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + + return output + + +@maybe_allow_in_graph +class TemporalTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + attention_block_types=( + "Temporal_Self", + "Temporal_Self", + ), + dropout=0.0, + norm_num_groups: int = 32, + cross_attention_dim: int = 768, + activation_fn: str = "geglu", + attention_bias: bool = False, + upcast_attention: bool = False, + cross_frame_attention_mode=None, + temporal_position_encoding: bool = False, + temporal_position_encoding_max_len: int = 24, + ): + super().__init__() + + attention_blocks = [] + norms = [] + + for block_name in attention_block_types: + attention_blocks.append( + VersatileAttention( + attention_mode=block_name.split("_")[0], + cross_attention_dim=( + cross_attention_dim if block_name.endswith("_Cross") else None + ), + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + ) + norms.append(nn.LayerNorm(dim)) + + self.attention_blocks = nn.ModuleList(attention_blocks) + self.norms = nn.ModuleList(norms) + + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.ff_norm = nn.LayerNorm(dim) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + video_length=None, + ): + for attention_block, norm in zip(self.attention_blocks, self.norms): + norm_hidden_states = norm(hidden_states) + hidden_states = ( + attention_block( + norm_hidden_states, + encoder_hidden_states=( + encoder_hidden_states + if attention_block.is_cross_attention + else None + ), + video_length=video_length, + ) + + hidden_states + ) + + hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states + + output = hidden_states + return output + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout: float = 0.0, max_len: int = 24): + super().__init__() + self.dropout: nn.Module = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) + ) + pe: Tensor = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x: Tensor): + x = x + self.pe[:, : x.size(1)] + return self.dropout(x) + + +@maybe_allow_in_graph +class VersatileAttention(Attention): + def __init__( + self, + attention_mode: str = None, + cross_frame_attention_mode: Optional[str] = None, + temporal_position_encoding: bool = False, + temporal_position_encoding_max_len: int = 24, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + if attention_mode.lower() != "temporal": + raise ValueError(f"Attention mode {attention_mode} is not supported.") + + self.attention_mode = attention_mode + self.is_cross_attention = kwargs["cross_attention_dim"] is not None + + self.pos_encoder = ( + PositionalEncoding( + kwargs["query_dim"], + dropout=0.0, + max_len=temporal_position_encoding_max_len, + ) + if (temporal_position_encoding and attention_mode == "Temporal") + else None + ) + + def extra_repr(self): + return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" + + def forward( + self, + hidden_states: Tensor, + encoder_hidden_states=None, + attention_mask=None, + video_length=None, + ): + if self.attention_mode == "Temporal": + d = hidden_states.shape[1] + hidden_states = rearrange( + hidden_states, "(b f) d c -> (b d) f c", f=video_length + ) + + if self.pos_encoder is not None: + hidden_states = self.pos_encoder(hidden_states) + + encoder_hidden_states = ( + repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) + if encoder_hidden_states is not None + else encoder_hidden_states + ) + else: + raise NotImplementedError + + # attention processor makes this easy so that's nice + hidden_states = self.processor( + self, hidden_states, encoder_hidden_states, attention_mask + ) + + if self.attention_mode == "Temporal": + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + + def set_use_memory_efficient_attention_xformers( + self, + use_memory_efficient_attention_xformers: bool, + attention_op: Optional[Callable] = None, + ): + return None diff --git a/modules/resnet.py b/modules/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ddcdb91dc524c8110d15eca888233a12fb2a642e --- /dev/null +++ b/modules/resnet.py @@ -0,0 +1,261 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py + +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor, nn + + +class InflatedConv3d(nn.Conv2d): + def forward(self, x: Tensor) -> Tensor: + ori_dim = x.ndim + if ori_dim == 5: + frames = x.shape[2] + x = rearrange(x, "b c f h w -> (b f) c h w") + x = F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + if ori_dim == 5: + x = rearrange(x, "(b f) c h w -> b c f h w", f=frames) + return x + + +class InflatedGroupNorm(nn.GroupNorm): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name="conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, hidden_states: Tensor, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate( + hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest" + ) + else: + hidden_states = F.interpolate( + hidden_states, size=output_size, mode="nearest" + ) + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + name="conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = InflatedConv3d( + self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + raise NotImplementedError + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + use_inflated_groupnorm=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + assert use_inflated_groupnorm != None + if use_inflated_groupnorm: + self.norm1 = InflatedGroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + else: + self.norm1 = nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + + self.conv1 = InflatedConv3d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError( + f"unknown time_embedding_norm : {self.time_embedding_norm} " + ) + + self.time_emb_proj = nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + if use_inflated_groupnorm: + self.norm2 = InflatedGroupNorm( + num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True + ) + else: + self.norm2 = nn.GroupNorm( + num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True + ) + + self.dropout = nn.Dropout(dropout) + self.conv2 = InflatedConv3d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = ( + self.in_channels != self.out_channels + if use_in_shortcut is None + else use_in_shortcut + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) diff --git a/modules/unet.py b/modules/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..163c6f2a06733023c33f099034af82b2015b6262 --- /dev/null +++ b/modules/unet.py @@ -0,0 +1,591 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py + +import json +import os +from dataclasses import dataclass +from os import PathLike +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin, PeftAdapterMixin +from diffusers.models import ModelMixin +from diffusers.models.attention_processor import AttentionProcessor +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging +from safetensors.torch import load_file +from torch import Tensor, nn + +from .resnet import InflatedConv3d, InflatedGroupNorm +from .unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionFlowModelOutput(BaseOutput): + sample: torch.FloatTensor + + + +class UNet3DConditionFlowModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + + use_inflated_groupnorm=False, + + # Additional + use_motion_module = False, + motion_module_resolutions = ( 1,2,4,8 ), + motion_module_mid_block = False, + motion_module_decoder_only = False, + motion_module_type = None, + motion_module_kwargs = {}, + unet_use_cross_frame_attention = False, + unet_use_temporal_attention = False, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + res = 2 ** i + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + + use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + + use_motion_module=use_motion_module and motion_module_mid_block, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + res = 2 ** (3 - i) + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + + use_motion_module=use_motion_module and (res in motion_module_resolutions), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if use_inflated_groupnorm: + self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + + def get_image_controlnet(self, controlnet_noisy_latents, timesteps, + encoder_hidden_states=None, + controlnet_cond=None, + conditioning_mask=None, + conditioning_scale=None, + guess_mode=False, + return_dict=False,): + down_block_additional_residuals, mid_block_additional_residual = self.image_controlnet( + controlnet_noisy_latents, timesteps, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_cond, + conditioning_mask=conditioning_mask, + conditioning_scale=conditioning_scale, + guess_mode=guess_mode, + return_dict=return_dict, + ) + return down_block_additional_residuals, mid_block_additional_residual + + + def get_flow_controlnet(self, controlnet_noisy_latents, timesteps, + encoder_hidden_states=None, + controlnet_cond=None, + conditioning_mask=None, + conditioning_scale=None, + guess_mode=False, + return_dict=False,): + down_block_additional_residuals, mid_block_additional_residual = self.omcm_controlnet( + controlnet_noisy_latents, timesteps, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_cond, + conditioning_mask=conditioning_mask, + conditioning_scale=conditioning_scale, + guess_mode=guess_mode, + return_dict=return_dict, + ) + return down_block_additional_residuals, mid_block_additional_residual + + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + + # support image controlnet + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + + # support flow controlnet + flow_down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + flow_mid_block_additional_residual: Optional[torch.Tensor] = None, + + return_dict: bool = True, + ) -> Union[UNet3DConditionFlowModelOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # pre-process + sample = self.conv_in(sample) + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) + + down_block_res_samples += res_samples + + # support controlnet + # image controlnet + down_block_res_samples = list(down_block_res_samples) + if down_block_additional_residuals is not None: + for i, down_block_additional_residual in enumerate(down_block_additional_residuals): + if down_block_additional_residual.dim() == 4: # boardcast + down_block_additional_residual = down_block_additional_residual.unsqueeze(2) + down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual + + # flow controlnet + if flow_down_block_additional_residuals is not None: + for i, down_block_additional_residual in enumerate(flow_down_block_additional_residuals): + if down_block_additional_residual.dim() == 4: # boardcast + down_block_additional_residual = down_block_additional_residual.unsqueeze(2) + down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual + + + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + # support controlnet + # image controlnet + if mid_block_additional_residual is not None: + if mid_block_additional_residual.dim() == 4: # boardcast + mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2) + sample = sample + mid_block_additional_residual + + # flow controlnet + if flow_mid_block_additional_residual is not None: + if flow_mid_block_additional_residual.dim() == 4: # boardcast + flow_mid_block_additional_residual = flow_mid_block_additional_residual.unsqueeze(2) + sample = sample + flow_mid_block_additional_residual + + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, + ) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet3DConditionFlowModelOutput(sample=sample) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...") + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D" + ] + config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ] + + from diffusers.utils import WEIGHTS_NAME + model = cls.from_config(config, **unet_additional_kwargs) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + + motion_params = [p.numel() if "motion_modules." in n else 0 for n,p in model.named_parameters()] + motion_name = [n for n in model.state_dict().keys() if "motion_modules." in n] + + print(f"### Motion Module Parameters: {sum(motion_params) / 1e6} M") + print(f"### Motion Module keys: {len(motion_name)}") + + unnorlmal = [] + for n in m: + if n not in motion_name: + unnorlmal.append(n) + + return model + + 'motion_modules.' in 'up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe' \ No newline at end of file diff --git a/modules/unet_blocks.py b/modules/unet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..8ba7a70c4a647f83eadc3f461c926ae658772318 --- /dev/null +++ b/modules/unet_blocks.py @@ -0,0 +1,866 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn + +from .attention import Transformer3DModel +from .motion_module import get_motion_module +from .resnet import Downsample3D, ResnetBlock3D, Upsample3D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, +): + down_block_type = ( + down_block_type[7:] + if down_block_type.startswith("UNetRes") + else down_block_type + ) + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnDownBlock3D" + ) + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, +): + up_block_type = ( + up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + ) + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnUpBlock3D" + ) + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ] + attentions = [] + motion_modules = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=in_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet, motion_module in zip( + self.attentions, self.resnets[1:], self.motion_modules + ): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if motion_module is not None: + hidden_states = motion_module( + hidden_states, + temb, + encoder_hidden_states=encoder_hidden_states, + ) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + num_attention_heads=attn_num_head_channels, + attention_head_dim=out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + output_states = () + + for resnet, attn, motion_module in zip( + self.resnets, self.attentions, self.motion_modules + ): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + encoder_hidden_states, + ) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + # add motion module + hidden_states = ( + motion_module( + hidden_states, temb, encoder_hidden_states=encoder_hidden_states + ) + if motion_module is not None + else hidden_states + ) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + use_inflated_groupnorm=None, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, motion_module in zip(self.resnets, self.motion_modules): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + encoder_hidden_states, + ) + else: + hidden_states = resnet(hidden_states, temb) + + # add motion module + if motion_module: + hidden_states = motion_module( + hidden_states, temb, encoder_hidden_states=encoder_hidden_states + ) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + for resnet, attn, motion_module in zip( + self.resnets, self.attentions, self.motion_modules + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + encoder_hidden_states, + ) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # add motion module + if motion_module: + hidden_states = motion_module( + hidden_states, temb, encoder_hidden_states=encoder_hidden_states + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + use_inflated_groupnorm=None, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None, + encoder_hidden_states=None, + ): + for resnet, motion_module in zip(self.resnets, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + encoder_hidden_states, + ) + else: + hidden_states = resnet(hidden_states, temb) + if motion_module: + hidden_states = motion_module( + hidden_states, temb, encoder_hidden_states=encoder_hidden_states + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/peft/__init__.py b/peft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c934d842b0b5f802d2a53a317dcf22c8c32c5979 --- /dev/null +++ b/peft/__init__.py @@ -0,0 +1,98 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "0.11.1" + +from .auto import ( + AutoPeftModel, + AutoPeftModelForCausalLM, + AutoPeftModelForSequenceClassification, + AutoPeftModelForSeq2SeqLM, + AutoPeftModelForTokenClassification, + AutoPeftModelForQuestionAnswering, + AutoPeftModelForFeatureExtraction, +) +from .mapping import ( + MODEL_TYPE_TO_PEFT_MODEL_MAPPING, + PEFT_TYPE_TO_CONFIG_MAPPING, + get_peft_config, + get_peft_model, + inject_adapter_in_model, +) +from .mixed_model import PeftMixedModel +from .peft_model import ( + PeftModel, + PeftModelForCausalLM, + PeftModelForSeq2SeqLM, + PeftModelForSequenceClassification, + PeftModelForTokenClassification, + PeftModelForQuestionAnswering, + PeftModelForFeatureExtraction, + get_layer_status, + get_model_status, +) +from .tuners import ( + AdaptionPromptConfig, + AdaptionPromptModel, + LoraConfig, + LoftQConfig, + LoraModel, + LoHaConfig, + LoHaModel, + LoKrConfig, + LoKrModel, + IA3Config, + IA3Model, + AdaLoraConfig, + AdaLoraModel, + BOFTConfig, + BOFTModel, + PrefixEncoder, + PrefixTuningConfig, + PromptEmbedding, + PromptEncoder, + PromptEncoderConfig, + PromptEncoderReparameterizationType, + PromptTuningConfig, + PromptTuningInit, + MultitaskPromptTuningConfig, + MultitaskPromptTuningInit, + OFTConfig, + OFTModel, + PolyConfig, + PolyModel, + LNTuningConfig, + LNTuningModel, + VeraConfig, + VeraModel, +) +from .utils import ( + TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, + PeftType, + TaskType, + bloom_model_postprocess_past_key_value, + get_peft_model_state_dict, + prepare_model_for_kbit_training, + replace_lora_weights_loftq, + set_peft_model_state_dict, + shift_tokens_right, + load_peft_weights, + cast_mixed_precision_params, +) +from .config import PeftConfig, PromptLearningConfig diff --git a/peft/__pycache__/__init__.cpython-310.pyc b/peft/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7519f40da950b0df55e6d8ab2df9a18221f646d1 Binary files /dev/null and b/peft/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/__pycache__/auto.cpython-310.pyc b/peft/__pycache__/auto.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a48eb8ac284e9b03f0ae25bafb43fc79f89bb62c Binary files /dev/null and b/peft/__pycache__/auto.cpython-310.pyc differ diff --git a/peft/__pycache__/config.cpython-310.pyc b/peft/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e99f10e8e6c7b24206a42261d9ec24fd2491118d Binary files /dev/null and b/peft/__pycache__/config.cpython-310.pyc differ diff --git a/peft/__pycache__/import_utils.cpython-310.pyc b/peft/__pycache__/import_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97e87072c428137738955a011d3a9837b637ac69 Binary files /dev/null and b/peft/__pycache__/import_utils.cpython-310.pyc differ diff --git a/peft/__pycache__/mapping.cpython-310.pyc b/peft/__pycache__/mapping.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cb5ce2a074f407bccad18aab353b0455dab5c09 Binary files /dev/null and b/peft/__pycache__/mapping.cpython-310.pyc differ diff --git a/peft/__pycache__/mixed_model.cpython-310.pyc b/peft/__pycache__/mixed_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70be307653d0608f4d50ed009e1421c006257ec5 Binary files /dev/null and b/peft/__pycache__/mixed_model.cpython-310.pyc differ diff --git a/peft/__pycache__/peft_model.cpython-310.pyc b/peft/__pycache__/peft_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..973f687480ed9ffb5e31ba346b787e6ff25b27ac Binary files /dev/null and b/peft/__pycache__/peft_model.cpython-310.pyc differ diff --git a/peft/auto.py b/peft/auto.py new file mode 100644 index 0000000000000000000000000000000000000000..353c9e2f84c48bd61194102da6d6e83dfdcd42db --- /dev/null +++ b/peft/auto.py @@ -0,0 +1,170 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib +import os +from typing import Optional + +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoModelForQuestionAnswering, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoModelForTokenClassification, + AutoTokenizer, +) + +from .config import PeftConfig +from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING +from .peft_model import ( + PeftModel, + PeftModelForCausalLM, + PeftModelForFeatureExtraction, + PeftModelForQuestionAnswering, + PeftModelForSeq2SeqLM, + PeftModelForSequenceClassification, + PeftModelForTokenClassification, +) +from .utils.constants import TOKENIZER_CONFIG_NAME +from .utils.other import check_file_exists_on_hf_hub + + +class _BaseAutoPeftModel: + _target_class = None + _target_peft_class = None + + def __init__(self, *args, **kwargs): + # For consistency with transformers: https://github.com/huggingface/transformers/blob/91d7df58b6537d385e90578dac40204cb550f706/src/transformers/models/auto/auto_factory.py#L400 + raise EnvironmentError( # noqa: UP024 + f"{self.__class__.__name__} is designed to be instantiated " + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " + f"`{self.__class__.__name__}.from_config(config)` methods." + ) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + adapter_name: str = "default", + is_trainable: bool = False, + config: Optional[PeftConfig] = None, + **kwargs, + ): + r""" + A wrapper around all the preprocessing steps a user needs to perform in order to load a PEFT model. The kwargs + are passed along to `PeftConfig` that automatically takes care of filtering the kwargs of the Hub methods and + the config object init. + """ + peft_config = PeftConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + base_model_path = peft_config.base_model_name_or_path + + task_type = getattr(peft_config, "task_type", None) + + if cls._target_class is not None: + target_class = cls._target_class + elif cls._target_class is None and task_type is not None: + # this is only in the case where we use `AutoPeftModel` + raise ValueError( + "Cannot use `AutoPeftModel` with a task type, please use a specific class for your task type. (e.g. `AutoPeftModelForCausalLM` for `task_type='CAUSAL_LM'`)" + ) + + if task_type is not None: + expected_target_class = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[task_type] + if cls._target_peft_class.__name__ != expected_target_class.__name__: + raise ValueError( + f"Expected target PEFT class: {expected_target_class.__name__}, but you have asked for: {cls._target_peft_class.__name__ }" + " make sure that you are loading the correct model for your task type." + ) + elif task_type is None and getattr(peft_config, "auto_mapping", None) is not None: + auto_mapping = getattr(peft_config, "auto_mapping", None) + base_model_class = auto_mapping["base_model_class"] + parent_library_name = auto_mapping["parent_library"] + + parent_library = importlib.import_module(parent_library_name) + target_class = getattr(parent_library, base_model_class) + else: + raise ValueError( + "Cannot infer the auto class from the config, please make sure that you are loading the correct model for your task type." + ) + + base_model = target_class.from_pretrained(base_model_path, **kwargs) + + tokenizer_exists = False + if os.path.exists(os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_NAME)): + tokenizer_exists = True + else: + token = kwargs.get("token", None) + if token is None: + token = kwargs.get("use_auth_token", None) + + tokenizer_exists = check_file_exists_on_hf_hub( + repo_id=pretrained_model_name_or_path, + filename=TOKENIZER_CONFIG_NAME, + revision=kwargs.get("revision", None), + repo_type=kwargs.get("repo_type", None), + token=token, + ) + + if tokenizer_exists: + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=kwargs.get("trust_remote_code", False) + ) + base_model.resize_token_embeddings(len(tokenizer)) + + return cls._target_peft_class.from_pretrained( + base_model, + pretrained_model_name_or_path, + adapter_name=adapter_name, + is_trainable=is_trainable, + config=config, + **kwargs, + ) + + +class AutoPeftModel(_BaseAutoPeftModel): + _target_class = None + _target_peft_class = PeftModel + + +class AutoPeftModelForCausalLM(_BaseAutoPeftModel): + _target_class = AutoModelForCausalLM + _target_peft_class = PeftModelForCausalLM + + +class AutoPeftModelForSeq2SeqLM(_BaseAutoPeftModel): + _target_class = AutoModelForSeq2SeqLM + _target_peft_class = PeftModelForSeq2SeqLM + + +class AutoPeftModelForSequenceClassification(_BaseAutoPeftModel): + _target_class = AutoModelForSequenceClassification + _target_peft_class = PeftModelForSequenceClassification + + +class AutoPeftModelForTokenClassification(_BaseAutoPeftModel): + _target_class = AutoModelForTokenClassification + _target_peft_class = PeftModelForTokenClassification + + +class AutoPeftModelForQuestionAnswering(_BaseAutoPeftModel): + _target_class = AutoModelForQuestionAnswering + _target_peft_class = PeftModelForQuestionAnswering + + +class AutoPeftModelForFeatureExtraction(_BaseAutoPeftModel): + _target_class = AutoModel + _target_peft_class = PeftModelForFeatureExtraction diff --git a/peft/config.py b/peft/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d38da2b82d5c75971ca506b3ceb99fd30e3dac --- /dev/null +++ b/peft/config.py @@ -0,0 +1,270 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import json +import os +from dataclasses import asdict, dataclass, field +from typing import Dict, Optional, Union + +from huggingface_hub import hf_hub_download +from transformers.utils import PushToHubMixin + +from .utils import CONFIG_NAME, PeftType, TaskType + + +@dataclass +class PeftConfigMixin(PushToHubMixin): + r""" + This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all + PEFT adapter models. This class inherits from [`~transformers.utils.PushToHubMixin`] which contains the methods to + push your model to the Hub. The method `save_pretrained` will save the configuration of your adapter model in a + directory. The method `from_pretrained` will load the configuration of your adapter model from a directory. + + Args: + peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use. + """ + + peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."}) + auto_mapping: Optional[dict] = field( + default=None, metadata={"help": "An auto mapping dict to help retrieve the base model class if needed."} + ) + + def to_dict(self) -> Dict: + r""" + Returns the configuration for your adapter model as a dictionary. + """ + return asdict(self) + + def save_pretrained(self, save_directory: str, **kwargs) -> None: + r""" + This method saves the configuration of your adapter model in a directory. + + Args: + save_directory (`str`): + The directory where the configuration will be saved. + kwargs (additional keyword arguments, *optional*): + Additional keyword arguments passed along to the [`~transformers.utils.PushToHubMixin.push_to_hub`] + method. + """ + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + auto_mapping_dict = kwargs.pop("auto_mapping_dict", None) + + output_dict = asdict(self) + # converting set type to list + for key, value in output_dict.items(): + if isinstance(value, set): + output_dict[key] = list(value) + + output_path = os.path.join(save_directory, CONFIG_NAME) + + # Add auto mapping details for custom models. + if auto_mapping_dict is not None: + output_dict["auto_mapping"] = auto_mapping_dict + + # save it + with open(output_path, "w") as writer: + writer.write(json.dumps(output_dict, indent=2, sort_keys=True)) + + @classmethod + def from_peft_type(cls, **kwargs): + r""" + This method loads the configuration of your adapter model from a set of kwargs. + + The appropriate configuration type is determined by the `peft_type` argument. If `peft_type` is not provided, + the calling class type is instantiated. + + Args: + kwargs (configuration keyword arguments): + Keyword arguments passed along to the configuration initialization. + """ + # Avoid circular dependency .. TODO: fix this with a larger refactor + from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING + + # TODO: this hack is needed to fix the following issue (on commit 702f937): + # if someone saves a default config and loads it back with `PeftConfig` class it yields to + # not loading the correct config class. + + # from peft import AdaLoraConfig, PeftConfig + # peft_config = AdaLoraConfig() + # print(peft_config) + # >>> AdaLoraConfig(peft_type=, auto_mapping=None, base_model_name_or_path=None, + # revision=None, task_type=None, inference_mode=False, r=8, target_modules=None, lora_alpha=8, lora_dropout=0.0, ... + # + # peft_config.save_pretrained("./test_config") + # peft_config = PeftConfig.from_pretrained("./test_config") + # print(peft_config) + # >>> PeftConfig(peft_type='ADALORA', auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=None, inference_mode=False) + + if "peft_type" in kwargs: + peft_type = kwargs["peft_type"] + config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type] + else: + config_cls = cls + + return config_cls(**kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional[str] = None, **kwargs): + r""" + This method loads the configuration of your adapter model from a directory. + + Args: + pretrained_model_name_or_path (`str`): + The directory or the Hub repository id where the configuration is saved. + kwargs (additional keyword arguments, *optional*): + Additional keyword arguments passed along to the child class initialization. + """ + path = ( + os.path.join(pretrained_model_name_or_path, subfolder) + if subfolder is not None + else pretrained_model_name_or_path + ) + + hf_hub_download_kwargs, class_kwargs, _ = cls._split_kwargs(kwargs) + + if os.path.isfile(os.path.join(path, CONFIG_NAME)): + config_file = os.path.join(path, CONFIG_NAME) + else: + try: + config_file = hf_hub_download( + pretrained_model_name_or_path, CONFIG_NAME, subfolder=subfolder, **hf_hub_download_kwargs + ) + except Exception as exc: + raise ValueError(f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'") from exc + + loaded_attributes = cls.from_json_file(config_file) + kwargs = {**class_kwargs, **loaded_attributes} + return cls.from_peft_type(**kwargs) + + @classmethod + def from_json_file(cls, path_json_file: str, **kwargs): + r""" + Loads a configuration file from a json file. + + Args: + path_json_file (`str`): + The path to the json file. + """ + with open(path_json_file) as file: + json_object = json.load(file) + + return json_object + + @classmethod + def _split_kwargs(cls, kwargs): + hf_hub_download_kwargs = {} + class_kwargs = {} + other_kwargs = {} + + for key, value in kwargs.items(): + if key in inspect.signature(hf_hub_download).parameters: + hf_hub_download_kwargs[key] = value + elif key in list(cls.__annotations__): + class_kwargs[key] = value + else: + other_kwargs[key] = value + + return hf_hub_download_kwargs, class_kwargs, other_kwargs + + @classmethod + def _get_peft_type( + cls, + model_id: str, + **hf_hub_download_kwargs, + ): + subfolder = hf_hub_download_kwargs.get("subfolder", None) + + path = os.path.join(model_id, subfolder) if subfolder is not None else model_id + + if os.path.isfile(os.path.join(path, CONFIG_NAME)): + config_file = os.path.join(path, CONFIG_NAME) + else: + try: + config_file = hf_hub_download( + model_id, + CONFIG_NAME, + **hf_hub_download_kwargs, + ) + except Exception: + raise ValueError(f"Can't find '{CONFIG_NAME}' at '{model_id}'") + + loaded_attributes = cls.from_json_file(config_file) + return loaded_attributes["peft_type"] + + @property + def is_prompt_learning(self) -> bool: + r""" + Utility method to check if the configuration is for prompt learning. + """ + return False + + @property + def is_adaption_prompt(self) -> bool: + """Return True if this is an adaption prompt config.""" + return False + + +@dataclass +class PeftConfig(PeftConfigMixin): + """ + This is the base configuration class to store the configuration of a [`PeftModel`]. + + Args: + peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use. + task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform. + inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode. + """ + + base_model_name_or_path: Optional[str] = field( + default=None, metadata={"help": "The name of the base model to use."} + ) + revision: Optional[str] = field(default=None, metadata={"help": "The specific model version to use."}) + peft_type: Optional[Union[str, PeftType]] = field(default=None, metadata={"help": "Peft type"}) + task_type: Optional[Union[str, TaskType]] = field(default=None, metadata={"help": "Task type"}) + inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"}) + + +@dataclass +class PromptLearningConfig(PeftConfig): + """ + This is the base configuration class to store the configuration of [`PrefixTuning`], [`PromptEncoder`], or + [`PromptTuning`]. + + Args: + num_virtual_tokens (`int`): The number of virtual tokens to use. + token_dim (`int`): The hidden embedding dimension of the base transformer model. + num_transformer_submodules (`int`): The number of transformer submodules in the base transformer model. + num_attention_heads (`int`): The number of attention heads in the base transformer model. + num_layers (`int`): The number of layers in the base transformer model. + """ + + num_virtual_tokens: int = field(default=None, metadata={"help": "Number of virtual tokens"}) + token_dim: int = field( + default=None, metadata={"help": "The hidden embedding dimension of the base transformer model"} + ) + num_transformer_submodules: Optional[int] = field( + default=None, metadata={"help": "Number of transformer submodules"} + ) + num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"}) + num_layers: Optional[int] = field(default=None, metadata={"help": "Number of transformer layers"}) + + @property + def is_prompt_learning(self) -> bool: + r""" + Utility method to check if the configuration is for prompt learning. + """ + return True diff --git a/peft/helpers.py b/peft/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..63742a7725caddc582c9160b3ce70b5e930f3b82 --- /dev/null +++ b/peft/helpers.py @@ -0,0 +1,148 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from copy import deepcopy +from functools import update_wrapper +from types import MethodType + +from .peft_model import PeftConfig, PeftModel + + +def update_forward_signature(model: PeftModel) -> None: + """ + Updates the forward signature of the PeftModel to include parents class signature + model (`PeftModel`): Peft model to update the forward signature + + Example: + + ```python + >>> from transformers import WhisperForConditionalGeneration + >>> from peft import get_peft_model, LoraConfig, update_forward_signature + + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + >>> peft_config = LoraConfig(r=8, lora_alpha=32, lora_dropout=0.1, target_modules=["q_proj", "v_proj"]) + + >>> peft_model = get_peft_model(model, peft_config) + >>> update_forward_signature(peft_model) + ``` + """ + + # Only update signature when the current forward signature only has *args and **kwargs + current_signature = inspect.signature(model.forward) + if ( + len(current_signature.parameters) == 2 + and "args" in current_signature.parameters + and "kwargs" in current_signature.parameters + ): + forward = deepcopy(model.forward.__func__) + update_wrapper( + forward, type(model.get_base_model()).forward, assigned=("__doc__", "__name__", "__annotations__") + ) + model.forward = MethodType(forward, model) + + +def update_generate_signature(model: PeftModel) -> None: + """ + Updates the generate signature of a PeftModel with overriding generate to include parents class signature + model (`PeftModel`): Peft model to update the generate signature + + Example: + + ```python + >>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + >>> from peft import get_peft_model, LoraConfig, TaskType, update_generate_signature + + >>> model_name_or_path = "bigscience/mt0-large" + >>> tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + >>> model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) + + >>> peft_config = LoraConfig( + ... task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1 + ... ) + >>> peft_model = get_peft_model(model, peft_config) + >>> update_generate_signature(peft_model) + >>> help(peft_model.generate) + ``` + """ + if not hasattr(model, "generate"): + return + current_signature = inspect.signature(model.generate) + if ( + len(current_signature.parameters) == 2 + and "args" in current_signature.parameters + and "kwargs" in current_signature.parameters + ) or (len(current_signature.parameters) == 1 and "kwargs" in current_signature.parameters): + generate = deepcopy(model.generate.__func__) + update_wrapper( + generate, + type(model.get_base_model()).generate, + assigned=("__doc__", "__name__", "__annotations__"), + ) + model.generate = MethodType(generate, model) + + +def update_signature(model: PeftModel, method: str = "all") -> None: + """ + Updates the signature of a PeftModel include parents class signature for forward or generate method + model (`PeftModel`): Peft model to update generate or forward signature method (`str`): method to update + signature choose one of "forward", "generate", "all" + + Example: + ```python + >>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + >>> from peft import get_peft_model, LoraConfig, TaskType, update_signature + + >>> model_name_or_path = "bigscience/mt0-large" + >>> tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + >>> model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) + + >>> peft_config = LoraConfig( + ... task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1 + ... ) + >>> peft_model = get_peft_model(model, peft_config) + >>> update_signature(peft_model) + >>> help(peft_model.generate) + ``` + """ + if method == "forward": + update_forward_signature(model) + elif method == "generate": + update_generate_signature(model) + elif method == "all": + update_forward_signature(model) + update_generate_signature(model) + else: + raise ValueError(f"method {method} is not supported please choose one of ['forward', 'generate', 'all']") + + +def check_if_peft_model(model_name_or_path: str) -> bool: + """ + Check if the model is a PEFT model. + + Args: + model_name_or_path (`str`): + Model id to check, can be local or on the Hugging Face Hub. + + Returns: + `bool`: True if the model is a PEFT model, False otherwise. + """ + is_peft_model = True + try: + PeftConfig.from_pretrained(model_name_or_path) + except Exception: + # allow broad exceptions so that this works even if new exceptions are added on HF Hub side + is_peft_model = False + + return is_peft_model diff --git a/peft/import_utils.py b/peft/import_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..58c65f9c8022c33702d978d99ae76ca5d260a518 --- /dev/null +++ b/peft/import_utils.py @@ -0,0 +1,89 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import importlib.metadata as importlib_metadata +from functools import lru_cache + +import packaging.version + + +@lru_cache +def is_bnb_available() -> bool: + return importlib.util.find_spec("bitsandbytes") is not None + + +@lru_cache +def is_bnb_4bit_available() -> bool: + if not is_bnb_available(): + return False + + import bitsandbytes as bnb + + return hasattr(bnb.nn, "Linear4bit") + + +@lru_cache +def is_auto_gptq_available(): + if importlib.util.find_spec("auto_gptq") is not None: + AUTOGPTQ_MINIMUM_VERSION = packaging.version.parse("0.5.0") + version_autogptq = packaging.version.parse(importlib_metadata.version("auto_gptq")) + if AUTOGPTQ_MINIMUM_VERSION <= version_autogptq: + return True + else: + raise ImportError( + f"Found an incompatible version of auto-gptq. Found version {version_autogptq}, " + f"but only versions above {AUTOGPTQ_MINIMUM_VERSION} are supported" + ) + + +@lru_cache +def is_optimum_available() -> bool: + return importlib.util.find_spec("optimum") is not None + + +@lru_cache +def is_torch_tpu_available(check_device=True): + "Checks if `torch_xla` is installed and potentially if a TPU is in the environment" + if importlib.util.find_spec("torch_xla") is not None: + if check_device: + # We need to check if `xla_device` can be found, will raise a RuntimeError if not + try: + import torch_xla.core.xla_model as xm + + _ = xm.xla_device() + return True + except RuntimeError: + return False + return True + return False + + +@lru_cache +def is_aqlm_available(): + return importlib.util.find_spec("aqlm") is not None + + +@lru_cache +def is_auto_awq_available(): + return importlib.util.find_spec("awq") is not None + + +@lru_cache +def is_eetq_available(): + return importlib.util.find_spec("eetq") is not None + + +@lru_cache +def is_hqq_available(): + return importlib.util.find_spec("hqq") is not None diff --git a/peft/mapping.py b/peft/mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..d2ae669f010cdb73bbc76d42f7c3ead746ab0d28 --- /dev/null +++ b/peft/mapping.py @@ -0,0 +1,181 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import torch + +from .config import PeftConfig +from .mixed_model import PeftMixedModel +from .peft_model import ( + PeftModel, + PeftModelForCausalLM, + PeftModelForFeatureExtraction, + PeftModelForQuestionAnswering, + PeftModelForSeq2SeqLM, + PeftModelForSequenceClassification, + PeftModelForTokenClassification, +) +from .tuners import ( + AdaLoraConfig, + AdaLoraModel, + AdaptionPromptConfig, + BOFTConfig, + BOFTModel, + IA3Config, + IA3Model, + LNTuningConfig, + LNTuningModel, + LoHaConfig, + LoHaModel, + LoKrConfig, + LoKrModel, + LoraConfig, + LoraModel, + MultitaskPromptTuningConfig, + OFTConfig, + OFTModel, + PolyConfig, + PolyModel, + PrefixTuningConfig, + PromptEncoderConfig, + PromptTuningConfig, + VeraConfig, + VeraModel, +) +from .tuners.tuners_utils import BaseTuner as _BaseTuner +from .utils import _prepare_prompt_learning_config + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + +MODEL_TYPE_TO_PEFT_MODEL_MAPPING: dict[str, type[PeftModel]] = { + "SEQ_CLS": PeftModelForSequenceClassification, + "SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM, + "CAUSAL_LM": PeftModelForCausalLM, + "TOKEN_CLS": PeftModelForTokenClassification, + "QUESTION_ANS": PeftModelForQuestionAnswering, + "FEATURE_EXTRACTION": PeftModelForFeatureExtraction, +} + +PEFT_TYPE_TO_CONFIG_MAPPING: dict[str, type[PeftConfig]] = { + "ADAPTION_PROMPT": AdaptionPromptConfig, + "PROMPT_TUNING": PromptTuningConfig, + "PREFIX_TUNING": PrefixTuningConfig, + "P_TUNING": PromptEncoderConfig, + "LORA": LoraConfig, + "LOHA": LoHaConfig, + "LOKR": LoKrConfig, + "ADALORA": AdaLoraConfig, + "BOFT": BOFTConfig, + "IA3": IA3Config, + "MULTITASK_PROMPT_TUNING": MultitaskPromptTuningConfig, + "OFT": OFTConfig, + "POLY": PolyConfig, + "LN_TUNING": LNTuningConfig, + "VERA": VeraConfig, +} + +PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[_BaseTuner]] = { + "LORA": LoraModel, + "LOHA": LoHaModel, + "LOKR": LoKrModel, + "ADALORA": AdaLoraModel, + "BOFT": BOFTModel, + "IA3": IA3Model, + "OFT": OFTModel, + "POLY": PolyModel, + "LN_TUNING": LNTuningModel, + "VERA": VeraModel, +} + + +def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig: + """ + Returns a Peft config object from a dictionary. + + Args: + config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters. + """ + + return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict) + + +def get_peft_model( + model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default", mixed: bool = False +) -> PeftModel | PeftMixedModel: + """ + Returns a Peft model object from a model and a config. + + Args: + model ([`transformers.PreTrainedModel`]): + Model to be wrapped. + peft_config ([`PeftConfig`]): + Configuration object containing the parameters of the Peft model. + adapter_name (`str`, `optional`, defaults to `"default"`): + The name of the adapter to be injected, if not provided, the default adapter name is used ("default"). + mixed (`bool`, `optional`, defaults to `False`): + Whether to allow mixing different (compatible) adapter types. + """ + model_config = getattr(model, "config", {"model_type": "custom"}) + if hasattr(model_config, "to_dict"): + model_config = model_config.to_dict() + + peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None) + + if mixed: + return PeftMixedModel(model, peft_config, adapter_name=adapter_name) + + if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning: + return PeftModel(model, peft_config, adapter_name=adapter_name) + + if peft_config.is_prompt_learning: + peft_config = _prepare_prompt_learning_config(peft_config, model_config) + return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name) + + +def inject_adapter_in_model( + peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default" +) -> torch.nn.Module: + r""" + A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning + methods and adaption prompt. Make sure to have the correct `target_names` set in the `peft_config` object. The API + calls `get_peft_model` under the hood but would be restricted only to non-prompt learning methods. + + Args: + peft_config (`PeftConfig`): + Configuration object containing the parameters of the Peft model. + model (`torch.nn.Module`): + The input model where the adapter will be injected. + adapter_name (`str`, `optional`, defaults to `"default"`): + The name of the adapter to be injected, if not provided, the default adapter name is used ("default"). + """ + if peft_config.is_prompt_learning or peft_config.is_adaption_prompt: + raise ValueError("`create_and_replace` does not support prompt learning and adaption prompt yet.") + + if peft_config.peft_type not in PEFT_TYPE_TO_TUNER_MAPPING.keys(): + raise ValueError( + f"`inject_adapter_in_model` does not support {peft_config.peft_type} yet. Please use `get_peft_model`." + ) + + tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type] + + # By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules. + peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name) + + return peft_model.model diff --git a/peft/mixed_model.py b/peft/mixed_model.py new file mode 100644 index 0000000000000000000000000000000000000000..be640e01e7b29ba1cb9e5350e46dc20d91d58ce1 --- /dev/null +++ b/peft/mixed_model.py @@ -0,0 +1,415 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from contextlib import contextmanager +from typing import Any, Optional, Union + +import torch +from accelerate.hooks import remove_hook_from_submodules +from torch import nn +from transformers.utils import PushToHubMixin + +from peft.tuners.mixed import COMPATIBLE_TUNER_TYPES + +from .config import PeftConfig +from .peft_model import PeftModel +from .tuners import ( + AdaLoraModel, + IA3Model, + LoHaModel, + LoKrModel, + LoraModel, + MixedModel, + OFTModel, +) +from .utils import PeftType, _set_adapter, _set_trainable + + +PEFT_TYPE_TO_MODEL_MAPPING = { + PeftType.LORA: LoraModel, + PeftType.LOHA: LoHaModel, + PeftType.LOKR: LoKrModel, + PeftType.ADALORA: AdaLoraModel, + PeftType.IA3: IA3Model, + PeftType.OFT: OFTModel, +} + + +def _prepare_model_for_gradient_checkpointing(model: nn.Module) -> None: + r""" + Prepares the model for gradient checkpointing if necessary + """ + # Note: same as PeftModel._prepare_model_for_gradient_checkpointing + if not getattr(model, "is_gradient_checkpointing", True): + return model + + if not ( + getattr(model, "is_loaded_in_8bit", False) + or getattr(model, "is_loaded_in_4bit", False) + or getattr(model, "is_quantized", False) + ): + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + elif hasattr(model, "get_input_embeddings"): + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + +def _check_config_compatible(peft_config: PeftConfig) -> None: + if peft_config.peft_type not in COMPATIBLE_TUNER_TYPES: + raise ValueError( + f"The provided `peft_type` '{peft_config.peft_type.value}' is not compatible with the `PeftMixedModel`. " + f"Compatible types are: {COMPATIBLE_TUNER_TYPES}" + ) + + +class PeftMixedModel(PushToHubMixin, torch.nn.Module): + """ + PeftMixedModel for loading mixing different types of adapters for inference. + + This class does not support loading/saving, and it shouldn't usually be initialized directly. Instead, use + `get_peft_model` with the argument `mixed=True`. + + + + Read the [Mixed adapter types](https://huggingface.co/docs/peft/en/developer_guides/mixed_models) guide to learn + more about using different adapter types. + + + + Example: + + ```py + >>> from peft import get_peft_model + + >>> base_model = ... # load the base model, e.g. from transformers + >>> peft_model = PeftMixedModel.from_pretrained(base_model, path_to_adapter1, "adapter1").eval() + >>> peft_model.load_adapter(path_to_adapter2, "adapter2") + >>> peft_model.set_adapter(["adapter1", "adapter2"]) # activate both adapters + >>> peft_model(data) # forward pass using both adapters + ``` + + Args: + model (`torch.nn.Module`): + The model to be tuned. + config (`PeftConfig`): + The config of the model to be tuned. The adapter type must be compatible. + adapter_name (`str`, `optional`, defaults to `"default"`): + The name of the first adapter. + """ + + def __init__(self, model: nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: + super().__init__() + _check_config_compatible(peft_config) + _prepare_model_for_gradient_checkpointing(model) + self.modules_to_save = None + self.base_model = MixedModel(model, {adapter_name: peft_config}, adapter_name) + self.set_modules_to_save(peft_config, adapter_name) + + self.config = getattr(model, "config", {"model_type": "custom"}) + + # the `pretraining_tp` is set for some models to simulate Tensor Parallelism during inference to avoid + # numerical differences, https://github.com/pytorch/pytorch/issues/76232 - to avoid any unexpected + # behavior we disable that in this line. + if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"): + self.base_model.config.pretraining_tp = 1 + + @property + def peft_config(self) -> dict[str, PeftConfig]: + return self.base_model.peft_config + + @property + def active_adapter(self) -> str: + return self.base_model.active_adapter + + @property + def active_adapters(self) -> list[str]: + return self.base_model.active_adapters + + def get_nb_trainable_parameters(self): + r""" + Returns the number of trainable parameters and number of all parameters in the model. + """ + # note: same as PeftModel.get_nb_trainable_parameters + trainable_params = 0 + all_param = 0 + for _, param in self.named_parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + + # Due to the design of 4bit linear layers from bitsandbytes + # one needs to multiply the number of parameters by 2 to get + # the correct number of parameters + if param.__class__.__name__ == "Params4bit": + num_params = num_params * 2 + + all_param += num_params + if param.requires_grad: + trainable_params += num_params + + return trainable_params, all_param + + def print_trainable_parameters(self): + """ + Prints the number of trainable parameters in the model. + + Note: print_trainable_parameters() uses get_nb_trainable_parameters() which is different from + num_parameters(only_trainable=True) from huggingface/transformers. get_nb_trainable_parameters() returns + (trainable parameters, all parameters) of the Peft Model which includes modified backbone transformer model. + For techniques like LoRA, the backbone transformer model is modified in place with LoRA modules. However, for + prompt tuning, the backbone transformer model is unmodified. num_parameters(only_trainable=True) returns number + of trainable parameters of the backbone transformer model which can be different. + """ + # note: same as PeftModel.print_trainable_parameters + trainable_params, all_param = self.get_nb_trainable_parameters() + + print( + f"trainable params: {trainable_params:,d} || " + f"all params: {all_param:,d} || " + f"trainable%: {100 * trainable_params / all_param:.4f}" + ) + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.base_model, name) + + def forward(self, *args: Any, **kwargs: Any): + """ + Forward pass of the model. + """ + return self.base_model(*args, **kwargs) + + def generate(self, *args: Any, **kwargs: Any): + """ + Generate output. + """ + return self.base_model.generate(*args, **kwargs) + + @contextmanager + def disable_adapter(self): + """ + Disables the adapter module. + """ + try: + self.base_model.disable_adapter_layers() + yield + finally: + self.base_model.enable_adapter_layers() + + def add_adapter(self, adapter_name: str, peft_config: PeftConfig): + _check_config_compatible(peft_config) + + try: + self.peft_config[adapter_name] = peft_config + self.base_model.inject_adapter(self, adapter_name) + except Exception: # something went wrong, roll back + if adapter_name in self.peft_config: + del self.peft_config[adapter_name] + raise + + self.set_modules_to_save(peft_config, adapter_name) + + def set_modules_to_save(self, peft_config: PeftConfig, adapter_name: str) -> None: + if (modules_to_save := getattr(peft_config, "modules_to_save", None)) is None: + return + + if self.modules_to_save is None: + self.modules_to_save = set(modules_to_save) + else: + self.modules_to_save.update(modules_to_save) + _set_trainable(self, adapter_name) + + def set_adapter(self, adapter_name: Union[str, list[str]]) -> None: + """ + Sets the active adapter(s) for the model. + + Note that the order in which the adapters are applied during the forward pass may not be the same as the order + in which they are passed to this function. Instead, the order during the forward pass is determined by the + order in which the adapters were loaded into the model. The active adapters only determine which adapters are + active during the forward pass, but not the order in which they are applied. + + Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is + not desired, use the following code. + + ```py + >>> for name, param in model_peft.named_parameters(): + ... if ...: # some check on name (ex. if 'lora' in name) + ... param.requires_grad = False + ``` + + Args: + adapter_name (`str` or `List[str]`): + The name of the adapter(s) to be activated. + """ + if isinstance(adapter_name, str): + adapter_name = [adapter_name] + + mismatched = set(adapter_name) - set(self.peft_config.keys()) + if mismatched: + raise ValueError( + f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}" + ) + + self.base_model.set_adapter(adapter_name) + _set_adapter(self, adapter_name) + + def delete_adapter(self, adapter_name: Union[str, list[str]]) -> None: + if isinstance(adapter_name, str): + adapter_name = [adapter_name] + + mismatched = set(adapter_name) - set(self.peft_config.keys()) + if mismatched: + raise ValueError( + f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}" + ) + + self.base_model.delete_adapter(adapter_name) + + def merge_and_unload(self, *args: Any, **kwargs: Any): + r""" + This method merges the adapter layers into the base model. This is needed if someone wants to use the base + model as a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + return self.base_model.merge_and_unload(*args, **kwargs) + + def unload(self, *args: Any, **kwargs: Any): + """ + Gets back the base model by removing all the adapter modules without merging. This gives back the original base + model. + """ + return self.base_model.unload(*args, **kwargs) + + def get_layer_status(self): + raise TypeError(f"get_layer_status is not supported for {self.__class__.__name__}.") + + def get_model_status(self): + raise TypeError(f"get_model_status is not supported for {self.__class__.__name__}.") + + @classmethod + def _split_kwargs(cls, kwargs: dict[str, Any]): + return PeftModel._split_kwargs(kwargs) + + def load_adapter(self, model_id: str, adapter_name: str, *args: Any, **kwargs: Any): + output = PeftModel.load_adapter(self, model_id, adapter_name, *args, **kwargs) + # TODO: not quite clear why this is necessary but tests fail without it + self.set_adapter(self.active_adapters) + return output + + def create_or_update_model_card(self, output_dir: str): + raise NotImplementedError(f"Model card creation is not supported for {self.__class__.__name__} (yet).") + + def save_pretrained( + self, + save_directory: str, + safe_serialization: bool = False, + selected_adapters: Optional[list[str]] = None, + **kwargs: Any, + ): + raise NotImplementedError(f"Saving is not supported for {self.__class__.__name__} (yet).") + + @classmethod + def from_pretrained( + cls, + model: nn.Module, + model_id: str | os.PathLike, + adapter_name: str = "default", + is_trainable: bool = False, + config: Optional[PeftConfig] = None, + **kwargs: Any, + ): + r""" + Instantiate a PEFT mixed model from a pretrained model and loaded PEFT weights. + + Note that the passed `model` may be modified inplace. + + Args: + model (`nn.Module`): + The model to be adapted. + model_id (`str` or `os.PathLike`): + The name of the PEFT configuration to use. Can be either: + - A string, the `model id` of a PEFT configuration hosted inside a model repo on the Hugging Face + Hub. + - A path to a directory containing a PEFT configuration file saved using the `save_pretrained` + method (`./my_peft_config_directory/`). + adapter_name (`str`, *optional*, defaults to `"default"`): + The name of the adapter to be loaded. This is useful for loading multiple adapters. + is_trainable (`bool`, *optional*, defaults to `False`): + Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and use for + inference + config ([`~peft.PeftConfig`], *optional*): + The configuration object to use instead of an automatically loaded configuration. This configuration + object is mutually exclusive with `model_id` and `kwargs`. This is useful when configuration is already + loaded before calling `from_pretrained`. + kwargs: (`optional`): + Additional keyword arguments passed along to the specific PEFT configuration class. + """ + # note: adapted from PeftModel.from_pretrained + from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING + + # load the config + if config is None: + config = PEFT_TYPE_TO_CONFIG_MAPPING[ + PeftConfig._get_peft_type( + model_id, + subfolder=kwargs.get("subfolder", None), + revision=kwargs.get("revision", None), + cache_dir=kwargs.get("cache_dir", None), + use_auth_token=kwargs.get("use_auth_token", None), + ) + ].from_pretrained(model_id, **kwargs) + elif isinstance(config, PeftConfig): + config.inference_mode = not is_trainable + else: + raise ValueError(f"The input config must be a PeftConfig, got {config.__class__}") + + # note: this is different from PeftModel.from_pretrained + if config.peft_type not in PEFT_TYPE_TO_MODEL_MAPPING: + raise ValueError(f"Adapter of type {config.peft_type} is not supported for mixed models.") + + if (getattr(model, "hf_device_map", None) is not None) and len( + set(model.hf_device_map.values()).intersection({"cpu", "disk"}) + ) > 0: + remove_hook_from_submodules(model) + + if config.is_prompt_learning and is_trainable: + # note: should not be possible to reach, but just in case + raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.") + else: + config.inference_mode = not is_trainable + + # note: this is different from PeftModel.from_pretrained, we always return a PeftMixedModel + model = cls(model, config, adapter_name) + model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs) + return model diff --git a/peft/peft_model.py b/peft/peft_model.py new file mode 100644 index 0000000000000000000000000000000000000000..adaa23cd6e85134de67118a3df1d40d85ef7e057 --- /dev/null +++ b/peft/peft_model.py @@ -0,0 +1,2613 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import collections +import inspect +import os +import warnings +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Literal, Optional, Union + +import packaging.version +import torch +import transformers +from accelerate import dispatch_model, infer_auto_device_map +from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules +from accelerate.utils import get_balanced_memory, named_module_tensors +from huggingface_hub import ModelCard, ModelCardData, hf_hub_download +from safetensors import safe_open +from safetensors.torch import save_file as safe_save_file +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import PreTrainedModel +from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput +from transformers.utils import PushToHubMixin + +from . import __version__ +from .config import PeftConfig +from .tuners import ( + AdaLoraModel, + AdaptionPromptModel, + BOFTModel, + IA3Model, + LNTuningModel, + LoHaModel, + LoKrModel, + LoraModel, + MultitaskPromptEmbedding, + OFTModel, + PolyModel, + PrefixEncoder, + PromptEmbedding, + PromptEncoder, + VeraModel, +) +from .tuners.tuners_utils import BaseTuner, BaseTunerLayer +from .utils import ( + SAFETENSORS_WEIGHTS_NAME, + TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, + WEIGHTS_NAME, + PeftType, + TaskType, + _get_batch_size, + _prepare_prompt_learning_config, + _set_adapter, + _set_trainable, + get_peft_model_state_dict, + id_tensor_storage, + infer_device, + load_peft_weights, + set_peft_model_state_dict, + shift_tokens_right, +) + + +PEFT_TYPE_TO_MODEL_MAPPING = { + PeftType.LORA: LoraModel, + PeftType.LOHA: LoHaModel, + PeftType.LOKR: LoKrModel, + PeftType.PROMPT_TUNING: PromptEmbedding, + PeftType.P_TUNING: PromptEncoder, + PeftType.PREFIX_TUNING: PrefixEncoder, + PeftType.ADALORA: AdaLoraModel, + PeftType.BOFT: BOFTModel, + PeftType.ADAPTION_PROMPT: AdaptionPromptModel, + PeftType.IA3: IA3Model, + PeftType.OFT: OFTModel, + PeftType.POLY: PolyModel, + PeftType.LN_TUNING: LNTuningModel, + PeftType.VERA: VeraModel, +} + + +class PeftModel(PushToHubMixin, torch.nn.Module): + """ + Base model encompassing various Peft methods. + + Args: + model ([`~transformers.PreTrainedModel`]): The base transformer model used for Peft. + peft_config ([`PeftConfig`]): The configuration of the Peft model. + adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`. + + **Attributes**: + - **base_model** ([`torch.nn.Module`]) -- The base transformer model used for Peft. + - **peft_config** ([`PeftConfig`]) -- The configuration of the Peft model. + - **modules_to_save** (`list` of `str`) -- The list of sub-module names to save when + saving the model. + - **prompt_encoder** ([`PromptEncoder`]) -- The prompt encoder used for Peft if + using [`PromptLearningConfig`]. + - **prompt_tokens** (`torch.Tensor`) -- The virtual prompt tokens used for Peft if + using [`PromptLearningConfig`]. + - **transformer_backbone_name** (`str`) -- The name of the transformer + backbone in the base model if using [`PromptLearningConfig`]. + - **word_embeddings** (`torch.nn.Embedding`) -- The word embeddings of the transformer backbone + in the base model if using [`PromptLearningConfig`]. + """ + + def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> None: + super().__init__() + self.modules_to_save = None + self.active_adapter = adapter_name + self.peft_type = peft_config.peft_type + # These args are special PEFT arguments that users can pass. They need to be removed before passing them to + # forward. + self.special_peft_forward_args = {"adapter_names"} + + self._is_prompt_learning = peft_config.is_prompt_learning + if self._is_prompt_learning: + self._peft_config = {adapter_name: peft_config} + self.base_model = model + self.add_adapter(adapter_name, peft_config) + else: + self._peft_config = None + cls = PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type] + self.base_model = cls(model, {adapter_name: peft_config}, adapter_name) + self.set_additional_trainable_modules(peft_config, adapter_name) + + if getattr(model, "is_gradient_checkpointing", True): + model = self._prepare_model_for_gradient_checkpointing(model) + + # the `pretraining_tp` is set for some models to simulate Tensor Parallelism during inference to avoid + # numerical differences, https://github.com/pytorch/pytorch/issues/76232 - to avoid any unexpected + # behavior we disable that in this line. + if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"): + self.base_model.config.pretraining_tp = 1 + + @property + def peft_config(self) -> dict[str, PeftConfig]: + if self._is_prompt_learning: + return self._peft_config + return self.base_model.peft_config + + @property + def active_adapters(self) -> list[str]: + try: + adapters = self.base_model.active_adapters + except AttributeError: + adapters = self.active_adapter + if isinstance(adapters, str): + adapters = [adapters] + return adapters + + @peft_config.setter + def peft_config(self, value: dict[str, PeftConfig]): + if self._is_prompt_learning: + self._peft_config = value + else: + self.base_model.peft_config = value + + def save_pretrained( + self, + save_directory: str, + safe_serialization: bool = True, + selected_adapters: Optional[list[str]] = None, + save_embedding_layers: Union[str, bool] = "auto", + is_main_process: bool = True, + convert_pissa_to_lora: Optional[str] = None, + **kwargs: Any, + ) -> None: + r""" + This function saves the adapter model and the adapter configuration files to a directory, so that it can be + reloaded using the [`PeftModel.from_pretrained`] class method, and also used by the [`PeftModel.push_to_hub`] + method. + + Args: + save_directory (`str`): + Directory where the adapter model and configuration files will be saved (will be created if it does not + exist). + safe_serialization (`bool`, *optional*): + Whether to save the adapter files in safetensors format, defaults to `True`. + selected_adapters (`List[str]`, *optional*): + A list of adapters to be saved. If `None`, will default to all adapters. + save_embedding_layers (`Union[bool, str]`, *optional*, defaults to `"auto"`): + If `True`, save the embedding layers in addition to adapter weights. If `auto`, checks the common + embedding layers `peft.utils.other.EMBEDDING_LAYER_NAMES` in config's `target_modules` when available. + and automatically sets the boolean flag. This only works for 🤗 transformers models. + is_main_process (`bool`, *optional*): + Whether the process calling this is the main process or not. Will default to `True`. Will not save the + checkpoint if not on the main process, which is important for multi device setups (e.g. DDP). + convert_pissa_to_lora (`str`): + The path to the initialized PiSSA adapter, which is obtained after initializing the model with PiSSA + and before performing any training. When `convert_pissa_to_lora` is not None, the difference in PISSA + before and after fine-tuning is calculated. This difference can be represented as the parameters of a + of a standard LoRA adapter. Using this converted adapter does not require changes to the base model, + thus conveniently allowing the use of multiple PISSA and LoRA adapters, and the activation or + deactivation of any adapters. + kwargs (additional keyword arguments, *optional*): + Additional keyword arguments passed along to the `push_to_hub` method. + """ + if os.path.isfile(save_directory): + raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file") + + if selected_adapters is None: + selected_adapters = list(self.peft_config.keys()) + else: + if any( + selected_adapter_name not in list(self.peft_config.keys()) + for selected_adapter_name in selected_adapters + ): + raise ValueError( + f"You passed an invalid `selected_adapters` arguments, current supported adapter names are" + f" {list(self.peft_config.keys())} - got {selected_adapters}." + ) + + def save_pissa_as_lora(peft_config, convert_pissa_to_lora, output_state_dict, kwargs): + if not str(peft_config.init_lora_weights).startswith("pissa"): + warnings.warn("`convert_pissa_to_lora` only works for converting a PiSSA adapter to a LoRA adapter") + initial_adapter = os.path.basename(convert_pissa_to_lora) + self.load_adapter( + os.path.dirname(convert_pissa_to_lora), subfolder=initial_adapter, adapter_name=initial_adapter + ) + if str(self.peft_config[initial_adapter].init_lora_weights).startswith("pissa"): + raise ValueError( + "The `init_lora_weights` parameter of the initial PiSSA adapter should be set to `True`. " + "Otherwise, `self.load_adapter` will subtract the principal singular value and vector again based on the residual model." + ) + output_state_dict = self.base_model.subtract_pissa_init(output_state_dict, initial_adapter, kwargs) + self.delete_adapter(adapter_name) + return output_state_dict + + if is_main_process: + os.makedirs(save_directory, exist_ok=True) + self.create_or_update_model_card(save_directory) + + for adapter_name in selected_adapters: + peft_config = self.peft_config[adapter_name] + # save only the trainable weights + output_state_dict = get_peft_model_state_dict( + self, + state_dict=kwargs.get("state_dict", None), + adapter_name=adapter_name, + save_embedding_layers=save_embedding_layers, + ) + output_dir = os.path.join(save_directory, adapter_name) if adapter_name != "default" else save_directory + os.makedirs(output_dir, exist_ok=True) + + if is_main_process and safe_serialization: + # Section copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2111-L2134 + # Safetensors does not allow tensor aliasing. + # We're going to remove aliases before saving + ptrs = collections.defaultdict(list) + for name, tensor in output_state_dict.items(): + # Sometimes in the state_dict we have non-tensor objects. + # e.g. in bitsandbytes we have some `str` objects in the state_dict + if isinstance(tensor, torch.Tensor): + ptrs[id_tensor_storage(tensor)].append(name) + else: + # In the non-tensor case, fall back to the pointer of the object itself + ptrs[id(tensor)].append(name) + + # These are all the pointers of shared tensors. + shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} + + for _, names in shared_ptrs.items(): + # Here we just clone the shared tensors to avoid tensor aliasing which is + # not supported in safetensors. + for shared_tensor_name in names[1:]: + output_state_dict[shared_tensor_name] = output_state_dict[shared_tensor_name].clone() + if convert_pissa_to_lora is not None: + output_state_dict = save_pissa_as_lora( + peft_config, convert_pissa_to_lora, output_state_dict, kwargs + ) + safe_save_file( + output_state_dict, + os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME), + metadata={"format": "pt"}, + ) + elif is_main_process: + if convert_pissa_to_lora is not None: + output_state_dict = save_pissa_as_lora( + peft_config, convert_pissa_to_lora, output_state_dict, kwargs + ) + torch.save(output_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + + # save the config and change the inference mode to `True` + if peft_config.base_model_name_or_path is None: + peft_config.base_model_name_or_path = ( + self.base_model.__dict__.get("name_or_path", None) + if peft_config.is_prompt_learning + else self.base_model.model.__dict__.get("name_or_path", None) + ) + inference_mode = peft_config.inference_mode + peft_config.inference_mode = True + + if peft_config.task_type is None: + # deal with auto mapping + base_model_class = self._get_base_model_class( + is_prompt_tuning=peft_config.is_prompt_learning, + ) + parent_library = base_model_class.__module__ + + auto_mapping_dict = { + "base_model_class": base_model_class.__name__, + "parent_library": parent_library, + } + else: + auto_mapping_dict = None + + if is_main_process: + if convert_pissa_to_lora is not None: + peft_config.init_lora_weights = True + peft_config.r *= 2 + peft_config.lora_alpha *= 2 + peft_config.save_pretrained(output_dir, auto_mapping_dict=auto_mapping_dict) + peft_config.inference_mode = inference_mode + + @classmethod + def from_pretrained( + cls, + model: torch.nn.Module, + model_id: Union[str, os.PathLike], + adapter_name: str = "default", + is_trainable: bool = False, + config: Optional[PeftConfig] = None, + **kwargs: Any, + ) -> PeftModel: + r""" + Instantiate a PEFT model from a pretrained model and loaded PEFT weights. + + Note that the passed `model` may be modified inplace. + + Args: + model ([`torch.nn.Module`]): + The model to be adapted. For 🤗 Transformers models, the model should be initialized with the + [`~transformers.PreTrainedModel.from_pretrained`]. + model_id (`str` or `os.PathLike`): + The name of the PEFT configuration to use. Can be either: + - A string, the `model id` of a PEFT configuration hosted inside a model repo on the Hugging Face + Hub. + - A path to a directory containing a PEFT configuration file saved using the `save_pretrained` + method (`./my_peft_config_directory/`). + adapter_name (`str`, *optional*, defaults to `"default"`): + The name of the adapter to be loaded. This is useful for loading multiple adapters. + is_trainable (`bool`, *optional*, defaults to `False`): + Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be + used for inference. + config ([`~peft.PeftConfig`], *optional*): + The configuration object to use instead of an automatically loaded configuration. This configuration + object is mutually exclusive with `model_id` and `kwargs`. This is useful when configuration is already + loaded before calling `from_pretrained`. + kwargs: (`optional`): + Additional keyword arguments passed along to the specific PEFT configuration class. + """ + from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING + + # load the config + if config is None: + config = PEFT_TYPE_TO_CONFIG_MAPPING[ + PeftConfig._get_peft_type( + model_id, + subfolder=kwargs.get("subfolder", None), + revision=kwargs.get("revision", None), + cache_dir=kwargs.get("cache_dir", None), + use_auth_token=kwargs.get("use_auth_token", None), + token=kwargs.get("token", None), + ) + ].from_pretrained(model_id, **kwargs) + elif isinstance(config, PeftConfig): + config.inference_mode = not is_trainable + else: + raise ValueError(f"The input config must be a PeftConfig, got {config.__class__}") + + if hasattr(model, "hf_device_map"): + weight_map = dict(named_module_tensors(model, recurse=True)) + + # recreate the offload_index for disk-offloaded modules: we need to know the location in storage of each weight + # before the offload hook is removed from the model + disk_modules = set() + index = None + for name, module in model.named_modules(): + if hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "original_devices"): + if hasattr(module._hf_hook.weights_map, "dataset"): + index = module._hf_hook.weights_map.dataset.index + for key in module._hf_hook.original_devices.keys(): + if module._hf_hook.original_devices[key] == torch.device("meta"): + disk_modules.add(str(name) + "." + str(key)) + + if disk_modules and not kwargs.get("use_safetensors", True): + raise ValueError("Disk offloading currently only supported for safetensors") + + if index: + offload_index = { + p: { + "safetensors_file": index[p]["safetensors_file"], + "weight_name": p, + "dtype": str(weight_map[p].dtype).replace("torch.", ""), + } + for p in weight_map.keys() + if p in disk_modules + } + kwargs["offload_index"] = offload_index + + if (getattr(model, "hf_device_map", None) is not None) and len( + set(model.hf_device_map.values()).intersection({"cpu", "disk"}) + ) > 0: + remove_hook_from_submodules(model) + + if config.is_prompt_learning and is_trainable: + raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.") + else: + config.inference_mode = not is_trainable + + if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys(): + model = cls(model, config, adapter_name) + else: + model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config, adapter_name) + model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs) + return model + + def _setup_prompt_encoder(self, adapter_name: str): + config = self.peft_config[adapter_name] + if not hasattr(self, "prompt_encoder"): + self.prompt_encoder = torch.nn.ModuleDict({}) + self.prompt_tokens = {} + transformer_backbone = None + for name, module in self.base_model.named_children(): + for param in module.parameters(): + param.requires_grad = False + if isinstance(module, PreTrainedModel): + # Make sure to freeze Tranformers model + if transformer_backbone is None: + transformer_backbone = module + self.transformer_backbone_name = name + if transformer_backbone is None: + transformer_backbone = self.base_model + + if config.num_transformer_submodules is None: + config.num_transformer_submodules = 2 if config.task_type == TaskType.SEQ_2_SEQ_LM else 1 + + for named_param, value in list(transformer_backbone.named_parameters()): + # for ZeRO-3, the tensor is sharded across accelerators and deepspeed modifies it to a tensor with shape [0] + # the actual unsharded shape is stored in "ds_shape" attribute + # special handling is needed in case the model is initialized in deepspeed.zero.Init() context or HfDeepSpeedConfig + # has been called before + # For reference refer to issue: https://github.com/huggingface/peft/issues/996 + deepspeed_distributed_tensor_shape = getattr(value, "ds_shape", None) + + if value.shape[0] == self.base_model.config.vocab_size or ( + deepspeed_distributed_tensor_shape is not None + and deepspeed_distributed_tensor_shape[0] == self.base_model.config.vocab_size + ): + self.word_embeddings = transformer_backbone.get_submodule(named_param.replace(".weight", "")) + break + + if config.peft_type == PeftType.PROMPT_TUNING: + prompt_encoder = PromptEmbedding(config, self.word_embeddings) + elif config.peft_type == PeftType.MULTITASK_PROMPT_TUNING: + prompt_encoder = MultitaskPromptEmbedding(config, self.word_embeddings) + elif config.peft_type == PeftType.P_TUNING: + prompt_encoder = PromptEncoder(config) + elif config.peft_type == PeftType.PREFIX_TUNING: + prompt_encoder = PrefixEncoder(config) + else: + raise ValueError("Not supported") + + prompt_encoder = prompt_encoder.to(self.device) + self.prompt_encoder.update(torch.nn.ModuleDict({adapter_name: prompt_encoder})) + self.prompt_tokens[adapter_name] = torch.arange( + config.num_virtual_tokens * config.num_transformer_submodules + ).long() + + def _prepare_model_for_gradient_checkpointing(self, model: PreTrainedModel): + r""" + Prepares the model for gradient checkpointing if necessary + """ + if not ( + getattr(model, "is_loaded_in_8bit", False) + or getattr(model, "is_loaded_in_4bit", False) + or getattr(model, "is_quantized", False) + ): + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + elif hasattr(model, "get_input_embeddings"): + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + return model + + def get_prompt_embedding_to_save(self, adapter_name: str) -> torch.Tensor: + """ + Returns the prompt embedding to save when saving the model. Only applicable when using a prompt learning + method. + """ + prompt_encoder = self.prompt_encoder[adapter_name] + prompt_tokens = ( + self.prompt_tokens[adapter_name].unsqueeze(0).expand(1, -1).to(prompt_encoder.embedding.weight.device) + ) + if self.peft_config[adapter_name].peft_type == PeftType.PREFIX_TUNING: + prompt_tokens = prompt_tokens[:, : self.peft_config[adapter_name].num_virtual_tokens] + + if self.peft_config[adapter_name].peft_type == PeftType.MULTITASK_PROMPT_TUNING: + prompt_embeddings = super(MultitaskPromptEmbedding, prompt_encoder).forward(prompt_tokens) + else: + prompt_embeddings = prompt_encoder(prompt_tokens) + + return prompt_embeddings[0].detach().cpu() + + def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Returns the virtual prompts to use for Peft. Only applicable when using a prompt learning method. + """ + peft_config = self.active_peft_config + prompt_encoder = self.prompt_encoder[self.active_adapter] + prompt_tokens = ( + self.prompt_tokens[self.active_adapter] + .unsqueeze(0) + .expand(batch_size, -1) + .to(prompt_encoder.embedding.weight.device) + ) + if peft_config.peft_type == PeftType.PREFIX_TUNING: + prompt_tokens = prompt_tokens[:, : peft_config.num_virtual_tokens] + if peft_config.inference_mode: + past_key_values = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1) + else: + past_key_values = prompt_encoder(prompt_tokens) + if self.base_model_torch_dtype is not None: + past_key_values = past_key_values.to(self.base_model_torch_dtype) + past_key_values = past_key_values.view( + batch_size, + peft_config.num_virtual_tokens, + peft_config.num_layers * 2, + peft_config.num_attention_heads, + peft_config.token_dim // peft_config.num_attention_heads, + ) + if peft_config.num_transformer_submodules == 2: + past_key_values = torch.cat([past_key_values, past_key_values], dim=2) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split( + peft_config.num_transformer_submodules * 2 + ) + if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None: + post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type] + past_key_values = post_process_fn(past_key_values) + return past_key_values + else: + if peft_config.peft_type == PeftType.MULTITASK_PROMPT_TUNING: + prompts = prompt_encoder(prompt_tokens, task_ids) + else: + if peft_config.inference_mode: + prompts = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1) + else: + prompts = prompt_encoder(prompt_tokens) + return prompts + + def get_nb_trainable_parameters(self) -> tuple[int, int]: + r""" + Returns the number of trainable parameters and the number of all parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in self.named_parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + + # Due to the design of 4bit linear layers from bitsandbytes + # one needs to multiply the number of parameters by 2 to get + # the correct number of parameters + if param.__class__.__name__ == "Params4bit": + if hasattr(param, "element_size"): + num_bytes = param.element_size() + elif not hasattr(param, "quant_storage"): + num_bytes = 1 + else: + num_bytes = param.quant_storage.itemsize + num_params = num_params * 2 * num_bytes + + all_param += num_params + if param.requires_grad: + trainable_params += num_params + + return trainable_params, all_param + + def print_trainable_parameters(self) -> None: + """ + Prints the number of trainable parameters in the model. + + Note: print_trainable_parameters() uses get_nb_trainable_parameters() which is different from + num_parameters(only_trainable=True) from huggingface/transformers. get_nb_trainable_parameters() returns + (trainable parameters, all parameters) of the Peft Model which includes modified backbone transformer model. + For techniques like LoRA, the backbone transformer model is modified in place with LoRA modules. However, for + prompt tuning, the backbone transformer model is unmodified. num_parameters(only_trainable=True) returns number + of trainable parameters of the backbone transformer model which can be different. + """ + trainable_params, all_param = self.get_nb_trainable_parameters() + + print( + f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.4f}" + ) + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.base_model, name) + + @contextmanager + def _enable_peft_forward_hooks(self, *args, **kwargs): + # If the base model has a method called _enable_peft_forward_hooks, it is invoked as a context. Otherwise, this + # runs without any changes + if hasattr(self.base_model, "_enable_peft_forward_hooks"): + with self.base_model._enable_peft_forward_hooks(*args, **kwargs): + yield + return + else: + # nothing to enable + yield + return + + def forward(self, *args: Any, **kwargs: Any): + """ + Forward pass of the model. + """ + with self._enable_peft_forward_hooks(*args, **kwargs): + kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args} + return self.get_base_model()(*args, **kwargs) + + def generate(self, *args, **kwargs): + with self._enable_peft_forward_hooks(*args, **kwargs): + kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args} + return self.get_base_model().generate(*args, **kwargs) + + def _get_base_model_class(self, is_prompt_tuning=False): + """ + Returns the base model class. + """ + if not is_prompt_tuning: + return self.base_model.model.__class__ + return self.base_model.__class__ + + @contextmanager + def disable_adapter(self): + """ + Context manager that disables the adapter module. Use this to run inference on the base model. + + Example: + + ```py + >>> with model.disable_adapter(): + ... model(inputs) + ``` + """ + if self.peft_config[self.active_adapter].is_prompt_learning: + try: + # TODO: consider replacing this patching of methods with a more robust mechanism: setting a flag and + # letting the underlying methods deal with it, same as how LoRA does it. + old_forward = self.forward + self.forward = self.base_model.forward + old_prepare_inputs_for_generation = self.prepare_inputs_for_generation + self.prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation + yield + finally: + self.forward = old_forward + self.prepare_inputs_for_generation = old_prepare_inputs_for_generation + + elif self.peft_config[self.active_adapter].is_adaption_prompt: + try: + self.base_model.disable_adapter_layers() + yield + finally: + self.base_model.enable_adapter_layers() + + else: # LoRA, LoHa, etc. + model_status = self.get_model_status() + if model_status.enabled == "irregular": + warnings.warn( + "The model contains some adapter layers that are enabled and others that are disabled. " + "This is most likely unintentional. After exiting the disable_adapter context, all adapters " + "will be enabled" + ) + try: + self.base_model.disable_adapter_layers() + yield + finally: + if model_status.enabled is not False: + # model_status.enabled is `True` or `"irregular"` + self.base_model.enable_adapter_layers() + + def get_base_model(self) -> torch.nn.Module: + """ + Returns the base model. + """ + return ( + self.base_model + if (self.active_peft_config.is_prompt_learning or self.peft_type == PeftType.POLY) + else self.base_model.model + ) + + def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None: + """ + Add an adapter to the model based on the passed configuration. + + This adapter is not trained. To load a trained adapter, check out [`PeftModel.load_adapter`]. + + The name for the new adapter should be unique. + + The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active + adapter. + + Args: + adapter_name (`str`): + The name of the adapter to be added. + peft_config ([`PeftConfig`]): + The configuration of the adapter to be added. + """ + if peft_config.peft_type != self.peft_type: + raise ValueError( + f"Cannot combine adapters with different peft types. " + f"Found {self.peft_type} and {peft_config.peft_type}." + ) + + try: + if peft_config.is_prompt_learning: + self.peft_config[adapter_name] = peft_config + if hasattr(self.config, "to_dict"): + dict_config = self.config.to_dict() + else: + dict_config = self.config + + peft_config = _prepare_prompt_learning_config(peft_config, dict_config) + self._setup_prompt_encoder(adapter_name) + elif peft_config.is_adaption_prompt: + self.base_model.add_adapter(adapter_name, peft_config) + else: + self.peft_config[adapter_name] = peft_config + self.base_model.inject_adapter(self.base_model.model, adapter_name) + except Exception: # something went wrong, roll back + if adapter_name in self.peft_config: + del self.peft_config[adapter_name] + raise + + self.set_additional_trainable_modules(peft_config, adapter_name) + + def set_additional_trainable_modules(self, peft_config, adapter_name): + if getattr(peft_config, "modules_to_save", None) is not None: + if self.modules_to_save is None: + self.modules_to_save = set(peft_config.modules_to_save) + else: + self.modules_to_save.update(peft_config.modules_to_save) + _set_trainable(self, adapter_name) # this may add a new ModulesToSaveWrapper + + def get_layer_status(self) -> list[TunerLayerStatus]: + """Get the status of each adapter layer in the model. + + This method returns a list of `TunerLayerStatus` dataclass instances, each of which contains the following + attributes: + + - `name` (`str`): + The name of the adapter layer, e.g. `model.encoder.block.0.layer.0.SelfAttention.q`. + - `module_type` (`str`): + The type of the adapter layer, e.g. `lora.Linear`. + - `enabled` (`bool`): + Whether the adapter layer is enabled. + - `active_adapters` (`list[str]`): + The names of the active adapters, if any, e.g. `["default"]`. + - `merged_adapters` (`list[str]`): + The names of the merged adapters, if any, e.g. `["default"]`. + - `available_adapters` (`list[str]`): + The names of the available adapters, e.g. `["default"]`. + + Args: + model ([`~PeftModel`]): + The model to get the adapter layer status from. + + Returns: + list[`peft.peft_model.TunerLayerStatus`]: + A list of dataclasses, each containing the status of the corresponding adapter layer. + + """ + return get_layer_status(self) + + def get_model_status(self) -> TunerModelStatus: + """Get the status of tuners of the model. + + This method returns a `TunerModelStatus` dataclass instance, which contains the following attributes: + + - `base_model_type` (`str`): + The type of the base model, e.g. `T5Model`. + - `adapter_model_type` (`str`): + The type of the adapter model, e.g. `LoraModel`. + - `peft_types` (`dict[str, str]`): + The mapping of adapter name to adapter type, e.g. `{"default": "LORA"}`. + - `trainable_params` (`int`): + The number of trainable parameters in the model. + - `total_params` (`int`): + The total number of parameters in the model. + - `num_adapter_layers` (`int`): + The number of adapter layers in the model. + - `enabled` (`bool`, `Literal["irregular"]`): + Whether all adapter layers are enabled. If some are enabled and some are not, this will be `"irregular"`. + This means that your model is in an inconsistent state and might not work as expected. + - `active_adapters` (`list[str]`, `Literal["irregular"]`): + The names of the active adapters. If the active adapters are not consistent across all layers, this will be + `"irregular"`, which means that your model is in an inconsistent state and might not work as expected. + - `merged_adapters` (`list[str]`, `Literal["irregular"]`): + The names of the merged adapters. If the merged adapters are not consistent across all layers, this will be + `"irregular"`, which means that your model is in an inconsistent state and might not work as expected. + - `available_adapters` (`list[str]`): + The names of the available adapters, e.g. `["default"]`. + + Args: + model ([`~PeftModel`]): + The model to get the adapter layer status from. + + Returns: + `peft.peft_model.TunerModelStatus`: + A dataclass containing the status of the model. + + """ + return get_model_status(self) + + @classmethod + def _split_kwargs(cls, kwargs: dict[str, Any]): + _kwargs_not_in_hf_hub_download_signature = ("use_auth_token",) + hf_hub_download_kwargs = {} + other_kwargs = {} + + for key, value in kwargs.items(): + if key in inspect.signature(hf_hub_download).parameters or key in _kwargs_not_in_hf_hub_download_signature: + hf_hub_download_kwargs[key] = value + else: + other_kwargs[key] = value + + return hf_hub_download_kwargs, other_kwargs + + def _update_offload(self, offload_index: dict[str, dict[str, str]], adapters_weights: dict[str, torch.tensor]): + """ + Update the offload_index and safetensors files for loading and mergine PeftModels with disk-offloaded modules. + + Args: + offload_index (Dict[str: str]): + Dictionary of disk-offloaded modules with their metadata and safetensors filenames + adapters_weights (Dict[str: torch.tensor]): + Dictionary of Peft adapter module names and weights + """ + + if not offload_index: + return offload_index + + prefix = "base_model.model." + # rename offload index weight and model names + adapter_names = list(self.peft_config.keys()) + for adapter_name in adapter_names: + keys = list(offload_index.keys()) + block_id = keys[0].split(".")[0] + "." # for writing safetensors key, + + # replace original offload index keys with PeftModel keys + for key in keys: + suffix_pos = key.rfind(".") + extended_prefix = prefix + key[:suffix_pos] + module = dict(self.named_modules())[extended_prefix] + if isinstance(module, BaseTunerLayer): + new_key = prefix + key[:suffix_pos] + ".base_layer" + key[suffix_pos:] + else: + new_key = prefix + key + offload_index[key]["weight_name"] = new_key + offload_index[new_key] = offload_index[key] + del offload_index[key] + + files_seen = set() + # rename safetensors for dispatch + for new_key in list(offload_index.keys()): + fname = offload_index[new_key]["safetensors_file"] + + # make a new file name + new_fname_list = list(fname.split(os.sep)) + for i, name in enumerate(new_fname_list): + if "--" in name: + new_fname_list[i] += "-peft" + break + new_fname = os.path.join(*new_fname_list) + + if fname in files_seen: + continue + safe_dict = {} + with safe_open(fname, framework="pt") as f: + for safe_key in f.keys(): + safe_tensor = f.get_tensor(safe_key) + metadata = f.metadata() + suffix_pos = safe_key.rfind(".") + extended_prefix = prefix + block_id + safe_key[:suffix_pos] + safe_module = dict(self.named_modules())[extended_prefix] + if isinstance(safe_module, BaseTunerLayer): + final_key = extended_prefix + ".base_layer" + safe_key[suffix_pos:] + lora_dict = {key: val for key, val in adapters_weights.items() if extended_prefix in key} + + # add LoRA keys and values to disk offload + for lora_key, lora_val in lora_dict.items(): + divide = lora_key.rfind(".") + new_key = lora_key[:divide] + f".{adapter_name}" + lora_key[divide:] + safe_dict[new_key] = lora_val + else: + final_key = prefix + block_id + safe_key + safe_dict[final_key] = safe_tensor + files_seen.add(new_fname) + + # avoid overwriting original safetensors + for key in safe_dict.keys(): + offload_index[key] = {"safetensors_file": new_fname, "weight_name": key} + + base_name = os.path.dirname(new_fname) + if not os.path.exists(base_name): + os.makedirs(base_name) + safe_save_file(safe_dict, new_fname, metadata=metadata) + + def load_adapter( + self, + model_id: str, + adapter_name: str, + is_trainable: bool = False, + torch_device: Optional[str] = None, + **kwargs: Any, + ): + """ + Load a trained adapter into the model. + + The name for the new adapter should be unique. + + The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active + adapter. + + Args: + adapter_name (`str`): + The name of the adapter to be added. + peft_config ([`PeftConfig`]): + The configuration of the adapter to be added. + is_trainable (`bool`, *optional*, defaults to `False`): + Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be + used for inference. + torch_device (`str`, *optional*, defaults to None): + The device to load the adapter on. If `None`, the device will be inferred. + kwargs: (`optional`): + Additional arguments to modify the way the adapter is loaded, e.g. the token for Hugging Face Hub. + """ + from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING + + hf_hub_download_kwargs, kwargs = self._split_kwargs(kwargs) + if torch_device is None: + torch_device = infer_device() + + if adapter_name not in self.peft_config: + # load the config + peft_config = PEFT_TYPE_TO_CONFIG_MAPPING[ + PeftConfig._get_peft_type( + model_id, + **hf_hub_download_kwargs, + ) + ].from_pretrained( + model_id, + **hf_hub_download_kwargs, + ) + if peft_config.is_prompt_learning and is_trainable: + raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.") + else: + peft_config.inference_mode = not is_trainable + self.add_adapter(adapter_name, peft_config) + + adapters_weights = load_peft_weights(model_id, device=torch_device, **hf_hub_download_kwargs) + + # load the weights into the model + ignore_mismatched_sizes = kwargs.get("ignore_mismatched_sizes", False) + load_result = set_peft_model_state_dict( + self, adapters_weights, adapter_name=adapter_name, ignore_mismatched_sizes=ignore_mismatched_sizes + ) + if ( + (getattr(self, "hf_device_map", None) is not None) + and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0) + and len(self.peft_config) == 1 + ): + device_map = kwargs.get("device_map", "auto") + max_memory = kwargs.get("max_memory", None) + offload_dir = kwargs.get("offload_folder", None) + offload_index = kwargs.get("offload_index", None) + + dispatch_model_kwargs = {} + # Safety checker for previous `accelerate` versions + # `offload_index` was introduced in https://github.com/huggingface/accelerate/pull/873/ + if "offload_index" in inspect.signature(dispatch_model).parameters: + dispatch_model_kwargs["offload_index"] = offload_index + + no_split_module_classes = self._no_split_modules + + if device_map != "sequential": + max_memory = get_balanced_memory( + self, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + low_zero=(device_map == "balanced_low_0"), + ) + + if isinstance(device_map, str): + device_map = infer_auto_device_map( + self, max_memory=max_memory, no_split_module_classes=no_split_module_classes + ) + + self._update_offload(offload_index, adapters_weights) + dispatch_model_kwargs["offload_index"] = offload_index + + dispatch_model( + self, + device_map=device_map, + offload_dir=offload_dir, + **dispatch_model_kwargs, + ) + + hook = AlignDevicesHook(io_same_device=True) + if self.peft_config[adapter_name].is_prompt_learning: + remove_hook_from_submodules(self.prompt_encoder) + add_hook_to_module(self.get_base_model(), hook) + + # Set model in evaluation mode to deactivate Dropout modules by default + if not is_trainable: + self.eval() + return load_result + + def set_adapter(self, adapter_name: str) -> None: + """ + Sets the active adapter. + + Only one adapter can be active at a time. + + Additionally, this function will set the specified adapter to trainable (i.e., requires_grad=True). If this is + not desired, use the following code. + + ```py + >>> for name, param in model_peft.named_parameters(): + ... if ...: # some check on name (ex. if 'lora' in name) + ... param.requires_grad = False + ``` + + Args: + adapter_name (`str`): + The name of the adapter to be set as active. The adapter must be loaded first. + """ + if adapter_name not in self.peft_config: + raise ValueError(f"Adapter {adapter_name} not found.") + self.active_adapter = adapter_name + if not self.peft_config[adapter_name].is_prompt_learning: + self.base_model.set_adapter(adapter_name) + _set_adapter(self, adapter_name) + + @property + def base_model_torch_dtype(self): + return getattr(self.base_model, "dtype", None) + + @property + def active_peft_config(self): + return self.peft_config[self.active_adapter] + + def create_or_update_model_card(self, output_dir: str): + """ + Updates or create model card to include information about peft: + 1. Adds `peft` library tag + 2. Adds peft version + 3. Adds base model info + 4. Adds quantization information if it was used + """ + + filename = os.path.join(output_dir, "README.md") + + card = ModelCard.load(filename) if os.path.exists(filename) else ModelCard.from_template(ModelCardData()) + + card.data["library_name"] = "peft" + + model_config = getattr(self, "config", None) + if hasattr(model_config, "to_dict"): + model_config = model_config.to_dict() + if model_config is not None and "_name_or_path" in model_config: + card.data["base_model"] = model_config["_name_or_path"] + + lines = card.text.splitlines() + + quantization_config = None + if hasattr(model_config, "quantization_config"): + quantization_config = self.config.quantization_config.to_dict() + training_config_text = "" + quantization_prefix = "The following `bitsandbytes` quantization config was used during training:" + # Adds quantization information if it was used + if quantization_config is not None: + training_config_text += f"\n{quantization_prefix}\n" + training_config_text += "\n".join([f"- {name}: {value}" for name, value in quantization_config.items()]) + training_config_text += "\n" + + training_procedure_heading = "## Training procedure" + if quantization_prefix not in lines and bool(training_config_text): + if training_procedure_heading in lines: + lines.insert(lines.index(training_procedure_heading) + 2, training_config_text) + else: + lines.append(f"{training_procedure_heading}\n{training_config_text}") + + # Adds peft version + framework_block_heading = "### Framework versions" + if f"- PEFT {__version__}" not in lines: + if framework_block_heading in lines: + lines.insert(lines.index(framework_block_heading) + 2, f"- PEFT {__version__}") + else: + lines.append(f"{framework_block_heading}\n\n- PEFT {__version__}") + + card.text = "\n".join(lines) + card.save(filename) + + +class PeftModelForSequenceClassification(PeftModel): + """ + Peft model for sequence classification tasks. + + Args: + model ([`~transformers.PreTrainedModel`]): Base transformer model. + peft_config ([`PeftConfig`]): Peft config. + + **Attributes**: + - **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model. + - **cls_layer_name** (`str`) -- The name of the classification layer. + + Example: + + ```py + >>> from transformers import AutoModelForSequenceClassification + >>> from peft import PeftModelForSequenceClassification, get_peft_config + + >>> config = { + ... "peft_type": "PREFIX_TUNING", + ... "task_type": "SEQ_CLS", + ... "inference_mode": False, + ... "num_virtual_tokens": 20, + ... "token_dim": 768, + ... "num_transformer_submodules": 1, + ... "num_attention_heads": 12, + ... "num_layers": 12, + ... "encoder_hidden_size": 768, + ... "prefix_projection": False, + ... "postprocess_past_key_value_function": None, + ... } + + >>> peft_config = get_peft_config(config) + >>> model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased") + >>> peft_model = PeftModelForSequenceClassification(model, peft_config) + >>> peft_model.print_trainable_parameters() + trainable params: 370178 || all params: 108680450 || trainable%: 0.3406113979101117 + ``` + """ + + def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: + super().__init__(model, peft_config, adapter_name) + + classifier_module_names = ["classifier", "score"] + if self.modules_to_save is None: + self.modules_to_save = set(classifier_module_names) + else: + self.modules_to_save.update(classifier_module_names) + + if hasattr(peft_config, "modules_to_save"): + if peft_config.modules_to_save is None: + peft_config.modules_to_save = classifier_module_names[:] + else: + peft_config.modules_to_save.extend(classifier_module_names) + + for name, _ in self.base_model.named_children(): + if any(module_name in name for module_name in self.modules_to_save): + self.cls_layer_name = name + break + + # to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper + _set_trainable(self, adapter_name) + + def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None: + """ + Add an adapter to the model based on the passed configuration. + + This adapter is not trained. To load a trained adapter, check out [`PeftModel.load_adapter`]. + + The name for the new adapter should be unique. + + The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active + adapter. + + Args: + adapter_name (`str`): + The name of the adapter to be added. + peft_config ([`PeftConfig`]): + The configuration of the adapter to be added. + """ + # ensure that additional adapters also add the classifier layer to modules_to_save + if hasattr(peft_config, "modules_to_save"): + classifier_module_names = ["classifier", "score"] + if peft_config.modules_to_save is None: + peft_config.modules_to_save = classifier_module_names[:] + else: + peft_config.modules_to_save.extend(classifier_module_names) + + return super().add_adapter(adapter_name, peft_config) + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_ids=None, + **kwargs, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + peft_config = self.active_peft_config + if not peft_config.is_prompt_learning: + with self._enable_peft_forward_hooks(**kwargs): + kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args} + if peft_config.peft_type == PeftType.POLY: + kwargs["task_ids"] = task_ids + return self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + batch_size = _get_batch_size(input_ids, inputs_embeds) + if attention_mask is not None: + # concat prompt attention mask + prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + if kwargs.get("position_ids", None) is not None: + warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") + kwargs["position_ids"] = None + kwargs.update( + { + "attention_mask": attention_mask, + "labels": labels, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + } + ) + + if peft_config.peft_type == PeftType.PREFIX_TUNING: + return self._prefix_tuning_forward(input_ids=input_ids, **kwargs) + else: + if kwargs.get("token_type_ids", None) is not None: + kwargs["token_type_ids"] = torch.cat( + ( + torch.zeros(batch_size, peft_config.num_virtual_tokens).to(self.word_embeddings.weight.device), + kwargs["token_type_ids"], + ), + dim=1, + ).long() + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids) + prompts = prompts.to(inputs_embeds.dtype) + inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) + return self.base_model(inputs_embeds=inputs_embeds, **kwargs) + + def _prefix_tuning_forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs, + ): + batch_size = _get_batch_size(input_ids, inputs_embeds) + past_key_values = self.get_prompt(batch_size) + fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys()) + kwargs.update( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "inputs_embeds": inputs_embeds, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + "past_key_values": past_key_values, + } + ) + if "past_key_values" in fwd_params: + return self.base_model(labels=labels, **kwargs) + else: + transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name) + fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys()) + if "past_key_values" not in fwd_params: + raise ValueError("Model does not support past key values which are required for prefix tuning.") + outputs = transformer_backbone_name(**kwargs) + pooled_output = outputs[1] if len(outputs) > 1 else outputs[0] + if "dropout" in [name for name, _ in list(self.base_model.named_children())]: + pooled_output = self.base_model.dropout(pooled_output) + logits = self.base_model.get_submodule(self.cls_layer_name)(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.base_model.num_labels == 1: + self.config.problem_type = "regression" + elif self.base_model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.base_model.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.base_model.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class PeftModelForCausalLM(PeftModel): + """ + Peft model for causal language modeling. + + Args: + model ([`~transformers.PreTrainedModel`]): Base transformer model. + peft_config ([`PeftConfig`]): Peft config. + + + Example: + + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import PeftModelForCausalLM, get_peft_config + + >>> config = { + ... "peft_type": "PREFIX_TUNING", + ... "task_type": "CAUSAL_LM", + ... "inference_mode": False, + ... "num_virtual_tokens": 20, + ... "token_dim": 1280, + ... "num_transformer_submodules": 1, + ... "num_attention_heads": 20, + ... "num_layers": 36, + ... "encoder_hidden_size": 1280, + ... "prefix_projection": False, + ... "postprocess_past_key_value_function": None, + ... } + + >>> peft_config = get_peft_config(config) + >>> model = AutoModelForCausalLM.from_pretrained("gpt2-large") + >>> peft_model = PeftModelForCausalLM(model, peft_config) + >>> peft_model.print_trainable_parameters() + trainable params: 1843200 || all params: 775873280 || trainable%: 0.23756456724479544 + ``` + """ + + def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: + super().__init__(model, peft_config, adapter_name) + self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_ids=None, + **kwargs, + ): + peft_config = self.active_peft_config + if not peft_config.is_prompt_learning: + if self.base_model.config.model_type == "mpt": + if inputs_embeds is not None: + raise AssertionError("forward in MPTForCausalLM does not support inputs_embeds") + return self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + if peft_config.peft_type == PeftType.POLY: + kwargs["task_ids"] = task_ids + + with self._enable_peft_forward_hooks(**kwargs): + kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args} + return self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + batch_size = _get_batch_size(input_ids, inputs_embeds) + if attention_mask is not None: + # concat prompt attention mask + prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + if kwargs.get("position_ids", None) is not None: + warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") + kwargs["position_ids"] = None + if kwargs.get("token_type_ids", None) is not None: + warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids") + kwargs["token_type_ids"] = None + kwargs.update( + { + "attention_mask": attention_mask, + "labels": labels, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + } + ) + + if peft_config.peft_type == PeftType.PREFIX_TUNING: + past_key_values = self.get_prompt(batch_size) + return self.base_model( + input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, **kwargs + ) + else: + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + # concat prompt labels + if labels is not None: + prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device) + kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1) + prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids) + prompts = prompts.to(inputs_embeds.dtype) + inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) + return self.base_model(inputs_embeds=inputs_embeds, **kwargs) + + def generate(self, *args, **kwargs): + peft_config = self.active_peft_config + self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation + if hasattr(self.base_model, "model"): + self.base_model.model.generation_config = self.generation_config + else: + self.base_model.generation_config = self.generation_config + try: + if not peft_config.is_prompt_learning: + with self._enable_peft_forward_hooks(*args, **kwargs): + kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args} + outputs = self.base_model.generate(*args, **kwargs) + else: + outputs = self.base_model.generate(**kwargs) + except: + self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation + raise + else: + self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation + return outputs + + def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor] = None, **kwargs): + peft_config = self.active_peft_config + model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs) + + # https://github.com/huggingface/transformers/pull/26681/ introduced new cache format + # for some architectures which requires a special fix for prompt tuning etc. + # TODO: starting with transformers 4.38, all architectures should support caching. + uses_transformers_4_38 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.38.0") + uses_transformers_4_36 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.36.0") + transformers_new_cache_archs = ["llama", "mistral", "persimmon", "phi"] + uses_cache = uses_transformers_4_38 or ( + uses_transformers_4_36 and self.base_model.config.model_type in transformers_new_cache_archs + ) + + if peft_config.peft_type == PeftType.POLY: + model_kwargs["task_ids"] = task_ids + if peft_config.is_prompt_learning: + if uses_cache and (model_kwargs["past_key_values"] is not None): + # change in the logic of `prepare_inputs_for_generation` makes the below code necessary + # In prompt learning methods, past key values are longer when compared to the `input_ids`. + # As such only consider the last input ids in the autogressive generation phase. + if model_kwargs["past_key_values"][0][0].shape[-2] >= model_kwargs["input_ids"].shape[1]: + model_kwargs["input_ids"] = model_kwargs["input_ids"][:, -1:] + + if model_kwargs.get("attention_mask", None) is not None: + size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens + prefix_attention_mask = torch.ones(size).to(model_kwargs["input_ids"].device) + model_kwargs["attention_mask"] = torch.cat( + (prefix_attention_mask, model_kwargs["attention_mask"]), dim=1 + ) + + if model_kwargs.get("position_ids", None) is not None: + warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") + model_kwargs["position_ids"] = None + + if kwargs.get("token_type_ids", None) is not None: + warnings.warn( + "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids" + ) + kwargs["token_type_ids"] = None + + if model_kwargs["past_key_values"] is None and peft_config.peft_type == PeftType.PREFIX_TUNING: + past_key_values = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0]) + model_kwargs["past_key_values"] = past_key_values + else: + if model_kwargs["past_key_values"] is None: + inputs_embeds = self.word_embeddings(model_kwargs["input_ids"]) + prompts = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0], task_ids=task_ids) + prompts = prompts.to(inputs_embeds.dtype) + model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1) + model_kwargs["input_ids"] = None + + # For transformers>=4.38.0 - for some architectures such as Llama, `cache_position` is + # passed in the forward pass to keep track of the position ids of the cache. We have to + # pop that from `model_kwargs` as `cache_position` is properly created by the model, using the passed + # `inputs_embeds`: https://github.com/huggingface/transformers/blob/593230f0a1150ea9c0477b9d859f25daf73c8c33/src/transformers/models/llama/modeling_llama.py#L956 + _ = model_kwargs.pop("cache_position", None) + + return model_kwargs + + +class PeftModelForSeq2SeqLM(PeftModel): + """ + Peft model for sequence-to-sequence language modeling. + + Args: + model ([`~transformers.PreTrainedModel`]): Base transformer model. + peft_config ([`PeftConfig`]): Peft config. + + + Example: + + ```py + >>> from transformers import AutoModelForSeq2SeqLM + >>> from peft import PeftModelForSeq2SeqLM, get_peft_config + + >>> config = { + ... "peft_type": "LORA", + ... "task_type": "SEQ_2_SEQ_LM", + ... "inference_mode": False, + ... "r": 8, + ... "target_modules": ["q", "v"], + ... "lora_alpha": 32, + ... "lora_dropout": 0.1, + ... "fan_in_fan_out": False, + ... "enable_lora": None, + ... "bias": "none", + ... } + + >>> peft_config = get_peft_config(config) + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + >>> peft_model = PeftModelForSeq2SeqLM(model, peft_config) + >>> peft_model.print_trainable_parameters() + trainable params: 884736 || all params: 223843584 || trainable%: 0.3952474242013566 + ``` + """ + + def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: + super().__init__(model, peft_config, adapter_name) + self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation + self.base_model_prepare_encoder_decoder_kwargs_for_generation = ( + self.base_model._prepare_encoder_decoder_kwargs_for_generation + ) + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + decoder_input_ids=None, + decoder_attention_mask=None, + decoder_inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_ids=None, + **kwargs, + ): + peft_config = self.active_peft_config + if not peft_config.is_prompt_learning: + if peft_config.peft_type == PeftType.POLY: + kwargs["task_ids"] = task_ids + + with self._enable_peft_forward_hooks(**kwargs): + kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args} + return self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_inputs_embeds=decoder_inputs_embeds, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + batch_size = _get_batch_size(input_ids, inputs_embeds) + if decoder_attention_mask is not None: + # concat prompt attention mask + prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to( + decoder_attention_mask.device + ) + if peft_config.peft_type not in [PeftType.PROMPT_TUNING, PeftType.P_TUNING]: + decoder_attention_mask = torch.cat((prefix_attention_mask, decoder_attention_mask), dim=1) + + if kwargs.get("position_ids", None) is not None: + warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") + kwargs["position_ids"] = None + if kwargs.get("token_type_ids", None) is not None: + warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids") + kwargs["token_type_ids"] = None + kwargs.update( + { + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "labels": labels, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + } + ) + + if peft_config.peft_type == PeftType.PREFIX_TUNING: + past_key_values = self.get_prompt(batch_size) + return self.base_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + decoder_inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + **kwargs, + ) + elif peft_config.peft_type in [PeftType.PROMPT_TUNING, PeftType.P_TUNING]: + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if attention_mask is not None: + # concat prompt attention mask + prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to( + attention_mask.device + ) + kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + prompts = self.get_prompt(batch_size=batch_size) + prompts = prompts.to(inputs_embeds.dtype) + inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1) + + return self.base_model( + inputs_embeds=inputs_embeds, + decoder_input_ids=decoder_input_ids, + decoder_inputs_embeds=decoder_inputs_embeds, + **kwargs, + ) + else: + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + if decoder_inputs_embeds is None and decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + decoder_inputs_embeds = self.word_embeddings(decoder_input_ids) + + if attention_mask is not None: + # concat prompt attention mask + prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to( + attention_mask.device + ) + kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1) + # concat prompt labels + if labels is not None: + if peft_config.num_transformer_submodules == 1: + kwargs["labels"] = labels + elif peft_config.num_transformer_submodules == 2: + prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device) + kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1) + prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids) + prompts = prompts.to(inputs_embeds.dtype) + inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1) + if peft_config.num_transformer_submodules == 1: + return self.base_model(inputs_embeds=inputs_embeds, **kwargs) + elif peft_config.num_transformer_submodules == 2: + decoder_inputs_embeds = torch.cat( + (prompts[:, peft_config.num_virtual_tokens :], decoder_inputs_embeds), dim=1 + ) + return self.base_model( + inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, **kwargs + ) + + def generate(self, **kwargs): + peft_config = self.active_peft_config + self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation + self.base_model._prepare_encoder_decoder_kwargs_for_generation = ( + self._prepare_encoder_decoder_kwargs_for_generation + ) + try: + if not peft_config.is_prompt_learning: + with self._enable_peft_forward_hooks(**kwargs): + kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args} + outputs = self.base_model.generate(**kwargs) + else: + if "input_ids" not in kwargs: + raise ValueError("input_ids must be provided for Peft model generation") + if kwargs.get("position_ids", None) is not None: + warnings.warn( + "Position ids are not supported for parameter efficient tuning. Ignoring position ids." + ) + kwargs["position_ids"] = None + if kwargs.get("token_type_ids", None) is not None: + warnings.warn( + "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids" + ) + kwargs["token_type_ids"] = None + + if peft_config.peft_type == PeftType.PREFIX_TUNING: + outputs = self.base_model.generate(**kwargs) + elif peft_config.peft_type in [ + PeftType.PROMPT_TUNING, + PeftType.P_TUNING, + PeftType.MULTITASK_PROMPT_TUNING, + ]: + kwargs = deepcopy(kwargs) + + if "encoder_outputs" in kwargs: + del kwargs["encoder_outputs"] + warnings.warn( + "`encoder_outputs` should not be passed to `generate` when using prompt tuning. Ignoring it." + ) + + input_ids = kwargs.pop("input_ids") + inputs_embeds = self.word_embeddings(input_ids) + batch_size = inputs_embeds.shape[0] + prompts = self.get_prompt(batch_size=batch_size, task_ids=kwargs.pop("task_ids", None)) + prompts = prompts.to(inputs_embeds.dtype) + + inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1) + kwargs["inputs_embeds"] = inputs_embeds + + if "attention_mask" in kwargs: + prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to( + kwargs["attention_mask"].device + ) + kwargs["attention_mask"] = torch.cat((prefix_attention_mask, kwargs["attention_mask"]), dim=1) + + return self.base_model.generate(**kwargs) + else: + raise NotImplementedError + except: + self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation + self.base_model._prepare_encoder_decoder_kwargs_for_generation = ( + self.base_model_prepare_encoder_decoder_kwargs_for_generation + ) + raise + else: + self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation + self.base_model._prepare_encoder_decoder_kwargs_for_generation = ( + self.base_model_prepare_encoder_decoder_kwargs_for_generation + ) + return outputs + + def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, **kwargs): + peft_config = self.active_peft_config + model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs) + if peft_config.peft_type == PeftType.POLY: + model_kwargs["task_ids"] = task_ids + if model_kwargs["past_key_values"] is None and peft_config.peft_type == PeftType.PREFIX_TUNING: + batch_size = model_kwargs["decoder_input_ids"].shape[0] + past_key_values = self.get_prompt(batch_size) + model_kwargs["past_key_values"] = past_key_values + + return model_kwargs + + +class PeftModelForTokenClassification(PeftModel): + """ + Peft model for token classification tasks. + + Args: + model ([`~transformers.PreTrainedModel`]): Base transformer model. + peft_config ([`PeftConfig`]): Peft config. + + **Attributes**: + - **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model. + - **cls_layer_name** (`str`) -- The name of the classification layer. + + Example: + + ```py + >>> from transformers import AutoModelForSequenceClassification + >>> from peft import PeftModelForTokenClassification, get_peft_config + + >>> config = { + ... "peft_type": "PREFIX_TUNING", + ... "task_type": "TOKEN_CLS", + ... "inference_mode": False, + ... "num_virtual_tokens": 20, + ... "token_dim": 768, + ... "num_transformer_submodules": 1, + ... "num_attention_heads": 12, + ... "num_layers": 12, + ... "encoder_hidden_size": 768, + ... "prefix_projection": False, + ... "postprocess_past_key_value_function": None, + ... } + + >>> peft_config = get_peft_config(config) + >>> model = AutoModelForTokenClassification.from_pretrained("bert-base-cased") + >>> peft_model = PeftModelForTokenClassification(model, peft_config) + >>> peft_model.print_trainable_parameters() + trainable params: 370178 || all params: 108680450 || trainable%: 0.3406113979101117 + ``` + """ + + def __init__(self, model: torch.nn.Module, peft_config: PeftConfig = None, adapter_name: str = "default") -> None: + super().__init__(model, peft_config, adapter_name) + + classifier_module_names = ["classifier", "score"] + if self.modules_to_save is None: + self.modules_to_save = set(classifier_module_names) + else: + self.modules_to_save.update(classifier_module_names) + + if hasattr(peft_config, "modules_to_save"): + if peft_config.modules_to_save is None: + peft_config.modules_to_save = classifier_module_names[:] + else: + peft_config.modules_to_save.extend(classifier_module_names) + + for name, _ in self.base_model.named_children(): + if any(module_name in name for module_name in self.modules_to_save): + self.cls_layer_name = name + break + + # to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper + _set_trainable(self, adapter_name) + + def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None: + """ + Add an adapter to the model based on the passed configuration. + + This adapter is not trained. To load a trained adapter, check out [`PeftModel.load_adapter`]. + + The name for the new adapter should be unique. + + The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active + adapter. + + Args: + adapter_name (`str`): + The name of the adapter to be added. + peft_config ([`PeftConfig`]): + The configuration of the adapter to be added. + """ + # ensure that additional adapters also add the classifier layer to modules_to_save + if hasattr(peft_config, "modules_to_save"): + classifier_module_names = ["classifier", "score"] + if peft_config.modules_to_save is None: + peft_config.modules_to_save = classifier_module_names[:] + else: + peft_config.modules_to_save.extend(classifier_module_names) + + return super().add_adapter(adapter_name, peft_config) + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_ids=None, + **kwargs, + ): + peft_config = self.active_peft_config + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if not peft_config.is_prompt_learning: + with self._enable_peft_forward_hooks(**kwargs): + kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args} + if peft_config.peft_type == PeftType.POLY: + kwargs["task_ids"] = task_ids + return self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + batch_size = _get_batch_size(input_ids, inputs_embeds) + if attention_mask is not None: + # concat prompt attention mask + prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + if kwargs.get("position_ids", None) is not None: + warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") + kwargs["position_ids"] = None + kwargs.update( + { + "attention_mask": attention_mask, + "labels": labels, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + } + ) + + if peft_config.peft_type == PeftType.PREFIX_TUNING: + return self._prefix_tuning_forward(input_ids=input_ids, **kwargs) + else: + if kwargs.get("token_type_ids", None) is not None: + kwargs["token_type_ids"] = torch.cat( + ( + torch.zeros(batch_size, peft_config.num_virtual_tokens).to(self.word_embeddings.weight.device), + kwargs["token_type_ids"], + ), + dim=1, + ).long() + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids) + prompts = prompts.to(inputs_embeds.dtype) + inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) + return self.base_model(inputs_embeds=inputs_embeds, **kwargs) + + def _prefix_tuning_forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs, + ): + batch_size = _get_batch_size(input_ids, inputs_embeds) + past_key_values = self.get_prompt(batch_size) + fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys()) + kwargs.update( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "inputs_embeds": inputs_embeds, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + "past_key_values": past_key_values, + } + ) + if "past_key_values" in fwd_params: + return self.base_model(labels=labels, **kwargs) + else: + transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name) + fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys()) + if "past_key_values" not in fwd_params: + raise ValueError("Model does not support past key values which are required for prefix tuning.") + outputs = transformer_backbone_name(**kwargs) + sequence_output = outputs[0] + if "dropout" in [name for name, _ in list(self.base_model.named_children())]: + sequence_output = self.base_model.dropout(sequence_output) + logits = self.base_model.get_submodule(self.cls_layer_name)(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class PeftModelForQuestionAnswering(PeftModel): + """ + Peft model for extractive question answering. + + Args: + model ([`~transformers.PreTrainedModel`]): Base transformer model. + peft_config ([`PeftConfig`]): Peft config. + + **Attributes**: + - **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model. + - **cls_layer_name** (`str`) -- The name of the classification layer. + + Example: + + ```py + >>> from transformers import AutoModelForQuestionAnswering + >>> from peft import PeftModelForQuestionAnswering, get_peft_config + + >>> config = { + ... "peft_type": "LORA", + ... "task_type": "QUESTION_ANS", + ... "inference_mode": False, + ... "r": 16, + ... "target_modules": ["query", "value"], + ... "lora_alpha": 32, + ... "lora_dropout": 0.05, + ... "fan_in_fan_out": False, + ... "bias": "none", + ... } + + >>> peft_config = get_peft_config(config) + >>> model = AutoModelForQuestionAnswering.from_pretrained("bert-base-cased") + >>> peft_model = PeftModelForQuestionAnswering(model, peft_config) + >>> peft_model.print_trainable_parameters() + trainable params: 592900 || all params: 108312580 || trainable%: 0.5473971721475013 + ``` + """ + + def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: + super().__init__(model, peft_config, adapter_name) + + qa_module_names = ["qa_outputs"] + if self.modules_to_save is None: + self.modules_to_save = set(qa_module_names) + else: + self.modules_to_save.update(qa_module_names) + + if hasattr(peft_config, "modules_to_save"): + if peft_config.modules_to_save is None: + peft_config.modules_to_save = qa_module_names[:] + else: + peft_config.modules_to_save.extend(qa_module_names) + + for name, _ in self.base_model.named_children(): + if any(module_name in name for module_name in self.modules_to_save): + self.cls_layer_name = name + break + + # to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper + _set_trainable(self, adapter_name) + + def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None: + """ + Add an adapter to the model based on the passed configuration. + + This adapter is not trained. To load a trained adapter, check out [`PeftModel.load_adapter`]. + + The name for the new adapter should be unique. + + The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active + adapter. + + Args: + adapter_name (`str`): + The name of the adapter to be added. + peft_config ([`PeftConfig`]): + The configuration of the adapter to be added. + """ + # ensure that additional adapters also add the classifier layer to modules_to_save + if hasattr(peft_config, "modules_to_save"): + qa_module_names = ["qa_outputs"] + if peft_config.modules_to_save is None: + peft_config.modules_to_save = qa_module_names[:] + else: + peft_config.modules_to_save.extend(qa_module_names) + + return super().add_adapter(adapter_name, peft_config) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_ids=None, + **kwargs, + ): + peft_config = self.active_peft_config + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if not peft_config.is_prompt_learning: + if peft_config.peft_type == PeftType.POLY: + kwargs["task_ids"] = task_ids + + with self._enable_peft_forward_hooks(**kwargs): + kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args} + return self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + start_positions=start_positions, + end_positions=end_positions, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + batch_size = _get_batch_size(input_ids, inputs_embeds) + if attention_mask is not None: + # concat prompt attention mask + prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + if kwargs.get("position_ids", None) is not None: + warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") + kwargs["position_ids"] = None + kwargs.update( + { + "attention_mask": attention_mask, + "start_positions": start_positions, + "end_positions": end_positions, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + } + ) + + if peft_config.peft_type == PeftType.PREFIX_TUNING: + return self._prefix_tuning_forward(input_ids=input_ids, **kwargs) + else: + if kwargs.get("token_type_ids", None) is not None: + kwargs["token_type_ids"] = torch.cat( + ( + torch.zeros(batch_size, peft_config.num_virtual_tokens).to(self.word_embeddings.weight.device), + kwargs["token_type_ids"], + ), + dim=1, + ).long() + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + prompts = self.get_prompt(batch_size=batch_size) + prompts = prompts.to(inputs_embeds.dtype) + inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) + return self.base_model(inputs_embeds=inputs_embeds, **kwargs) + + def _prefix_tuning_forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs, + ): + batch_size = _get_batch_size(input_ids, inputs_embeds) + past_key_values = self.get_prompt(batch_size) + fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys()) + kwargs.update( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "inputs_embeds": inputs_embeds, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + "past_key_values": past_key_values, + } + ) + if "past_key_values" in fwd_params: + return self.base_model(start_positions=start_positions, end_positions=end_positions, **kwargs) + else: + transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name) + fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys()) + if "past_key_values" not in fwd_params: + raise ValueError("Model does not support past key values which are required for prefix tuning.") + outputs = transformer_backbone_name(**kwargs) + sequence_output = outputs[0] + if "dropout" in [name for name, _ in list(self.base_model.named_children())]: + sequence_output = self.base_model.dropout(sequence_output) + logits = self.base_model.get_submodule(self.cls_layer_name)(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class PeftModelForFeatureExtraction(PeftModel): + """ + Peft model for extracting features/embeddings from transformer models + + Args: + model ([`~transformers.PreTrainedModel`]): Base transformer model. + peft_config ([`PeftConfig`]): Peft config. + + **Attributes**: + - **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model. + + Example: + + ```py + >>> from transformers import AutoModel + >>> from peft import PeftModelForFeatureExtraction, get_peft_config + + >>> config = { + ... "peft_type": "LORA", + ... "task_type": "FEATURE_EXTRACTION", + ... "inference_mode": False, + ... "r": 16, + ... "target_modules": ["query", "value"], + ... "lora_alpha": 32, + ... "lora_dropout": 0.05, + ... "fan_in_fan_out": False, + ... "bias": "none", + ... } + >>> peft_config = get_peft_config(config) + >>> model = AutoModel.from_pretrained("bert-base-cased") + >>> peft_model = PeftModelForFeatureExtraction(model, peft_config) + >>> peft_model.print_trainable_parameters() + ``` + """ + + def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default"): + super().__init__(model, peft_config, adapter_name) + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_ids=None, + **kwargs, + ): + peft_config = self.active_peft_config + if not peft_config.is_prompt_learning: + if peft_config.peft_type == PeftType.POLY: + kwargs["task_ids"] = task_ids + + with self._enable_peft_forward_hooks(**kwargs): + kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args} + return self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + batch_size = _get_batch_size(input_ids, inputs_embeds) + if attention_mask is not None: + # concat prompt attention mask + prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + if kwargs.get("position_ids", None) is not None: + warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") + kwargs["position_ids"] = None + if kwargs.get("token_type_ids", None) is not None: + warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids") + kwargs["token_type_ids"] = None + kwargs.update( + { + "attention_mask": attention_mask, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + } + ) + + if peft_config.peft_type == PeftType.PREFIX_TUNING: + past_key_values = self.get_prompt(batch_size) + return self.base_model(input_ids=input_ids, past_key_values=past_key_values, **kwargs) + else: + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + prompts = self.get_prompt(batch_size=batch_size) + prompts = prompts.to(inputs_embeds.dtype) + inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) + return self.base_model(inputs_embeds=inputs_embeds, **kwargs) + + +@dataclass +class TunerLayerStatus: + name: str + module_type: str + enabled: bool + active_adapters: list[str] + merged_adapters: list[str] + requires_grad: dict[str, bool | Literal["irregular"]] + available_adapters: list[str] + + +def get_layer_status(model: torch.nn.Module) -> list[TunerLayerStatus]: + """Get the status of each adapter layer in the model. + + This function returns a list of `TunerLayerStatus` dataclass instances, each of which contains the following + attributes: + + - `name` (`str`): + The name of the adapter layer, e.g. `model.encoder.block.0.layer.0.SelfAttention.q`. + - `module_type` (`str`): + The type of the adapter layer, e.g. `lora.Linear`. + - `enabled` (`bool`): + Whether the adapter layer is enabled. + - `active_adapters` (`list[str]`): + The names of the active adapters, if any, e.g. `["default"]`. + - `merged_adapters` (`list[str]`): + The names of the merged adapters, if any, e.g. `["default"]`. + - requires_grad : dict[str, bool | Literal["irregular"]] + The requires_grad status of the parameters for each adapter module. Ideally, it should be either `True` or + `False`. If the requires_grad status is not consistent across all parameters, the value will be set to + `"irregular"`. + - `available_adapters` (`list[str]`): + The names of the available adapters, e.g. `["default"]`. + + Args: + model ([Union[`~PeftModel`, `~transformers.PreTrainedModel`, `nn.Module`]]): + The model to get the adapter layer status from. + + Returns: + list[`peft.peft_model.TunerLayerStatus`]: + A list of dataclasses, each containing the status of the corresponding adapter layer. + + """ + if isinstance(model, PeftModel): + base_model = model.base_model + if not isinstance(base_model, BaseTuner): + raise TypeError( + "get_layer_status() got an invalid PeftModel instance; prefix tuning and adaption prompt are not " + "supported." + ) + else: + base_model = model + + layer_status: list[TunerLayerStatus] = [] + for name, module in base_model.named_modules(): + if not isinstance(module, BaseTunerLayer): + continue + + # determine if all submodules/parameters if this module require grad or not + mapping_requires_grad_list: dict[str, list[bool]] = collections.defaultdict(list) + for adapter_module_name in module.adapter_layer_names: + adapter_module = getattr(module, adapter_module_name) + if isinstance(adapter_module, torch.nn.ModuleDict): + for key, submodule in adapter_module.items(): + for param in submodule.parameters(): + mapping_requires_grad_list[key].append(param.requires_grad) + elif isinstance(adapter_module, torch.nn.ParameterDict): + for key, param in adapter_module.items(): + mapping_requires_grad_list[key].append(param.requires_grad) + else: + # strange, we don't know how to handle this, ignore for now + pass + + def check_irrgular(vals: list[bool]) -> bool | Literal["irregular"]: + if all(vals): + return True + if not any(vals): + return False + return "irregular" + + requires_grad = {key: check_irrgular(vals) for key, vals in mapping_requires_grad_list.items()} + + status = TunerLayerStatus( + name=name, + module_type=repr(module).partition("(")[0], + enabled=not module.disable_adapters, + active_adapters=module.active_adapters, + merged_adapters=module.merged_adapters, + requires_grad=requires_grad, + available_adapters=sorted(module._get_available_adapters()), + ) + layer_status.append(status) + + if not layer_status: + raise ValueError( + "No adapter layers found in the model, please ensure that it's a PEFT model or that you have PEFT adapters " + "injected in the model." + ) + + return layer_status + + +@dataclass +class TunerModelStatus: + base_model_type: str + adapter_model_type: str + peft_types: dict[str, str] + trainable_params: int + total_params: int + num_adapter_layers: int + enabled: bool | Literal["irregular"] + active_adapters: list[str] | Literal["irregular"] + merged_adapters: list[str] | Literal["irregular"] + requires_grad: dict[str, bool | Literal["irregular"]] + available_adapters: list[str] + + +def get_model_status(model: torch.nn.Module) -> TunerModelStatus: + """Get the status of tuners of the model. + + This function returns a `TunerModelStatus` dataclass instance, which contains the following attributes: + + - `base_model_type` (`str`): + The type of the base model, e.g. `T5Model`. + - `adapter_model_type` (`str`): + The type of the adapter model, e.g. `LoraModel`. + - `peft_types` (`dict[str, str]`): + The mapping of adapter name to adapter type, e.g. `{"default": "LORA"}`. + - `trainable_params` (`int`): + The number of trainable parameters in the model. + - `total_params` (`int`): + The total number of parameters in the model. + - `num_adapter_layers` (`int`): + The number of adapter layers in the model. + - `enabled` (`bool`, `Literal["irregular"]`): + Whether all adapter layers are enabled. If some are enabled and some are not, this will be `"irregular"`. This + means that your model is in an inconsistent state and might not work as expected. + - `active_adapters` (`list[str]`, `Literal["irregular"]`): + The names of the active adapters. If the active adapters are not consistent across all layers, this will be + `"irregular"`, which means that your model is in an inconsistent state and might not work as expected. + - `merged_adapters` (`list[str]`, `Literal["irregular"]`): + The names of the merged adapters. If the merged adapters are not consistent across all layers, this will be + `"irregular"`, which means that your model is in an inconsistent state and might not work as expected. + - `requires_grad` (`dict[str, bool | Literal["irregular"]]`): + Whether for the given adapter, all adapter layers have `requires_grad` set to `True` or `False`. If there is a + mix, this will be set to `"irregular"`, which means that your model is in an inconsistent state and might not + work as expected. + - `available_adapters` (`list[str]`): + The names of the available adapters, e.g. `["default"]`. + + Args: + model ([Union[`~PeftModel`, `~transformers.PreTrainedModel`, `nn.Module`]]): + The model to get the adapter layer status from. + + Returns: + `peft.peft_model.TunerModelStatus`: + A dataclass containing the status of the model. + + """ + if isinstance(model, PeftModel): + if not isinstance(model.base_model, BaseTuner): + raise TypeError( + "get_model_status() got an invalid PeftModel instance; prefix tuning and adaption prompt are not " + "supported." + ) + base_model_type = model.get_base_model().__class__.__name__ + trainable_params, total_params = model.get_nb_trainable_parameters() + base_model = model.base_model + peft_types = {key: str(config.peft_type).partition(".")[-1] for key, config in base_model.peft_config.items()} + adapter_model_type = base_model.__class__.__name__ + elif isinstance(model, PreTrainedModel): + base_model_type = model.__class__.__name__ + trainable_params, total_params = PeftModel.get_nb_trainable_parameters(model) + base_model = model + peft_types = {} + adapter_model_type = "None" + else: + base_model_type = "other" + trainable_params, total_params = PeftModel.get_nb_trainable_parameters(model) + base_model = model + peft_types = {} + adapter_model_type = "None" + + layer_status = get_layer_status(model) + num_adapter_layers = len(layer_status) + + enabled_set: set[bool] = {status.enabled for status in layer_status} # must be {True}, {False}, or {True, False} + enabled: bool | Literal["irregular"] + if len(enabled_set) == 1: + enabled = enabled_set.pop() + else: + enabled = "irregular" + + available_adapters: list[str] = sorted(set().union(*(status.available_adapters for status in layer_status))) + + # ideally, active adapters should be consistent across all layers of the model, but we cannot guarantee it + all_active_adapters: set[tuple[str, ...]] = {tuple(status.active_adapters) for status in layer_status} + active_adapters: list[str] | Literal["irregular"] + if not all_active_adapters: + active_adapters = [] + elif len(all_active_adapters) == 1: + active_adapters = list(all_active_adapters.pop()) + else: + active_adapters = "irregular" + + # Here we determine what adapters are merged. This is not trivial because multiple adapters can be merged or not at + # the same time. Some layers may only have adapter A, some only adapter B, so it's not as easy as just checking + # which adapters are merged on each layer. + + # First, determine all adapters that are merged on at least on module. + merged_all: set[str] = set() + for status in layer_status: + merged_all.update(status.merged_adapters) + + # Next, check if on any layer, on of these adapters is not merged. + merged_adapters: list[str] | Literal["irregular"] = sorted(merged_all) + for status in layer_status: + unmerged = set(status.available_adapters) - set(status.merged_adapters) + if unmerged & merged_all: + # there is overlap between unmerged adapters and adapters that should be merged + merged_adapters = "irregular" + break + + # check status of requires_grad + # first, merge the values for all layers + requires_grad_all: dict[str, list[bool | Literal["irregular"]]] = collections.defaultdict(list) + for status in layer_status: + for key, val in status.requires_grad.items(): + requires_grad_all[key].append(val) + + # then, check if the values are consistent + def check_irrgular(vals: list[bool | Literal["irregular"]]) -> bool | Literal["irregular"]: + if all(val is True for val in vals): + return True + if all(val is False for val in vals): + return False + return "irregular" + + requires_grad = {key: check_irrgular(vals) for key, vals in requires_grad_all.items()} + + adapter_model_status = TunerModelStatus( + base_model_type=base_model_type, + adapter_model_type=adapter_model_type, + peft_types=peft_types, + trainable_params=trainable_params, + total_params=total_params, + num_adapter_layers=num_adapter_layers, + enabled=enabled, + active_adapters=active_adapters, + merged_adapters=merged_adapters, + requires_grad=requires_grad, + available_adapters=available_adapters, + ) + return adapter_model_status diff --git a/peft/py.typed b/peft/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/peft/tuners/__init__.py b/peft/tuners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a8b6b75837362e929f9a79f8d8caf48f45d6b3a7 --- /dev/null +++ b/peft/tuners/__init__.py @@ -0,0 +1,35 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all + +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel +from .lora import LoraConfig, LoraModel, LoftQConfig +from .loha import LoHaConfig, LoHaModel +from .lokr import LoKrConfig, LoKrModel +from .ia3 import IA3Config, IA3Model +from .adalora import AdaLoraConfig, AdaLoraModel +from .boft import BOFTConfig, BOFTModel +from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType +from .prefix_tuning import PrefixEncoder, PrefixTuningConfig +from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit +from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit +from .oft import OFTConfig, OFTModel +from .mixed import MixedModel +from .poly import PolyConfig, PolyModel +from .ln_tuning import LNTuningConfig, LNTuningModel +from .vera import VeraConfig, VeraModel diff --git a/peft/tuners/__pycache__/__init__.cpython-310.pyc b/peft/tuners/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5098f1277fe52e8efd665799040a2005cf2a006f Binary files /dev/null and b/peft/tuners/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/__pycache__/lycoris_utils.cpython-310.pyc b/peft/tuners/__pycache__/lycoris_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e41938019e5dd37b2f3254d5a8ebff0ea8aaa9e0 Binary files /dev/null and b/peft/tuners/__pycache__/lycoris_utils.cpython-310.pyc differ diff --git a/peft/tuners/__pycache__/tuners_utils.cpython-310.pyc b/peft/tuners/__pycache__/tuners_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a6a0b9a02d7ac8b8938dbf5b3ed05bc3914b58b Binary files /dev/null and b/peft/tuners/__pycache__/tuners_utils.cpython-310.pyc differ diff --git a/peft/tuners/adalora/__init__.py b/peft/tuners/adalora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4bdb8a540bed454fd95633265e8cdceb3e792e3b --- /dev/null +++ b/peft/tuners/adalora/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available + +from .config import AdaLoraConfig +from .gptq import SVDQuantLinear +from .layer import AdaLoraLayer, RankAllocator, SVDLinear +from .model import AdaLoraModel + + +__all__ = ["AdaLoraConfig", "AdaLoraLayer", "AdaLoraModel", "SVDLinear", "RankAllocator", "SVDQuantLinear"] + + +def __getattr__(name): + if (name == "SVDLinear8bitLt") and is_bnb_available(): + from .bnb import SVDLinear8bitLt + + return SVDLinear8bitLt + + if (name == "SVDLinear4bit") and is_bnb_4bit_available(): + from .bnb import SVDLinear4bit + + return SVDLinear4bit + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/peft/tuners/adalora/__pycache__/__init__.cpython-310.pyc b/peft/tuners/adalora/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c7d3299fbcbd7e7585de0a16fda2527ce97c073 Binary files /dev/null and b/peft/tuners/adalora/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/adalora/__pycache__/bnb.cpython-310.pyc b/peft/tuners/adalora/__pycache__/bnb.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21ab08756f2efaf08bf91569fa6c84f7da6931ad Binary files /dev/null and b/peft/tuners/adalora/__pycache__/bnb.cpython-310.pyc differ diff --git a/peft/tuners/adalora/__pycache__/config.cpython-310.pyc b/peft/tuners/adalora/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3754e90ab061cd911b5736f22b667832701ec75 Binary files /dev/null and b/peft/tuners/adalora/__pycache__/config.cpython-310.pyc differ diff --git a/peft/tuners/adalora/__pycache__/gptq.cpython-310.pyc b/peft/tuners/adalora/__pycache__/gptq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f335b6eebf80c203177a76c433e07fa90830f355 Binary files /dev/null and b/peft/tuners/adalora/__pycache__/gptq.cpython-310.pyc differ diff --git a/peft/tuners/adalora/__pycache__/layer.cpython-310.pyc b/peft/tuners/adalora/__pycache__/layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6e29b4d5825e2f2d415b3b3c7102b77eb02d8be Binary files /dev/null and b/peft/tuners/adalora/__pycache__/layer.cpython-310.pyc differ diff --git a/peft/tuners/adalora/__pycache__/model.cpython-310.pyc b/peft/tuners/adalora/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f2021bb5753a376680ecf6e525756e213ab2008 Binary files /dev/null and b/peft/tuners/adalora/__pycache__/model.cpython-310.pyc differ diff --git a/peft/tuners/adalora/bnb.py b/peft/tuners/adalora/bnb.py new file mode 100644 index 0000000000000000000000000000000000000000..b8c32a815cef22b938b840a1b6013592a338936b --- /dev/null +++ b/peft/tuners/adalora/bnb.py @@ -0,0 +1,145 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available + +from .layer import AdaLoraLayer + + +if is_bnb_available(): + + class SVDLinear8bitLt(torch.nn.Module, AdaLoraLayer): + # Low-rank matrix for SVD-based adaptation + def __init__( + self, + base_layer: torch.nn.Module, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + **kwargs, + ) -> None: + super().__init__() + AdaLoraLayer.__init__(self, base_layer) + # Freezing the pre-trained weight matrix + self.get_base_layer().weight.requires_grad = False + + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # note: no check for self.merged because merging is not supported (yet) + result = self.base_layer(x) + + if self.disable_adapters: + return result + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + if x.dtype != torch.float32: + x = x.float() + + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + lora_E = self.lora_E[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + ranknum = self.ranknum[active_adapter] + 1e-5 + + output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling / ranknum + # inplace operation on view is forbidden for MatMul8bitLtBackward, so avoid it + result = result + output + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "adalora." + rep + + +if is_bnb_4bit_available(): + + class SVDLinear4bit(torch.nn.Module, AdaLoraLayer): + # Low-rank matrix for SVD-based adaptation + def __init__( + self, + base_layer: torch.nn.Module, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + **kwargs, + ) -> None: + super().__init__() + AdaLoraLayer.__init__(self, base_layer) + # Freezing the pre-trained weight matrix + self.get_base_layer().weight.requires_grad = False + + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + # note: no check for self.merged because merging is not supported (yet) + result = self.base_layer(x, *args, **kwargs) + + if self.disable_adapters: + return result + + # As per Tim Dettmers, for 4bit, we need to defensively clone here. + # The reason is that in some cases, an error can occur that backprop + # does not work on a manipulated view. This issue may be solved with + # newer PyTorch versions but this would need extensive testing to be + # sure. + result = result.clone() + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + lora_E = self.lora_E[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + ranknum = self.ranknum[active_adapter] + 1e-5 + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + compute_dtype = lora_A.dtype + if x.dtype != compute_dtype: + x = x.to(compute_dtype) + + output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling / ranknum + result += output + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "adalora." + rep diff --git a/peft/tuners/adalora/config.py b/peft/tuners/adalora/config.py new file mode 100644 index 0000000000000000000000000000000000000000..7972e81dd38d36a522ca5923f12e0599174645ba --- /dev/null +++ b/peft/tuners/adalora/config.py @@ -0,0 +1,69 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from peft.tuners.lora import LoraConfig +from peft.utils import PeftType + + +@dataclass +class AdaLoraConfig(LoraConfig): + """ + This is the configuration class to store the configuration of a [`~peft.AdaLora`]. + + Args: + target_r (`int`): The target average rank of incremental matrix. + init_r (`int`): The initial rank for each incremental matrix. + tinit (`int`): The steps of initial fine-tuning warmup. + tfinal (`int`): The step of final fine-tuning. + deltaT (`int`): The time internval between two budget allocations. + beta1 (`float`): The hyperparameter of EMA for sensitivity smoothing. + beta2 (`float`): The hyperparameter of EMA for undertainty quantification. + orth_reg_weight (`float`): The coefficient of orthogonal regularization. + total_step (`int`): The total training steps that should be specified before training. + rank_pattern (`list`): The allocated rank for each weight matrix by RankAllocator. + """ + + target_r: int = field(default=8, metadata={"help": "Target Lora matrix dimension."}) + init_r: int = field(default=12, metadata={"help": "Initial Lora matrix dimension."}) + tinit: int = field(default=0, metadata={"help": "The steps of initial warmup."}) + tfinal: int = field(default=0, metadata={"help": "The steps of final warmup."}) + deltaT: int = field(default=1, metadata={"help": "Step interval of rank allocation."}) + beta1: float = field(default=0.85, metadata={"help": "Hyperparameter of EMA."}) + beta2: float = field(default=0.85, metadata={"help": "Hyperparameter of EMA."}) + orth_reg_weight: float = field(default=0.5, metadata={"help": "The orthogonal regularization coefficient."}) + total_step: Optional[int] = field(default=None, metadata={"help": "The total training steps."}) + rank_pattern: Optional[dict] = field(default=None, metadata={"help": "The saved rank pattern."}) + + def __post_init__(self): + self.peft_type = PeftType.ADALORA + + if self.use_dora: + raise ValueError(f"{self.peft_type} does not support DoRA.") + + if self.loftq_config: + raise ValueError(f"{self.peft_type} does not support LOFTQ.") + + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) + # if target_modules is a regex expression, then layers_to_transform should be None + if isinstance(self.target_modules, str) and self.layers_to_transform is not None: + raise ValueError("`layers_to_transform` cannot be used when `target_modules` is a str.") + + # if target_modules is a regex expression, then layers_pattern should be None + if isinstance(self.target_modules, str) and self.layers_pattern is not None: + raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") diff --git a/peft/tuners/adalora/gptq.py b/peft/tuners/adalora/gptq.py new file mode 100644 index 0000000000000000000000000000000000000000..910377c5db5908727ed4753fd15b24e68821ce00 --- /dev/null +++ b/peft/tuners/adalora/gptq.py @@ -0,0 +1,72 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from .layer import AdaLoraLayer + + +class SVDQuantLinear(torch.nn.Module, AdaLoraLayer): + def __init__( + self, + base_layer, + adapter_name, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + **kwargs, + ) -> None: + super().__init__() + AdaLoraLayer.__init__(self, base_layer) + + # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter + # for backwards compatibility + self.quant_linear_module = base_layer + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + result = self.quant_linear_module(x) + + if self.disable_adapters: + return result + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + lora_E = self.lora_E[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + ranknum = self.ranknum[active_adapter] + 1e-5 + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + if x.dtype != torch.float32: + x = x.float() + + output = (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum + # TODO: here, the dtype conversion is applied on the *whole expression*, + # not the intermediate result, unlike for SVDLinear8bitLT and + # SVDLinear4bit, is that correct? + if requires_conversion: + output = output.to(expected_dtype) + result += output + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "adalora." + rep diff --git a/peft/tuners/adalora/layer.py b/peft/tuners/adalora/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..3b17e3124b17df864ea5266266ee60eb62fcac11 --- /dev/null +++ b/peft/tuners/adalora/layer.py @@ -0,0 +1,361 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Any, List, Optional + +import packaging +import torch +import transformers +from torch import nn + +from peft.tuners.lora import LoraLayer +from peft.tuners.tuners_utils import check_adapters_to_merge +from peft.utils import transpose + + +if packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.33.0"): + from transformers.integrations import deepspeed_config +else: + from transformers.deepspeed import deepspeed_config + + +class AdaLoraLayer(LoraLayer): + # List all names of layers that may contain adapter weights + # Note: ranknum doesn't need to be included as it is not an nn.Module + adapter_layer_names = ("lora_A", "lora_B", "lora_E", "lora_embedding_A", "lora_embedding_B") + # other_param_names is defined in LoraLayer + + def __init__(self, base_layer: nn.Module) -> None: + super().__init__(base_layer) + self.lora_E = nn.ParameterDict({}) + self.lora_A = nn.ParameterDict({}) + self.lora_B = nn.ParameterDict({}) + self.ranknum = nn.ParameterDict({}) + + def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): + if r < 0: + # note: r == 0 is allowed for AdaLora, see #1539 + raise ValueError(f"`r` should be a positive integer or 0, but the value passed is {r}") + + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout[adapter_name] = lora_dropout_layer + # Actual trainable parameters + # Right singular vectors + self.lora_A[adapter_name] = nn.Parameter(torch.randn(r, self.in_features)) + # Singular values + self.lora_E[adapter_name] = nn.Parameter(torch.randn(r, 1)) + # Left singular vectors + self.lora_B[adapter_name] = nn.Parameter(torch.randn(self.out_features, r)) + # The current rank + self.ranknum[adapter_name] = nn.Parameter(torch.randn(1), requires_grad=False) + self.ranknum[adapter_name].data.fill_(float(r)) + self.ranknum[adapter_name].requires_grad = False + self.scaling[adapter_name] = lora_alpha if lora_alpha > 0 else float(r) + if init_lora_weights: + self.reset_lora_parameters(adapter_name) + + if hasattr(self.get_base_layer(), "qweight"): + # QuantLinear + self.to(self.get_base_layer().qweight.device) + else: + self.to(self.get_base_layer().weight.device) + self.set_adapter(self.active_adapters) + + def reset_lora_parameters(self, adapter_name): + if adapter_name in self.lora_A.keys(): + nn.init.normal_(self.lora_E[adapter_name], mean=0.0, std=0.02) + nn.init.normal_(self.lora_A[adapter_name], mean=0.0, std=0.02) + nn.init.normal_(self.lora_B[adapter_name], mean=0.0, std=0.02) + + +class SVDLinear(nn.Module, AdaLoraLayer): + # SVD-based adaptation by a dense layer + def __init__( + self, + base_layer: nn.Module, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, + init_lora_weights: bool = True, + **kwargs, + ) -> None: + super().__init__() + AdaLoraLayer.__init__(self, base_layer) + # Freezing the pre-trained weight matrix + self.get_base_layer().weight.requires_grad = False + + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + base_layer = self.get_base_layer() + if active_adapter in self.lora_A.keys(): + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weights = base_layer.weight.data.clone() + orig_weights += self.get_delta_weight(active_adapter) + + if not torch.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = orig_weights + else: + base_layer.weight.data += self.get_delta_weight(active_adapter) + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.lora_A.keys(): + self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) + + def get_delta_weight(self, adapter) -> torch.Tensor: + return ( + transpose(self.lora_B[adapter] @ (self.lora_A[adapter] * self.lora_E[adapter]), self.fan_in_fan_out) + * self.scaling[adapter] + / (self.ranknum[adapter] + 1e-5) + ) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + lora_E = self.lora_E[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + ranknum = self.ranknum[active_adapter] + 1e-5 + + x = x.to(lora_A.dtype) + result += (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "adalora." + rep + + +class RankAllocator: + """ + The RankAllocator for AdaLoraModel. Paper: https://openreview.net/pdf?id=lq62uWRJjiY + + Args: + config ([`AdaLoraConfig`]): The configuration of the AdaLora model. + model: the model that we apply AdaLoRA to. + + """ + + def __init__(self, model, peft_config, adapter_name): + self.peft_config = peft_config + self.adapter_name = adapter_name + self.beta1 = peft_config.beta1 + self.beta2 = peft_config.beta2 + assert self.beta1 > 0 and self.beta1 < 1 + assert self.beta2 > 0 and self.beta2 < 1 + + self.reset_ipt() + self._set_budget_scheduler(model) + + def set_total_step(self, total_step): + self.peft_config.total_step = total_step + + def reset_ipt(self): + self.ipt = {} + self.exp_avg_ipt = {} + self.exp_avg_unc = {} + + def _set_budget_scheduler(self, model): + self.init_bgt = 0 + self.name_set = set() + for n, p in model.named_parameters(): + if f"lora_A.{self.adapter_name}" in n: + self.init_bgt += p.size(0) + self.name_set.add(n.replace("lora_A", "%s")) + self.name_set = sorted(self.name_set) + # The total final rank budget + self.target_bgt = self.peft_config.target_r * len(self.name_set) + + def budget_schedule(self, step: int): + tinit = self.peft_config.tinit + tfinal = self.peft_config.tfinal + total_step = self.peft_config.total_step + # Initial warmup + if step <= tinit: + budget = self.init_bgt + mask_ind = False + # Final fine-tuning + elif step > total_step - tfinal: + budget = self.target_bgt + mask_ind = True + else: + # Budget decreasing with a cubic scheduler + mul_coeff = 1 - (step - tinit) / (total_step - tfinal - tinit) + budget = int((self.init_bgt - self.target_bgt) * (mul_coeff**3) + self.target_bgt) + mask_ind = True if step % self.peft_config.deltaT == 0 else False + return budget, mask_ind + + def update_ipt(self, model): + # Update the sensitivity and uncertainty for every weight + for n, p in model.named_parameters(): + if "lora_" in n and self.adapter_name in n: + if n not in self.ipt: + self.ipt[n] = torch.zeros_like(p) + self.exp_avg_ipt[n] = torch.zeros_like(p) + self.exp_avg_unc[n] = torch.zeros_like(p) + with torch.no_grad(): + if deepspeed_config() is not None: + import deepspeed + + grad = deepspeed.utils.safe_get_full_grad(p) + self.ipt[n] = (p * grad).abs().detach() + else: + self.ipt[n] = (p * p.grad).abs().detach() + # Sensitivity smoothing + self.exp_avg_ipt[n] = self.beta1 * self.exp_avg_ipt[n] + (1 - self.beta1) * self.ipt[n] + # Uncertainty quantification + self.exp_avg_unc[n] = ( + self.beta2 * self.exp_avg_unc[n] + (1 - self.beta2) * (self.ipt[n] - self.exp_avg_ipt[n]).abs() + ) + + def _element_score(self, n): + return self.exp_avg_ipt[n] * self.exp_avg_unc[n] + + def _combine_ipt(self, ipt_E, ipt_AB): + ipt_AB = ipt_AB.sum(dim=1, keepdim=False) + sum_ipt = ipt_E.view(-1) + ipt_AB.view(-1) + return sum_ipt + + def mask_to_budget(self, model, budget): + value_ipt = {} + vector_ipt = {} + triplet_ipt = {} + # Get the importance score for A, E, B + for n, p in model.named_parameters(): + if f"lora_A.{self.adapter_name}" in n: + entry_ipt = self._element_score(n) + comb_ipt = torch.mean(entry_ipt, dim=1, keepdim=True) + name_m = n.replace("lora_A", "%s") + if name_m not in vector_ipt: + vector_ipt[name_m] = [comb_ipt] + else: + vector_ipt[name_m].append(comb_ipt) + if f"lora_B.{self.adapter_name}" in n: + entry_ipt = self._element_score(n) + comb_ipt = torch.mean(entry_ipt, dim=0, keepdim=False).view(-1, 1) + name_m = n.replace("lora_B", "%s") + if name_m not in vector_ipt: + vector_ipt[name_m] = [comb_ipt] + else: + vector_ipt[name_m].append(comb_ipt) + if f"lora_E.{self.adapter_name}" in n: + entry_ipt = self._element_score(n) + name_m = n.replace("lora_E", "%s") + value_ipt[name_m] = entry_ipt + + all_score = [] + # Calculate the score for each triplet + for name_m in vector_ipt: + ipt_E = value_ipt[name_m] + ipt_AB = torch.cat(vector_ipt[name_m], dim=1) + sum_ipt = self._combine_ipt(ipt_E, ipt_AB) + name_E = name_m % "lora_E" + triplet_ipt[name_E] = sum_ipt.view(-1, 1) + all_score.append(sum_ipt.view(-1)) + + # Get the threshold by ranking ipt + mask_threshold = torch.kthvalue( + torch.cat(all_score), + k=self.init_bgt - budget, + )[0].item() + + rank_pattern = {} + # Mask the unimportant triplets + with torch.no_grad(): + for n, p in model.named_parameters(): + if f"lora_E.{self.adapter_name}" in n: + p.masked_fill_(triplet_ipt[n] <= mask_threshold, 0.0) + rank_pattern[n] = (~(triplet_ipt[n] <= mask_threshold)).view(-1).tolist() + return rank_pattern + + def update_and_allocate(self, model, global_step, force_mask=False): + # # Update the importance score and allocate the budget + if global_step < self.peft_config.total_step - self.peft_config.tfinal: + self.update_ipt(model) + budget, mask_ind = self.budget_schedule(global_step) + # Allocate the budget according to importance scores + if mask_ind or force_mask: + rank_pattern = self.mask_to_budget(model, budget) + else: + rank_pattern = None + return budget, rank_pattern + + def mask_using_rank_pattern(self, model, rank_pattern): + # Mask the unimportant triplets + is_adapter_name_truncated = False + if self.adapter_name not in next(iter(rank_pattern.keys())): + is_adapter_name_truncated = True + + with torch.no_grad(): + for n, p in model.named_parameters(): + if f"lora_E.{self.adapter_name}" in n: + key = n if not is_adapter_name_truncated else n.replace(f".{self.adapter_name}", "") + mask = torch.Tensor(rank_pattern[key]).unsqueeze(-1).to(p.device) + p.masked_fill_(~mask.bool(), 0.0) diff --git a/peft/tuners/adalora/model.py b/peft/tuners/adalora/model.py new file mode 100644 index 0000000000000000000000000000000000000000..2a1302979a86365a7350062a1c06b1c08cc9792a --- /dev/null +++ b/peft/tuners/adalora/model.py @@ -0,0 +1,355 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import torch +from transformers.pytorch_utils import Conv1D + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.tuners.lora import LoraConfig, LoraModel +from peft.tuners.tuners_utils import BaseTunerLayer +from peft.utils import ( + TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, + _freeze_adapter, + _get_submodules, + get_auto_gptq_quant_linear, + get_quantization_config, +) +from peft.utils.integrations import gather_params_ctx + +from .gptq import SVDQuantLinear +from .layer import AdaLoraLayer, RankAllocator, SVDLinear + + +class AdaLoraModel(LoraModel): + """ + Creates AdaLoRA (Adaptive LoRA) model from a pretrained transformers model. Paper: + https://openreview.net/forum?id=lq62uWRJjiY + + Args: + model ([`transformers.PreTrainedModel`]): The model to be adapted. + config ([`AdaLoraConfig`]): The configuration of the AdaLora model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + `torch.nn.Module`: The AdaLora model. + + Example:: + + >>> from transformers import AutoModelForSeq2SeqLM, LoraConfig >>> from peft import AdaLoraModel, AdaLoraConfig + >>> config = AdaLoraConfig( + peft_type="ADALORA", task_type="SEQ_2_SEQ_LM", r=8, lora_alpha=32, target_modules=["q", "v"], + lora_dropout=0.01, + ) + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> model = AdaLoraModel(model, config, "default") + + **Attributes**: + - **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted. + - **peft_config** ([`AdaLoraConfig`]): The configuration of the AdaLora model. + """ + + # Note: don't redefine prefix here, it should be inherited from LoraModel + + def __init__(self, model, config, adapter_name): + super().__init__(model, config, adapter_name) + + traininable_mode_counter = 0 + for config in self.peft_config.values(): + if not config.inference_mode: + traininable_mode_counter += 1 + + if traininable_mode_counter > 1: + raise ValueError( + "AdaLoraModel supports only 1 trainable adapter. " + "When using multiple adapters, set inference_mode to True for all adapters except the one you want to train." + ) + + if self.peft_config[adapter_name].inference_mode: + _freeze_adapter(self.model, adapter_name) + else: + self.trainable_adapter_name = adapter_name + self.rankallocator = RankAllocator(self.model, self.peft_config[adapter_name], self.trainable_adapter_name) + + def _check_new_adapter_config(self, config: LoraConfig) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + super()._check_new_adapter_config(config) + + traininable_mode_counter = 0 + for config_ in self.peft_config.values(): + if not config_.inference_mode: + traininable_mode_counter += 1 + + if traininable_mode_counter > 1: + raise ValueError( + f"{self.__class__.__name__} supports only 1 trainable adapter. " + "When using multiple adapters, set inference_mode to True for all adapters except the one " + "you want to train." + ) + + def _create_and_replace( + self, + lora_config, + adapter_name, + target, + target_name, + parent, + current_key, + ): + kwargs = { + "r": lora_config.init_r, + "lora_alpha": lora_config.lora_alpha, + "lora_dropout": lora_config.lora_dropout, + "fan_in_fan_out": lora_config.fan_in_fan_out, + "init_lora_weights": lora_config.init_lora_weights, + "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), + "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), + } + if (kwargs["loaded_in_8bit"] or kwargs["loaded_in_4bit"]) and not is_bnb_available(): + raise ImportError( + "To use AdaLora with 8-bit quantization, please install the `bitsandbytes` package. " + "You can install it with `pip install bitsandbytes`." + ) + + quantization_config = get_quantization_config(self.model, method="gptq") + if quantization_config is not None: + kwargs["gptq_quantization_config"] = quantization_config + + # If it is not an AdaLoraLayer, create a new module, else update it with new adapters + if not isinstance(target, AdaLoraLayer): + new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs) + if adapter_name not in self.active_adapters: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + else: + target.update_layer( + adapter_name, + lora_config.init_r, + lora_config.lora_alpha, + lora_config.lora_dropout, + lora_config.init_lora_weights, + ) + + @staticmethod + def _create_new_module(lora_config, adapter_name, target, **kwargs): + # avoid eager bnb import + if is_bnb_available(): + import bitsandbytes as bnb + + from .bnb import SVDLinear8bitLt + if is_bnb_4bit_available(): + from .bnb import SVDLinear4bit + + gptq_quantization_config = kwargs.get("gptq_quantization_config", None) + AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) + + loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) + loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt): + kwargs.update( + { + "has_fp16_weights": target_base_layer.state.has_fp16_weights, + "memory_efficient_backward": target_base_layer.state.memory_efficient_backward, + "threshold": target_base_layer.state.threshold, + "index": target_base_layer.index, + } + ) + new_module = SVDLinear8bitLt(target, adapter_name, **kwargs) + elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit): + fourbit_kwargs = kwargs.copy() + fourbit_kwargs.update( + { + "compute_dtype": target_base_layer.compute_dtype, + "compress_statistics": target_base_layer.weight.compress_statistics, + "quant_type": target_base_layer.weight.quant_type, + } + ) + new_module = SVDLinear4bit(target, adapter_name, **fourbit_kwargs) + elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear): + new_module = SVDQuantLinear(target, adapter_name, **kwargs) + else: + if isinstance(target_base_layer, torch.nn.Linear): + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." + ) + kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False + elif isinstance(target_base_layer, Conv1D): + if not kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to False but the target module is `Conv1D`. " + "Setting fan_in_fan_out to True." + ) + kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True + else: + raise ValueError( + f"Target module {target} is not supported. " + f"Currently, only `torch.nn.Linear` and `Conv1D` are supported." + ) + new_module = SVDLinear(target, adapter_name, **kwargs) + + return new_module + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING[ + model_config["model_type"] + ] + return peft_config + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.model, name) + + def forward(self, *args, **kwargs): + outputs = self.model.forward(*args, **kwargs) + + if (getattr(outputs, "loss", None) is not None) and isinstance(outputs.loss, torch.Tensor): + # Calculate the orthogonal regularization + orth_reg_weight = self.peft_config[self.trainable_adapter_name].orth_reg_weight + + if orth_reg_weight <= 0: + raise ValueError("orth_reg_weight should be greater than 0. ") + + regu_loss = 0 + num_param = 0 + for n, p in self.model.named_parameters(): + if ("lora_A" in n or "lora_B" in n) and self.trainable_adapter_name in n: + if p.shape == torch.Size([0]): + with gather_params_ctx(p, fwd_module=self): + para_cov = p @ p.T if "lora_A" in n else p.T @ p + else: + para_cov = p @ p.T if "lora_A" in n else p.T @ p + I = torch.eye(*para_cov.size(), out=torch.empty_like(para_cov)) # noqa: E741 + I.requires_grad = False + num_param += 1 + regu_loss += torch.norm(para_cov - I, p="fro") + if num_param > 0: + regu_loss = regu_loss / num_param + else: + regu_loss = 0 + outputs.loss += orth_reg_weight * regu_loss + return outputs + + def resize_modules_by_rank_pattern(self, rank_pattern, adapter_name): + lora_config = self.peft_config[adapter_name] + for name, rank_idx in rank_pattern.items(): + if isinstance(rank_idx, list): + rank = sum(rank_idx) + elif isinstance(rank_idx, torch.Tensor): + rank_idx = rank_idx.view(-1) + rank = rank_idx.sum().item() + else: + raise ValueError("Unexpected type of rank_idx") + key = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1]) + _, target, _ = _get_submodules(self.model, key) + lora_E_weights = target.lora_E[adapter_name][rank_idx] + lora_A_weights = target.lora_A[adapter_name][rank_idx] + lora_B_weights = target.lora_B[adapter_name][:, rank_idx] + ranknum = target.ranknum[adapter_name] + target.update_layer( + adapter_name, + rank, + lora_config.lora_alpha, + lora_config.lora_dropout, + lora_config.init_lora_weights, + ) + with torch.no_grad(): + if rank > 0: + target.lora_E[adapter_name].copy_(lora_E_weights) + target.lora_A[adapter_name].copy_(lora_A_weights) + target.lora_B[adapter_name].copy_(lora_B_weights) + # The scaling is exactly as the previous + target.ranknum[adapter_name].copy_(ranknum) + + def resize_state_dict_by_rank_pattern(self, rank_pattern, state_dict, adapter_name): + for name, rank_idx in rank_pattern.items(): + rank = sum(rank_idx) + prefix = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1]) + for layer in ["lora_E", "lora_A", "lora_B"]: + key = f"base_model.model.{prefix}.{layer}.{adapter_name}" + if layer != "lora_B": + state_dict[key] = ( + state_dict[key][rank_idx] if rank != state_dict[key].shape[0] else state_dict[key] + ) + else: + state_dict[key] = ( + state_dict[key][:, rank_idx] if rank != state_dict[key].shape[1] else state_dict[key] + ) + return state_dict + + def update_and_allocate(self, global_step): + """ + This method updates Adalora budget and mask. + + This should be called in every training step after `loss.backward()` and before `zero_grad()`. + + `tinit`, `tfinal` and `deltaT` are handled with in the method. + + Args: + global_step (`int`): The current training step, it is used to calculate adalora budget. + + Example: + + ```python + >>> loss = model(**input).loss + >>> loss.backward() + >>> optimizer.step() + >>> model.base_model.update_and_allocate(i_step) + >>> optimizer.zero_grad() + ``` + """ + lora_config = self.peft_config[self.trainable_adapter_name] + # Update the importance score and allocate the budget + if global_step < lora_config.total_step - lora_config.tfinal: + _, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step) + if rank_pattern: + lora_config.rank_pattern = rank_pattern + # Finalize the budget allocation + elif global_step == lora_config.total_step - lora_config.tfinal: + _, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step, force_mask=True) + # for some reason, this freezes the trainable parameters and nothing gets updates + # self.resize_modules_by_rank_pattern(rank_pattern, self.trainable_adapter_name) + lora_config.rank_pattern = rank_pattern + self.rankallocator.reset_ipt() + # Currently using inefficient way to mask the unimportant weights using the rank pattern + # due to problem mentioned above + elif global_step > lora_config.total_step - lora_config.tfinal: + self.rankallocator.mask_using_rank_pattern(self.model, lora_config.rank_pattern) + # Pass the function and do forward propagation + else: + return None + + def add_weighted_adapter(self, *args, **kwargs): + """This method is not supported for AdaLoRA, use LoRA instead.""" + raise TypeError(f"{self.__class__.__name__} does not support add_weighted_adapter method.") diff --git a/peft/tuners/adaption_prompt/__init__.py b/peft/tuners/adaption_prompt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ede9455f70e41740768abe80f3198e78397053f --- /dev/null +++ b/peft/tuners/adaption_prompt/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .config import AdaptionPromptConfig +from .layer import AdaptedAttention +from .model import AdaptionPromptModel + + +__all__ = ["AdaptionPromptConfig", "AdaptedAttention", "AdaptionPromptModel"] diff --git a/peft/tuners/adaption_prompt/__pycache__/__init__.cpython-310.pyc b/peft/tuners/adaption_prompt/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..184e90af87162a9f665e28c4ce774bcf3879c776 Binary files /dev/null and b/peft/tuners/adaption_prompt/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/adaption_prompt/__pycache__/config.cpython-310.pyc b/peft/tuners/adaption_prompt/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5620960991c740728e640fbaba0b3ee37f715927 Binary files /dev/null and b/peft/tuners/adaption_prompt/__pycache__/config.cpython-310.pyc differ diff --git a/peft/tuners/adaption_prompt/__pycache__/layer.cpython-310.pyc b/peft/tuners/adaption_prompt/__pycache__/layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe8e680a78703597eb2451a156f935c0d6c95ba9 Binary files /dev/null and b/peft/tuners/adaption_prompt/__pycache__/layer.cpython-310.pyc differ diff --git a/peft/tuners/adaption_prompt/__pycache__/model.cpython-310.pyc b/peft/tuners/adaption_prompt/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eff23c789c1c4ec8306c200348ee2597665cc557 Binary files /dev/null and b/peft/tuners/adaption_prompt/__pycache__/model.cpython-310.pyc differ diff --git a/peft/tuners/adaption_prompt/__pycache__/utils.cpython-310.pyc b/peft/tuners/adaption_prompt/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a22d07f40ff74cd01df24cb784a37c2ee8c6fa0b Binary files /dev/null and b/peft/tuners/adaption_prompt/__pycache__/utils.cpython-310.pyc differ diff --git a/peft/tuners/adaption_prompt/config.py b/peft/tuners/adaption_prompt/config.py new file mode 100644 index 0000000000000000000000000000000000000000..90e29841498b8821dc6b6602282b66a0d3df6750 --- /dev/null +++ b/peft/tuners/adaption_prompt/config.py @@ -0,0 +1,80 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +from dataclasses import dataclass, field + +from peft.config import PeftConfig +from peft.utils import PeftType + +from .utils import llama_compute_query_states + + +@dataclass +class AdaptionPromptConfig(PeftConfig): + """Stores the configuration of an [`AdaptionPromptModel`].""" + + target_modules: str = field( + default=None, metadata={"help": "Name of the attention submodules to insert adaption prompts into."} + ) + adapter_len: int = field(default=None, metadata={"help": "Number of adapter tokens to insert"}) + adapter_layers: int = field(default=None, metadata={"help": "Number of adapter layers (from the top)"}) + + def __post_init__(self): + self.peft_type = PeftType.ADAPTION_PROMPT + + @property + def is_adaption_prompt(self) -> bool: + """Return True if this is an adaption prompt config.""" + return True + + +# Contains the config that is specific to a transformers model type. +ModelTypeConfig = namedtuple( + "ModelTypeConfig", ["compute_query_states", "target_modules", "k_proj_layer", "v_proj_layer", "o_proj_layer"] +) + +# Mapping of transformers model types to their specific configuration. +TRANSFORMERS_MODEL_CONFIG = { + "llama": ModelTypeConfig( + compute_query_states=llama_compute_query_states, + target_modules="self_attn", + k_proj_layer="k_proj", + v_proj_layer="v_proj", + o_proj_layer="o_proj", + ), + "mistral": ModelTypeConfig( # same as llama, + compute_query_states=llama_compute_query_states, + target_modules="self_attn", + k_proj_layer="k_proj", + v_proj_layer="v_proj", + o_proj_layer="o_proj", + ), +} + + +def prepare_config( + peft_config: AdaptionPromptConfig, + model, +) -> AdaptionPromptConfig: + """Prepare the config based on the llama model type.""" + if model.config.model_type not in TRANSFORMERS_MODEL_CONFIG: + raise ValueError("Unsupported model type for adaption prompt: '{model.config.model_type}'.") + + model_config = TRANSFORMERS_MODEL_CONFIG[model.config.model_type] + + if peft_config.target_modules is None: + peft_config.target_modules = model_config.target_modules + + return peft_config diff --git a/peft/tuners/adaption_prompt/layer.py b/peft/tuners/adaption_prompt/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..31fb51e0de6a9842d14578c78a5a1aceba676483 --- /dev/null +++ b/peft/tuners/adaption_prompt/layer.py @@ -0,0 +1,128 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .config import TRANSFORMERS_MODEL_CONFIG + + +class AdaptedAttention(nn.Module): + """This module wraps a LLamaAttention module and injects adaption prompts.""" + + def __init__(self, model_type: str, adapter_len: int, model): + """ + Initialize object. + + Args: + model_type: The transformer model type. This is used to retrieve the right method to + compute query states. + adapter_len: The length of the adaption prompt to insert. + model: The original transformer attention module that is being wrapped. + """ + assert not isinstance(model, AdaptedAttention) + super().__init__() + self.model_type = model_type + self.model = model + self.adapter_len = adapter_len + # Assume all parameters of the attention model we are wrapping are on the same device. + device = next(model.parameters()).device + # Don't think this was specified in the paper, but we follow the official repo which used an Embedding + # which initializes the tokens with standard normal values. + # https://github.com/ZrrSkywalker/LLaMA-Adapter/blob/41c3546fe1997ab8a65809dc8d8f9252b19d9faf/llama/model.py#L234 + # (bsz, adapter_len, hidden_size) + target_dtype = ( + model.q_proj.weight.dtype if model.q_proj.weight.dtype not in [torch.int8, torch.uint8] else torch.float32 + ) + self.adaption_prompt = nn.Parameter( + torch.empty(1, adapter_len, self.model.hidden_size, device=device, dtype=target_dtype).normal_() + ) + # Initialize the gate to 0 as this is "zero-init". + self.adaption_gate = nn.Parameter(torch.zeros(1, device=device, dtype=target_dtype)) + + def forward(self, **kwargs): + """ + Forward pass for the adapter which wraps the original LlamaAttention module. + + "Official" paper implementation: + https://github.com/ZrrSkywalker/LLaMA-Adapter/blob/41c3546fe1997ab8a65809dc8d8f9252b19d9faf/llama/model.py#L141 + + Args: + kwargs: See the original LlamaAttention module. + """ + if kwargs.get("output_attention", False): + raise NotImplementedError("output_attention is not currently supported.") + + output, _, past_key_value = self.model(**kwargs) + bsz = output.shape[0] + q_len = output.shape[1] + embed_dim = output.shape[2] + k_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].k_proj_layer + v_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].v_proj_layer + o_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].o_proj_layer + factor = ( + self.model.k_proj.in_features // self.model.k_proj.out_features + ) # Mistral has different input and output dimension for k_proj and v_proj layers + + if k_proj_layer == v_proj_layer: + _, key, value = getattr(self.model, k_proj_layer)(self.adaption_prompt).split(embed_dim, dim=2) + else: + key = getattr(self.model, k_proj_layer)(self.adaption_prompt) + value = getattr(self.model, v_proj_layer)(self.adaption_prompt) + + # (bsz, num_key_value_heads, adapter_len, head_dim) + adapter_k = ( + key.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim) + .repeat(bsz, 1, 1, 1) + .transpose(1, 2) + ) + adapter_v = ( + value.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim) + .repeat(bsz, 1, 1, 1) + .transpose(1, 2) + ) + # Below is taken from https://github.com/huggingface/transformers/blob/e547458c43dfdbbb8f6a7757237e234c44e20a8f/src/transformers/models/mistral/modeling_mistral.py#L181 + # (bsz, num_heads, adapter_len, head_dim) + adapter_k = torch.repeat_interleave(adapter_k, repeats=factor, dim=1) + adapter_v = torch.repeat_interleave(adapter_v, repeats=factor, dim=1) + # Recompute query states. + compute_query_states = TRANSFORMERS_MODEL_CONFIG[self.model_type].compute_query_states + # (bsz, num_heads, q_len, head_dim) + query_states = compute_query_states(model=self.model, **kwargs) + + previous_dtype = query_states.dtype + + # (bsz, num_heads, q_len, adapter_len) + scores = torch.matmul(query_states, adapter_k.transpose(2, 3).to(previous_dtype)) / math.sqrt( + self.model.head_dim + ) + # Upcast attention to fp32 + # (bsz, num_heads, q_len, adapter_len) + scores = self.adaption_gate * F.softmax(scores, dim=-1, dtype=torch.float32).to(previous_dtype) + # (bsz, q_len, num_heads * head_dim) + adapter_output = torch.matmul(scores, adapter_v).transpose(1, 2).reshape(bsz, q_len, -1) + + # (bsz, q_len, hidden_size) + if o_proj_layer is not None: + adapter_output = getattr(self.model, o_proj_layer)(adapter_output) + + # Add adaption prompt output to original output. + output = output + adapter_output + + # Restore original dtype. + output = output.to(previous_dtype) + return output, None, past_key_value diff --git a/peft/tuners/adaption_prompt/model.py b/peft/tuners/adaption_prompt/model.py new file mode 100644 index 0000000000000000000000000000000000000000..08aea27f8efb51c8c2d85be91a2f95659651c701 --- /dev/null +++ b/peft/tuners/adaption_prompt/model.py @@ -0,0 +1,161 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List + +import torch.nn as nn + +from peft.utils import _freeze_adapter, _get_submodules + +from .config import AdaptionPromptConfig, prepare_config +from .layer import AdaptedAttention +from .utils import is_adaption_prompt_trainable + + +class AdaptionPromptModel(nn.Module): + """ + Implements adaption prompts as described in https://arxiv.org/pdf/2303.16199.pdf. + + The top L attention modules are replaced with AdaptedAttention modules that wrap the original ones, but insert + trainable prompts with gates (for zero init). + + Notes on the multi-adapter pattern: + - We store the states of different adapters by keeping a dictionary of AdaptedAttention modules indexed by adapter + name. + - Every time we switch adapters, we remove the modules of the currently active adapter from the model, store them + in the dictionary, and replace them with the modules of the new adapter. + - To avoid duplicated and potentially inconsistent state, the currently active adapter is always removed from the + dictionary. + - Disabling the adapter would also result in the modules being removed from the model. + """ + + def __init__(self, model, configs: Dict, adapter_name: str): + super().__init__() + self.model = model + # Store adapter configs by name. + self.peft_config: Dict[str, AdaptionPromptConfig] = {} + # Store lists of the parents of the affected attention modules by adapter name. + # We keep references to the parents so we can swap the adapters in-and-out of the model. + self._parents: Dict[str, List[nn.Module]] = {} + # Store lists of cached AdaptedAttention modules by name. + self._cached_adapters: Dict[str, List] = {} + # The name of the currently active adapter. + self._active_adapter = None + # Whether the adapter is enabled. + self._enabled = True + self.forward = self.model.forward + self.add_adapter(adapter_name, configs[adapter_name]) + self._mark_only_adaption_prompts_as_trainable(self.model) + + def add_adapter(self, adapter_name: str, config: AdaptionPromptConfig) -> None: + """Add an adapter with the given name and config.""" + config = prepare_config(config, self.model) + if adapter_name in self.peft_config: + raise ValueError(f"Adapter with name '{adapter_name}' already exists.") + + parents = [] + for name, _ in self.model.named_modules(): + if name.endswith(config.target_modules): + par, _, _ = _get_submodules(self.model, name) + parents.append(par) + if len(parents) < config.adapter_layers: + raise ValueError( + f"Config specifies more adapter layers '{config.adapter_layers}'" + f" than the model has '{len(parents)}'." + ) + # Note that if the target modules are not in Sequential, ModuleList, or + # some other PyTorch ordered container, the behavior is undefined as we + # assume here that the order of the modules is the same as the order of + # the transformer decoder layers. + parents = parents[-config.adapter_layers :] + self._parents[adapter_name] = parents + + # It is only None during initialization. + # If it is disabled, we don't have to remove the modules. + if self._active_adapter is not None and self._enabled: + self._remove_adapted_attentions(self._active_adapter) + self._active_adapter = adapter_name + self.peft_config[adapter_name] = config + self._create_adapted_attentions(config, parents) + if not self._enabled: + self._remove_adapted_attentions(self._active_adapter) + + if config.inference_mode: + _freeze_adapter(self.model, adapter_name) + + def set_adapter(self, adapter_name: str) -> None: + """Set the model to use the adapter with the given name.""" + if self._active_adapter == adapter_name: + return + if adapter_name not in self.peft_config: + raise ValueError(f"Adapter with name '{adapter_name}' does not exist.") + + if self._enabled: + self._remove_adapted_attentions(self._active_adapter) + self._set_adapted_attentions(adapter_name) + + self._active_adapter = adapter_name + + def enable_adapter_layers(self): + """Enable adapter layers by swapping in cached AdaptedAttention modules.""" + self._enabled = True + self._set_adapted_attentions(self._active_adapter) + + def disable_adapter_layers(self): + """Disable adapter layers by swapping out AdaptedAttention modules.""" + self._enabled = False + self._remove_adapted_attentions(self._active_adapter) + + def _create_adapted_attentions(self, config: AdaptionPromptConfig, parents: List[nn.Module]) -> None: + """Wrap LlamaAttention modules with newly created AdaptedAttention modules.""" + for par in parents: + attn = AdaptedAttention( + model_type=self.model.config.model_type, + adapter_len=config.adapter_len, + model=getattr(par, config.target_modules), + ) + setattr(par, config.target_modules, attn) + + def _set_adapted_attentions(self, adapter_name: str) -> None: + """Replace LlamaAttention modules with cached AdaptedAttention modules.""" + cached = self._cached_adapters[adapter_name] + del self._cached_adapters[adapter_name] + config = self.peft_config[adapter_name] + for i, par in enumerate(self._parents[adapter_name]): + setattr(par, config.target_modules, cached[i]) + + def _remove_adapted_attentions(self, adapter_name: str) -> None: + """Remove AdaptedAttention modules from the model and store them in the cache.""" + config = self.peft_config[adapter_name] + adapted_attentions = [] + for par in self._parents[adapter_name]: + attn = getattr(par, config.target_modules) + adapted_attentions.append(attn) + setattr(par, config.target_modules, attn.model) + self._cached_adapters[adapter_name] = adapted_attentions + + def _mark_only_adaption_prompts_as_trainable(self, model: nn.Module) -> None: + """Freeze all parameters of the model except the adaption prompts.""" + for n, p in model.named_parameters(): + if not is_adaption_prompt_trainable(n): + p.requires_grad = False + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + # This is necessary as e.g. causal models have various methods that we + # don't want to re-implement here. + return getattr(self.model, name) diff --git a/peft/tuners/adaption_prompt/utils.py b/peft/tuners/adaption_prompt/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8f15d89f31aa0c92c7305897850342d3929292a6 --- /dev/null +++ b/peft/tuners/adaption_prompt/utils.py @@ -0,0 +1,121 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect + +import torch +import torch.nn as nn + + +def llama_rotate_half(x: torch.Tensor) -> torch.Tensor: + """ + Rotate half the hidden dims of the input. + + This function was duplicated verbatim from: + https://github.com/huggingface/transformers/blob/1de8ce9ee1191ba761a593ac15d9ccbf5851bfc5/src/transformers/models/llama/modeling_llama.py#L126 + + This was done to eliminate the Llama transformers implementation as a dependency of this file. Note that some other + functions were also adapted from the transformers implementation but were modified. + """ + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def llama_apply_rotary_pos_emb(q, cos, sin, position_ids): + """ + Apply rotary position embedding to query states in the Llama model. + + This function was adapted from: + https://github.com/huggingface/transformers/blob/1de8ce9ee1191ba761a593ac15d9ccbf5851bfc5/src/transformers/models/llama/modeling_llama.py#L133 + + It was modified to remove unnecessary processing of key states. The method is compatible with transformers <= + 4.34.2 and also with the latest version (>=4.35). + """ + # In previous transformers version cos/sin cached had a shape of 4D + if len(cos.shape) == 4: + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + # In the new version, it is 2D so we fall back to the new implementation + # https://github.com/huggingface/transformers/blame/eef7ea98c31a333bacdc7ae7a2372bde772be8e4/src/transformers/models/llama/modeling_llama.py#L222-L226 + else: + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + q_embed = (q * cos) + (llama_rotate_half(q) * sin) + return q_embed + + +def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor: + """ + Compute query states for Llama models specifically. They need to be recomputed as the forward() method of the + original LlamaModel in the transformers library does not return them. See the related discussion in the PR: + https://github.com/huggingface/peft/pull/268 + """ + hidden_states = kwargs.get("hidden_states") + position_ids = kwargs.get("position_ids") + past_key_value = kwargs.get("past_key_value") + bsz, q_len, _ = hidden_states.size() + query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2) + + factor = model.k_proj.in_features // model.k_proj.out_features + value_states = ( + model.v_proj(hidden_states).view(bsz, q_len, (model.num_heads // factor), model.head_dim).transpose(1, 2) + ) + + seq_len = q_len + + if past_key_value is not None: + if isinstance(past_key_value, tuple): + # for transformers <= 4.35 + seq_len += past_key_value[0].shape[-2] + else: + # since transformers 4.36, this is a DynamicCache instance + seq_len += past_key_value.get_seq_length(model.layer_idx) + + # For transformers > 4.37.2 `position_ids` became a required arguments in the rotary embedding's forward pass. + if "position_ids" not in inspect.signature(model.rotary_emb.forward).parameters: + # TODO we assume that position_ids is not None here, not sure if that is safe but the old code also did that + cos, sin = model.rotary_emb(value_states, seq_len=seq_len) + return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids) + + past_seen_tokens = 0 + if position_ids is None: + # Compute position_ids, since they are required for transformers > 4.37.2 + if past_key_value is None: + new_cache_positions = torch.arange(q_len, q_len + q_len, device=value_states.device) + else: + past_seen_tokens = past_key_value.get_usable_length(q_len, model.layer_idx) + new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=value_states.device) + position_ids = new_cache_positions.unsqueeze(0) + + rotary_emb_kwargs = {"position_ids": position_ids} + # The `seq_len` argument has been officially removed in transformers >= 4.39.0 + if "seq_len" in inspect.signature(model.rotary_emb.forward).parameters: + rotary_emb_kwargs["seq_len"] = q_len + past_seen_tokens + + cos, sin = model.rotary_emb(value_states, **rotary_emb_kwargs) + + # For batched inference unsqueeze it on the correct dim + # since: https://github.com/huggingface/transformers/pull/29109 + if len(cos.shape) == 3: + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + + return (query_states * cos) + (llama_rotate_half(query_states) * sin) + + +def is_adaption_prompt_trainable(params: str) -> bool: + """Return True if module is trainable under adaption prompt fine-tuning.""" + return params.split(".")[-1].startswith("adaption_") diff --git a/peft/tuners/boft/__init__.py b/peft/tuners/boft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b72b7395167b90f9be562ee12973dbfef1781a9 --- /dev/null +++ b/peft/tuners/boft/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import BOFTConfig +from .layer import BOFTLayer +from .model import BOFTModel + + +__all__ = ["BOFTConfig", "BOFTLayer", "BOFTModel"] diff --git a/peft/tuners/boft/__pycache__/__init__.cpython-310.pyc b/peft/tuners/boft/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bed15547e805fd2ab33a4f69fe4532967c755e03 Binary files /dev/null and b/peft/tuners/boft/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/boft/__pycache__/config.cpython-310.pyc b/peft/tuners/boft/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6ec82042dfbdd8e67f0523e2bc9ea38c413d095 Binary files /dev/null and b/peft/tuners/boft/__pycache__/config.cpython-310.pyc differ diff --git a/peft/tuners/boft/__pycache__/layer.cpython-310.pyc b/peft/tuners/boft/__pycache__/layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9d713c6337c46706eb8bd26b8a98d7be3c27aeb Binary files /dev/null and b/peft/tuners/boft/__pycache__/layer.cpython-310.pyc differ diff --git a/peft/tuners/boft/__pycache__/model.cpython-310.pyc b/peft/tuners/boft/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6996580a894f7debd0972e5751a797b61c43ea78 Binary files /dev/null and b/peft/tuners/boft/__pycache__/model.cpython-310.pyc differ diff --git a/peft/tuners/boft/config.py b/peft/tuners/boft/config.py new file mode 100644 index 0000000000000000000000000000000000000000..ab704b5d95be5a747970bda970966bbb01b8e649 --- /dev/null +++ b/peft/tuners/boft/config.py @@ -0,0 +1,133 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The implementation is based on "Parameter-Efficient Orthogonal Finetuning +# via Butterfly Factorization" (https://arxiv.org/abs/2311.06243) in ICLR 2024. + +from dataclasses import dataclass, field +from typing import List, Optional, Union + +from peft.config import PeftConfig +from peft.utils import PeftType + + +@dataclass +class BOFTConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`BOFTModel`]. + + Args: + boft_block_size (`int`): BOFT block size across different layers. + boft_block_num (`int`): Number of BOFT blocks per injected layer. + boft_n_butterfly_factor (`int`): Number of butterfly factors across different layers. + target_modules (`Union[List[str],str]`): The names of the modules to apply the adapter to. + boft_dropout (`float`): The multiplicative dropout probability for BOFT layers. + fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out). + For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set + to `True`. + bias (`str`): Bias type for BOFT. Can be 'none', 'all' or 'boft_only'. If 'all' or 'boft_only', the + corresponding biases will be updated during training. Be aware that this means that, even when disabling + the adapters, the model will not produce the same output as the base model would have without adaptation. + modules_to_save (`List[str]`):List of modules apart from BOFT layers to be set as trainable + and saved in the final checkpoint. + layers_to_transform (`Union[List[int],int]`): + The layer indexes to transform, if this argument is specified, it will apply the BOFT transformations on + the layer indexes that are specified in this list. If a single integer is passed, it will apply the BOFT + transformations on the layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer + pattern is not in the common layers pattern. + """ + + boft_block_size: int = field( + default=4, + metadata={ + "help": "BOFT block size across different layers.", + "note": "You can only specify either boft_block_size or boft_block_num, but not both simultaneously, because boft_block_size x boft_block_num = layer dimension.", + }, + ) + boft_block_num: int = field( + default=0, + metadata={ + "help": "Number of BOFT blocks per injected layer.", + "note": "You can only specify either boft_block_size or boft_block_num, but not both simultaneously, because boft_block_size x boft_block_num = layer dimension.", + }, + ) + boft_n_butterfly_factor: int = field( + default=1, + metadata={ + "help": "Number of butterfly factors.", + "note": ( + "for example, boft_n_butterfly_factor=2, the effective block size of OFT becomes twice as big and the number of blocks become half.", + "note: for boft_n_butterfly_factor=1, BOFT is the same as vanilla OFT.", + ), + }, + ) + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to replace with BOFT.", + "example": "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' ", + }, + ) + boft_dropout: float = field(default=0.0, metadata={"help": "BOFT multiplicative dropout"}) + fan_in_fan_out: bool = field( + default=False, + metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, + ) + bias: str = field(default="none", metadata={"help": "Bias type for BOFT. Can be 'none', 'all' or 'boft_only'"}) + modules_to_save: Optional[List[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from BOFT layers to be set as trainable and saved in the final checkpoint. ", + "note": ( + "For example, in Sequence Classification or Token Classification tasks, ", + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved.", + ), + }, + ) + init_weights: bool = field( + default=True, + metadata={ + "help": ( + "Whether to initialize the weights of the BOFT layers with their default initialization. Don't change ", + "this setting, except if you know exactly what you're doing.", + ), + }, + ) + layers_to_transform: Optional[Union[List[int], int]] = field( + default=None, + metadata={ + "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index." + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.BOFT + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) + if self.boft_block_size == 0 and self.boft_block_num == 0: + raise ValueError("You must specify either boft_block_size or boft_block_num.") + if not (self.boft_block_size != 0) ^ (self.boft_block_num != 0): + raise ValueError( + f"You can only specify either boft_block_size ({self.boft_block_size}) or boft_block_num ({self.boft_block_num}), " + "but not both simultaneously, because boft_block_size x boft_block_num != in_features." + ) diff --git a/peft/tuners/boft/fbd/__init__.py b/peft/tuners/boft/fbd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/peft/tuners/boft/fbd/__pycache__/__init__.cpython-310.pyc b/peft/tuners/boft/fbd/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4700c2c398aeb4ff3698a416b027cbf92c93f50c Binary files /dev/null and b/peft/tuners/boft/fbd/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/boft/fbd/fbd_cuda.cpp b/peft/tuners/boft/fbd/fbd_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d63111b04083f3b1359aba0a9de9656060a4d515 --- /dev/null +++ b/peft/tuners/boft/fbd/fbd_cuda.cpp @@ -0,0 +1,28 @@ +#include +#include +#include +#include + +std::vector forward_fast_block_diag_cuda( + at::Tensor input); + +std::vector forward_fast_block_diag( + at::Tensor input + ) { + return forward_fast_block_diag_cuda(input); +} + +std::vector backward_fast_block_diag_cuda( + at::Tensor grad_output, + at::Tensor input); +std::vector backward_fast_block_diag( + at::Tensor grad_output, + at::Tensor input + ) { + return backward_fast_block_diag_cuda(grad_output, input); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward_fast_block_diag, "FAST BLOCK DIAG (CUDA)"); + m.def("backward", &backward_fast_block_diag, "FAST BLOCK DIAG backward (CUDA)"); +} diff --git a/peft/tuners/boft/fbd/fbd_cuda_kernel.cu b/peft/tuners/boft/fbd/fbd_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..7d683371e27e61847fb86073ff4ae7eab3cb6192 --- /dev/null +++ b/peft/tuners/boft/fbd/fbd_cuda_kernel.cu @@ -0,0 +1,109 @@ +// Author: Yao Feng +// Date: 2023/08 +// Description: cuda kernel for fast block diag + +#include + +#include +#include +#include + +namespace{ +template +__global__ void forward_fast_block_diag_cuda_kernel( + const scalar_t* __restrict__ input, //[z, N, b, b] + scalar_t* output, //[z, Nxb, Nxb] + int z, int N, int b + ) { + + const int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= z*N*b*b) { + return; + } + const int zi = i/(N*b*b); + const int Ni = (i%(N*b*b))/(b*b); + const int x = ((i%(N*b*b))%(b*b))/b; + const int y = ((i%(N*b*b))%(b*b))%b; + + output[zi*N*b*N*b + (Ni*b+x)*N*b + Ni*b + y] = input[zi*N*b*b + Ni*b*b + x*b + y]; + +} + +template +__global__ void backward_fast_block_diag_cuda_kernel( + const scalar_t* __restrict__ grad_output, + scalar_t* grad_input, + int z, int N, int b + ) { + + const int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= z*N*b*b) { + return; + } + const int zi = i/(N*b*b); + const int Ni = (i%(N*b*b))/(b*b); + const int x = ((i%(N*b*b))%(b*b))/b; + const int y = ((i%(N*b*b))%(b*b))%b; + + grad_input[zi*N*b*b + Ni*b*b + x*b + y] = grad_output[zi*N*b*N*b + (Ni*b+x)*N*b + Ni*b + y]; + +} // namespace +} + +std::vector forward_fast_block_diag_cuda( + at::Tensor input + ){ + const auto z = input.size(0); + const auto N = input.size(1); + const auto b = input.size(2); + + // print(channel_size) + const int threads = 512; + const dim3 blocks_1 ((z*N*b*b - 1) / threads +1); + // initlaize output + auto output = at::zeros({z, N*b, N*b}, input.options()); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "forward_fast_block_diag1", ([&] { + forward_fast_block_diag_cuda_kernel<<>>( + input.data(), + output.data(), + z, N, b); + })); + + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + printf("Error in forward_fast_block_diag_cuda_kernel: %s\n", cudaGetErrorString(err)); + + return {output}; +} + +std::vector backward_fast_block_diag_cuda( + at::Tensor grad_output, + at::Tensor input + ){ + + const auto z = input.size(0); + const auto N = input.size(1); + const auto b = input.size(2); + + // print(channel_size) + const int threads = 512; + const dim3 blocks_1 ((z*N*b*b - 1) / threads +1); + + // initialize grad input + auto grad_input = at::zeros_like(input); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_output.type(), "backward_fast_block_diag", ([&] { + backward_fast_block_diag_cuda_kernel<<>>( + grad_output.data(), + grad_input.data(), + z, N, b); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + printf("Error in backward_fast_block_diag_cuda_kernel: %s\n", cudaGetErrorString(err)); + + return {grad_input}; +} diff --git a/peft/tuners/boft/layer.py b/peft/tuners/boft/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..192a2350e0b09cdb6aabc0991fa7790350a66089 --- /dev/null +++ b/peft/tuners/boft/layer.py @@ -0,0 +1,978 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The implementation is based on "Parameter-Efficient Orthogonal Finetuning +# via Butterfly Factorization" (https://arxiv.org/abs/2311.06243) in ICLR 2024. + +from __future__ import annotations + +import math +import os +import warnings +from contextlib import contextmanager +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.utils.cpp_extension import load + +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge + + +_FBD_CUDA = None + + +# this function is a 1:1 copy from accelerate +@contextmanager +def patch_environment(**kwargs): + """ + A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting. + + Will convert the values in `kwargs` to strings and upper-case all the keys. + + Example: + + ```python + >>> import os + >>> from accelerate.utils import patch_environment + + >>> with patch_environment(FOO="bar"): + ... print(os.environ["FOO"]) # prints "bar" + >>> print(os.environ["FOO"]) # raises KeyError + ``` + """ + existing_vars = {} + for key, value in kwargs.items(): + key = key.upper() + if key in os.environ: + existing_vars[key] = os.environ[key] + os.environ[key] = str(value) + + yield + + for key in kwargs: + key = key.upper() + if key in existing_vars: + # restore previous value + os.environ[key] = existing_vars[key] + else: + os.environ.pop(key, None) + + +def get_fbd_cuda(): + global _FBD_CUDA + + if _FBD_CUDA is not None: + return _FBD_CUDA + + curr_dir = os.path.dirname(__file__) + # need ninja to build the extension + try: + with patch_environment(CC="gcc", CXX="gcc"): + fbd_cuda = load( + name="fbd_cuda", + sources=[f"{curr_dir}/fbd/fbd_cuda.cpp", f"{curr_dir}/fbd/fbd_cuda_kernel.cu"], + verbose=True, + # build_directory='/tmp/' # for debugging + ) + # extra_cuda_cflags = ['-std=c++14', '-ccbin=$$(which gcc-7)']) # cuda10.2 is not compatible with gcc9. Specify gcc 7 + import fbd_cuda + except Exception as e: + warnings.warn(f"Failed to load the CUDA extension: {e}, check if ninja is available.") + warnings.warn("Setting boft_n_butterfly_factor to 1 to speed up the finetuning process.") + fbd_cuda = None + + _FBD_CUDA = fbd_cuda + return _FBD_CUDA + + +class FastBlockDiag(Function): + """ + Implements a custom autograd Function for a fast block diagonal operation using CUDA. + + This function is optimized for 4D tensors where the last two dimensions are equal, representing block diagonal + matrices for efficient computation on CUDA devices. + """ + + @staticmethod + def forward(ctx, input): + """ + The forward method for FastBlockDiag. + + Computes the block diagonal operation on the input tensor using a CUDA-optimized function. This method assumes + that the input is a 4D tensor where the last two dimensions are equal, which represent the blocks to be + diagonalized. + + Parameters: + ctx: A context object that can be used to stash information for backward computation. + input (Tensor): The input tensor of shape (N, D, H, H), where `N` is the batch size, + `D` represents one additional dimension (In BOFT, the number of BOFT blocks), and `H` is the + size of the square blocks along the last two dimensions (In BOFT, the block size). + + Returns: + Tensor: The resulting tensor after applying the block diagonal operation, + will have the shape (N, DxH, DxH). + """ + output = get_fbd_cuda().forward(input)[0] + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + (input,) = ctx.saved_tensors + grad_input = get_fbd_cuda().backward(grad_output, input)[0] + return grad_input + + +class MultiplicativeDropoutLayer(nn.Module): + """ + Implements the multiplicative dropout layer for BOFT. + """ + + def __init__(self, p=0.0): + """ + Initializes the multiplicative dropout layer. + + Parameters: + p (float): The probability of dropping out a block. Defaults to 0.0. + """ + super().__init__() + self.p = p + + def forward(self, x): + """ + Applies multiplicative dropout to the input tensor. + + Parameters: + x (Tensor): The input tensor of shape (N, D, H, H), where `N` is the batch size, `D` represents + one additional dimension (In BOFT, the number of BOFT blocks), and `H` is the size of the square + blocks along the last two dimensions (In BOFT, the block size). + """ + if self.training: + # Ensure the last two dimensions are the same + if x.shape[-1] != x.shape[-2]: + raise ValueError("The last two dimensions of input should be the same!") + + N, D, H, _ = x.shape + + # Randomly select one from N + n_random = torch.randint(0, N, (1,)).item() + + # Create a mask with 1s for matrices to be replaced with identity and 0s otherwise + num_to_replace = int(self.p * D) + num_zeros = D - num_to_replace + + # Generate a flat tensor with desired number of 1s and 0s + mask = torch.cat([torch.ones(num_to_replace, device=x.device), torch.zeros(num_zeros, device=x.device)]) + + # Shuffle and reshape the mask + mask = mask[torch.randperm(D)].view(1, D, 1, 1) + + full_mask = torch.zeros(N, D, 1, 1, device=x.device) + full_mask[n_random] = mask + + # Use the mask to combine original matrices and identity matrices + eye_matrix = torch.eye(H, device=x.device).repeat(N, D, 1, 1) + x = (1 - full_mask) * x + full_mask * eye_matrix + return x + + +class BOFTLayer(BaseTunerLayer): + """ + Implements the BOFT layer. + """ + + # All names of layers that may contain (trainable) adapter weights + adapter_layer_names = ("boft_R", "boft_s") + # All names of other parameters that may contain adapter-related parameters + other_param_names = ("boft_block_size", "boft_block_num", "boft_dropout") + + def __init__(self, base_layer: nn.Module, **kwargs) -> None: + """ + Initializes the BOFT layer. + + Note, currently only support linear layer and convolutional layer, with further support for other layers to be + added soon. + + Parameters: + base_layer: the pretrained model layer + """ + self.base_layer = base_layer + self.boft_block_size = {} + self.boft_block_num = {} + self.boft_dropout = nn.ModuleDict({}) + self.boft_R = nn.ParameterDict({}) + self.boft_s = nn.ParameterDict({}) + # Mark the weight as unmerged + self._disable_adapters = False + self.merged_adapters = [] + self.kwargs = kwargs + + base_layer = self.get_base_layer() + + if isinstance(base_layer, nn.Linear): + in_features, out_features = base_layer.in_features, base_layer.out_features + elif isinstance(base_layer, nn.Conv2d): + in_features, out_features = base_layer.in_channels, base_layer.out_channels + else: + raise ValueError(f"Unsupported layer type {type(base_layer)}") + + self.in_features = in_features + self.out_features = out_features + + def set_scale(self, adapter, scale): + if adapter not in self.scaling: + # Ignore the case where the adapter is not in the layer + return + + warnings.warn("Scaling operation for BOFT not supported! Automatically set scale to 1.") + + def scale_layer(self, scale: float) -> None: + if scale == 1: + return + + for active_adapter in self.active_adapters: + if active_adapter not in self.boft_R.keys(): + continue + + warnings.warn("Scaling operation for BOFT not supported! Automatically set scale to 1.") + + def unscale_layer(self, scale=None) -> None: + for active_adapter in self.active_adapters: + if active_adapter not in self.boft_R.keys(): + continue + + warnings.warn("Unscaling operation for BOFT not supported! Keeping scale to 1.") + + def update_layer( + self, adapter_name, boft_block_size, boft_block_num, boft_n_butterfly_factor, boft_dropout, init_weights + ): + """ + Update the linear layer with trainable BOFT weights. Override for other layer types. + """ + # to be consistent with the paper notation + boft_n_butterfly_factor = boft_n_butterfly_factor - 1 + if boft_n_butterfly_factor < 0: + raise ValueError( + f"You can only specify boft_n_butterfly_factor {boft_n_butterfly_factor+1} to be a positive integer number." + ) + + # Initialize the MultiplicativeDropoutLayer for boft_dropout > 0.0. + if boft_dropout > 0.0: + boft_dropout_layer = MultiplicativeDropoutLayer(p=boft_dropout) + else: + boft_dropout_layer = nn.Identity() + self.boft_dropout.update(nn.ModuleDict({adapter_name: boft_dropout_layer})) + + if boft_block_size == 0 and boft_block_num != 0: + if self.in_features % boft_block_num != 0: + raise ValueError( + f"in_features ({self.in_features}) must be divisible by boft_block_num ({boft_block_num})!" + ) + + if boft_n_butterfly_factor != 0: + if boft_n_butterfly_factor > int(math.log2(boft_block_num)): + raise ValueError( + f"Invalid combination of boft_n_butterfly_factor ({boft_n_butterfly_factor+1}) and boft_block_num ({boft_block_num})!" + ) + if boft_block_num % (2**boft_n_butterfly_factor) != 0: + raise ValueError( + f"boft_block_num ({boft_block_num}) must be a multiple of 2 raised to the power of boft_n_butterfly_factor ({boft_n_butterfly_factor+1})!" + ) + + boft_block_size = int(self.in_features // boft_block_num) + + elif boft_block_size != 0 and boft_block_num == 0: + if self.in_features % boft_block_size != 0: + raise ValueError( + f"in_features ({self.in_features}) must be divisible by boft_block_size ({boft_block_size})!" + ) + + if boft_n_butterfly_factor != 0: + if self.in_features < (boft_block_size * (2**boft_n_butterfly_factor)): + raise ValueError( + f"Invalid combination of in_features ({self.in_features}), boft_n_butterfly_factor ({boft_n_butterfly_factor+1}) and boft_block_size ({boft_block_size})!" + ) + if self.in_features % (boft_block_size * (2**boft_n_butterfly_factor)) != 0: + raise ValueError( + f"Invalid combination of in_features ({self.in_features}), boft_n_butterfly_factor ({boft_n_butterfly_factor+1}) and boft_block_size ({boft_block_size})!" + ) + + boft_block_num = int(self.in_features // boft_block_size) + + else: + raise ValueError( + f"You can only specify either boft_block_size ({boft_block_size}) or boft_block_num ({boft_block_num}), but not both simultaneously or setting both" + "to be 0, because boft_block_size x boft_block_num != in_features." + ) + + # In OFT you can specify the number of blocks to be 1 + if boft_n_butterfly_factor != 0: + if boft_block_num % 2 != 0: + raise ValueError(f"boft_block_num ({boft_block_num}) must be an even number!") + + if boft_block_size % 2 != 0: + raise ValueError(f"boft_block_size ({boft_block_size}) must be an even number!") + + # If there is no butterfly factor, then permutation matrix P will be an identity matrix. + P = torch.empty((boft_n_butterfly_factor + 1, self.in_features, self.in_features)) + for i in range(boft_n_butterfly_factor + 1): + perm = self.block_butterfly_perm( + self.in_features, int(boft_block_num / (2 ** (i))), int(boft_block_size / 2), boft_n_butterfly_factor + ) + perm_mat = self.perm2mat(perm) + P[i] = perm_mat + + self.register_buffer("boft_P", P) + + self.boft_R[adapter_name] = nn.Parameter( + torch.zeros(boft_n_butterfly_factor + 1, boft_block_num, boft_block_size, boft_block_size) + ) + self.boft_s[adapter_name] = nn.Parameter(torch.ones(int(self.out_features), 1)) + + self.reset_boft_parameters(adapter_name, init_weights) + + weight = getattr(self, "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + if weight.dtype.is_floating_point or weight.dtype.is_complex: + self.to(weight.device, dtype=weight.dtype) + else: + self.to(weight.device) + + # set the boft block size and number + self.boft_block_size[adapter_name] = boft_block_size + self.boft_block_num[adapter_name] = boft_block_num + + self.set_adapter(self.active_adapters) + + def reset_boft_parameters(self, adapter_name, init_weights): + """ + Reset the BOFT parameters. + """ + if init_weights is False: + nn.init.normal_(self.boft_R[adapter_name], mean=0.0, std=0.1) + nn.init.normal_(self.boft_s[adapter_name], mean=1.0, std=0.1) + return + + if adapter_name in self.boft_R.keys(): + if init_weights is True: + # initialize R to zero + nn.init.zeros_(self.boft_R[adapter_name]) + nn.init.ones_(self.boft_s[adapter_name]) + else: + raise ValueError(f"Unknown initialization {init_weights=}") + + def perm2mat(self, indices): + """ + Convert permutation indices to permutation matrix. + + Args: + indices: A list of indices representing the permutation. + """ + # Number of indices determines the size of the square matrix + n = len(indices) + + # Initialize a matrix of zeros + perm_mat = torch.zeros((n, n)) + + # Set the 1s according to the indices + for i, idx in enumerate(indices): + perm_mat[i, idx] = 1 + + return perm_mat + + def block_butterfly_perm(self, n, b, r=3, n_butterfly_factor=1): + """ + Define the permutation matrix for the block butterfly permutation. + + Args: + n: size of the permutation matrix + b: desired number of blocks after multiplying with the permutation matrix + r: base block size of the block diagonal matrix, e.g. 2x2, 3x3, 5x5 etc. + """ + + if n_butterfly_factor == 0: + return torch.arange(n) + + if b * r * 2 > n: + raise ValueError("Invalid number of blocks!") + + block_size = int(n // b) + indices = torch.arange(n) + + def sort_block(b, r): + step = b / r + initial_order = torch.arange(b) + sorted_order = torch.empty(b, dtype=torch.long) + + evens = torch.arange(0, step, 2) + odds = torch.arange(1, step, 2) + sorted_seq = torch.cat((evens, odds), dim=0) + for i, pos in enumerate(sorted_seq): + sorted_order[int(i * r) : int(i * r + r)] = initial_order[int(pos * r) : int(pos * r + r)] + return sorted_order + + sorted_order = sort_block(block_size, r) + + for i in range(0, n, block_size): + block_end = i + block_size + tmp_indices = indices[i:block_end] + indices[i:block_end] = tmp_indices[sorted_order] + return indices + + def cayley_batch(self, data): + """ + Perform the Cayley parametrization on a batch of skew-symmetric matrices. + + Args: + data: A batch of skew-symmetric matrices of shape (b, r, c). + """ + b, r, c = data.shape + # Ensure the input matrix is skew-symmetric + skew_mat = 0.5 * (data - data.transpose(1, 2)) + id_mat = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c) + + # Perform the Cayley parametrization + Q = torch.linalg.solve(id_mat + skew_mat, id_mat - skew_mat, left=False) + + return Q + + +class Linear(nn.Module, BOFTLayer): + """ + BOFT implemented in a dense layer. + """ + + def __init__( + self, + base_layer, + adapter_name: str, + boft_block_size: int = 8, + boft_block_num: int = 0, + boft_n_butterfly_factor: int = 0, + boft_dropout: float = 0.1, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + init_weights: Union[bool, str] = True, + is_target_conv_1d_layer: bool = False, + **kwargs, + ) -> None: + super().__init__() + BOFTLayer.__init__(self, base_layer, **kwargs) + self.fan_in_fan_out = fan_in_fan_out + + self._active_adapter = adapter_name + + # Attempt to load the CUDA extension during model initialization + if not get_fbd_cuda(): + self.fbd_cuda_available = False + # If the CUDA extension is not available, set the butterfly factor to 1 to speed up the finetuning process + boft_n_butterfly_factor = 1 + else: + self.fbd_cuda_available = True + + self.update_layer( + adapter_name, boft_block_size, boft_block_num, boft_n_butterfly_factor, boft_dropout, init_weights + ) + self.is_target_conv_1d_layer = is_target_conv_1d_layer + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.boft_R.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weight = base_layer.weight.data.clone() + butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) + orig_weight = torch.transpose(orig_weight, 0, 1) + orig_weight = torch.mm(butterfly_oft_mat, orig_weight) + orig_weight = torch.transpose(orig_weight, 0, 1) + orig_weight = orig_weight * boft_s + + if not torch.isfinite(orig_weight).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + self.base_layer.weight.data = orig_weight + else: + butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) + orig_weight = base_layer.weight.data.clone() + orig_weight = torch.transpose(orig_weight, 0, 1) + orig_weight = torch.mm(butterfly_oft_mat, orig_weight) + orig_weight = torch.transpose(orig_weight, 0, 1) + orig_weight = orig_weight * boft_s + + self.base_layer.weight.data = orig_weight + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.boft_R.keys(): + butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) + + orig_weight = self.get_base_layer().weight.data.clone() + orig_weight = torch.transpose(orig_weight, 0, 1) + orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight) + orig_weight = torch.transpose(orig_weight, 0, 1) + + self.get_base_layer().weight.data = orig_weight * (1 / boft_s) + + def get_delta_weight(self, adapter) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + boft_R = self.boft_R[adapter] + boft_s = self.boft_s[adapter] + + N, D, H, _ = boft_R.shape + boft_R = boft_R.view(N * D, H, H) + orth_rotate_butterfly = self.cayley_batch(boft_R) + orth_rotate_butterfly = orth_rotate_butterfly.view(N, D, H, H) + if self.fbd_cuda_available: + block_diagonal_butterfly = FastBlockDiag.apply(orth_rotate_butterfly) + else: + orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0) + block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly)) + block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0) + + butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, self.boft_P.permute(0, 2, 1)) + butterfly_oft_mat_batch = torch.bmm(self.boft_P, butterfly_oft_mat_batch) + butterfly_oft_mat = butterfly_oft_mat_batch[0] + + for i in range(1, butterfly_oft_mat_batch.shape[0]): + butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat + + return butterfly_oft_mat, boft_s + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + boft_rotation = torch.eye(self.in_features, device=x.device) + boft_scale = torch.ones((int(self.out_features), 1), device=x.device) + + for active_adapter in self.active_adapters: + if active_adapter not in self.boft_R.keys(): + continue + boft_R = self.boft_R[active_adapter] + boft_s = self.boft_s[active_adapter] + dropout = self.boft_dropout[active_adapter] + + N, D, H, _ = boft_R.shape + boft_R = boft_R.view(N * D, H, H) + orth_rotate_butterfly = self.cayley_batch(boft_R) + orth_rotate_butterfly = orth_rotate_butterfly.view(N, D, H, H) + orth_rotate_butterfly = dropout(orth_rotate_butterfly) + if self.fbd_cuda_available: + block_diagonal_butterfly = FastBlockDiag.apply(orth_rotate_butterfly) + else: + orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0) + block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly)) + block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0) + + butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, self.boft_P.permute(0, 2, 1)) + butterfly_oft_mat_batch = torch.bmm(self.boft_P, butterfly_oft_mat_batch) + butterfly_oft_mat = butterfly_oft_mat_batch[0] + + for i in range(1, butterfly_oft_mat_batch.shape[0]): + butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat + + boft_rotation = butterfly_oft_mat @ boft_rotation + boft_scale = boft_s * boft_scale + + x = x.to(self.get_base_layer().weight.data.dtype) + + orig_weight = self.get_base_layer().weight.data + orig_weight = torch.transpose(orig_weight, 0, 1) + rotated_weight = torch.mm(boft_rotation, orig_weight) + rotated_weight = torch.transpose(rotated_weight, 0, 1) + + scaled_rotated_weight = rotated_weight * boft_scale + + result = F.linear(input=x, weight=scaled_rotated_weight, bias=self.base_layer.bias) + + result = result.to(previous_dtype) + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "boft." + rep + + +class Conv2d(nn.Module, BOFTLayer): + """ + BOFT implemented in a Conv2d layer. + """ + + def __init__( + self, + base_layer: nn.Module, + adapter_name: str, + boft_block_size: int = 8, + boft_block_num: int = 0, + boft_n_butterfly_factor: int = 0, + boft_dropout: float = 0.1, + init_weights: Union[bool, str] = True, + **kwargs, + ) -> None: + super().__init__() + BOFTLayer.__init__(self, base_layer) + + self._active_adapter = adapter_name + + # Attempt to load the CUDA extension during model initialization + if not get_fbd_cuda(): + self.fbd_cuda_available = False + # If the CUDA extension is not available, set the butterfly factor to 1 to speed up the finetuning process + boft_n_butterfly_factor = 1 + else: + self.fbd_cuda_available = True + + self.update_layer( + adapter_name, boft_block_size, boft_block_num, boft_n_butterfly_factor, boft_dropout, init_weights + ) + + def update_layer( + self, adapter_name, boft_block_size, boft_block_num, boft_n_butterfly_factor, boft_dropout, init_weights + ): + """ + Update the conv2d layer with trainable BOFT weights. + """ + # to be consistent with the paper notation + boft_n_butterfly_factor = boft_n_butterfly_factor - 1 + if boft_n_butterfly_factor < 0: + raise ValueError( + f"You can only specify boft_n_butterfly_factor {boft_n_butterfly_factor+1} to be a positive integer number." + ) + + # Initialize the MultiplicativeDropoutLayer for boft_dropout > 0.0. + if boft_dropout > 0.0: + boft_dropout_layer = MultiplicativeDropoutLayer(p=boft_dropout) + else: + boft_dropout_layer = nn.Identity() + self.boft_dropout.update(nn.ModuleDict({adapter_name: boft_dropout_layer})) + + # layer information from the base layer + base_layer = self.get_base_layer() + conv_filter_dim = self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0] + + # Initialize the BOFT parameters. + if not (boft_block_size != 0) ^ (boft_block_num != 0): + raise ValueError( + f"You can only specify either boft_block_size ({boft_block_size}) or boft_block_num ({boft_block_num}), but not both simultaneously, because boft_block_size x boft_block_num != in_features." + ) + + if boft_block_size == 0 and boft_block_num != 0: + if conv_filter_dim % boft_block_num != 0: + raise ValueError( + f"Convolutional kernel dimension ({conv_filter_dim}) must be divisible by boft_block_num ({boft_block_num})!" + ) + + if boft_n_butterfly_factor != 0: + if boft_n_butterfly_factor > int(math.log2(boft_block_num)): + raise ValueError( + f"Invalid combination of boft_n_butterfly_factor ({boft_n_butterfly_factor+1}) and boft_block_num ({boft_block_num})!" + ) + if boft_block_num % (2**boft_n_butterfly_factor) != 0: + raise ValueError( + f"boft_block_num ({boft_block_num}) must be a multiple of 2 raised to the power of boft_n_butterfly_factor ({boft_n_butterfly_factor+1})!" + ) + + boft_block_size = int(conv_filter_dim // boft_block_num) + + elif boft_block_size != 0 and boft_block_num == 0: + if conv_filter_dim % boft_block_size != 0: + raise ValueError( + f"Convolutional kernel dimension ({conv_filter_dim}) must be divisible by boft_block_size ({boft_block_size})!" + ) + + if boft_n_butterfly_factor != 0: + if conv_filter_dim < (boft_block_size * (2**boft_n_butterfly_factor)): + raise ValueError( + f"Invalid combination of convolutional kernel dimension ({conv_filter_dim}), boft_n_butterfly_factor ({boft_n_butterfly_factor+1}) and boft_block_size ({boft_block_size})!" + ) + if conv_filter_dim % (boft_block_size * (2**boft_n_butterfly_factor)) != 0: + raise ValueError( + f"Invalid combination of convolutional kernel dimension ({conv_filter_dim}), boft_n_butterfly_factor ({boft_n_butterfly_factor+1}) and boft_block_size ({boft_block_size})!" + ) + + boft_block_num = int(conv_filter_dim // boft_block_size) + + else: + raise ValueError("Unknown error!") + + # In OFT you can specify the number of blocks to be 1 + if boft_n_butterfly_factor != 0: + if boft_block_num % 2 != 0: + raise ValueError(f"boft_block_num ({boft_block_num}) must be an even number!") + + if boft_block_size % 2 != 0: + raise ValueError(f"boft_block_size ({boft_block_size}) must be an even number!") + + # If there is no butterfly factor, then permutation matrix P will be an identity matrix. + P = torch.empty((boft_n_butterfly_factor + 1, conv_filter_dim, conv_filter_dim)) + for i in range(boft_n_butterfly_factor + 1): + perm = self.block_butterfly_perm( + conv_filter_dim, int(boft_block_num / (2 ** (i))), int(boft_block_size / 2), boft_n_butterfly_factor + ) + perm_mat = self.perm2mat(perm) + P[i] = perm_mat + + self.register_buffer("boft_P", P) + + self.boft_R[adapter_name] = nn.Parameter( + torch.zeros(boft_n_butterfly_factor + 1, boft_block_num, boft_block_size, boft_block_size) + ) + self.boft_s[adapter_name] = nn.Parameter(torch.ones(1, int(self.out_features))) + + self.reset_boft_parameters(adapter_name, init_weights) + + weight = getattr(self, "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + if weight.dtype.is_floating_point or weight.dtype.is_complex: + self.to(weight.device, dtype=weight.dtype) + else: + self.to(weight.device) + self.set_adapter(self.active_adapters) + + # set the boft block size and number + self.boft_block_size[adapter_name] = boft_block_size + self.boft_block_num[adapter_name] = boft_block_num + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.boft_R.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weight = base_layer.weight.data.clone() + butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) + + orig_weight = orig_weight.view( + self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0], self.out_features + ) + orig_weight = torch.mm(butterfly_oft_mat, orig_weight) + orig_weight = orig_weight * boft_s + orig_weight = orig_weight.view( + self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0] + ) + + self.base_layer.weight.data = orig_weight + else: + butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) + + orig_weight = base_layer.weight.data.clone() + orig_weight = orig_weight.view( + self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0], self.out_features + ) + orig_weight = torch.mm(butterfly_oft_mat, orig_weight) + orig_weight = orig_weight * boft_s + orig_weight = orig_weight.view( + self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0] + ) + + self.base_layer.weight.data = orig_weight + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.boft_R.keys(): + butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) + + orig_weight = self.get_base_layer().weight.data.clone() + orig_weight = orig_weight.view( + self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0], + self.out_features, + ) + orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight) + orig_weight = orig_weight * (1 / boft_s) + orig_weight = orig_weight.view( + self.out_features, + self.in_features, + self.get_base_layer().kernel_size[0], + self.get_base_layer().kernel_size[0], + ) + + self.get_base_layer().weight.data = orig_weight + + def get_delta_weight(self, adapter) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + + boft_R = self.boft_R[adapter] + boft_s = self.boft_s[adapter] + + N, D, H, _ = boft_R.shape + boft_R = boft_R.view(N * D, H, H) + orth_rotate_butterfly = self.cayley_batch(boft_R) + orth_rotate_butterfly = orth_rotate_butterfly.view(N, D, H, H) + if self.fbd_cuda_available: + block_diagonal_butterfly = FastBlockDiag.apply(orth_rotate_butterfly) + else: + orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0) + block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly)) + block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0) + + butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, self.boft_P.permute(0, 2, 1)) + butterfly_oft_mat_batch = torch.bmm(self.boft_P, butterfly_oft_mat_batch) + butterfly_oft_mat = butterfly_oft_mat_batch[0] + + for i in range(1, butterfly_oft_mat_batch.shape[0]): + butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat + + return butterfly_oft_mat, boft_s + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + boft_rotation = torch.eye( + self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], device=x.device + ) + boft_scale = torch.ones((1, int(self.out_features)), device=x.device) + + for active_adapter in self.active_adapters: + if active_adapter not in self.boft_R.keys(): + continue + boft_R = self.boft_R[active_adapter] + boft_s = self.boft_s[active_adapter] + dropout = self.boft_dropout[active_adapter] + + N, D, H, _ = boft_R.shape + boft_R = boft_R.view(N * D, H, H) + orth_rotate_butterfly = self.cayley_batch(boft_R) + orth_rotate_butterfly = orth_rotate_butterfly.view(N, D, H, H) + orth_rotate_butterfly = dropout(orth_rotate_butterfly) + if self.fbd_cuda_available: + block_diagonal_butterfly = FastBlockDiag.apply(orth_rotate_butterfly) + else: + orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0) + block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly)) + block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0) + + butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, self.boft_P.permute(0, 2, 1)) + butterfly_oft_mat_batch = torch.bmm(self.boft_P, butterfly_oft_mat_batch) + butterfly_oft_mat = butterfly_oft_mat_batch[0] + + for i in range(1, butterfly_oft_mat_batch.shape[0]): + butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat + + boft_rotation = butterfly_oft_mat @ boft_rotation + boft_scale = boft_s * boft_scale + + x = x.to(self.base_layer.weight.data.dtype) + + orig_weight = self.base_layer.weight.data + orig_weight = orig_weight.view( + self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], + self.out_features, + ) + rotated_weight = torch.mm(boft_rotation, orig_weight) + + scaled_rotated_weight = rotated_weight * boft_scale + + scaled_rotated_weight = scaled_rotated_weight.view( + self.out_features, self.in_features, self.base_layer.kernel_size[0], self.base_layer.kernel_size[0] + ) + result = F.conv2d( + input=x, + weight=scaled_rotated_weight, + bias=self.base_layer.bias, + padding=self.base_layer.padding[0], + stride=self.base_layer.stride[0], + ) + + result = result.to(previous_dtype) + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "boft." + rep diff --git a/peft/tuners/boft/model.py b/peft/tuners/boft/model.py new file mode 100644 index 0000000000000000000000000000000000000000..c975f56a9de5e15865b9ee992b45d2fb13b5230b --- /dev/null +++ b/peft/tuners/boft/model.py @@ -0,0 +1,334 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The implementation is based on "Parameter-Efficient Orthogonal Finetuning +# via Butterfly Factorization" (https://arxiv.org/abs/2311.06243) in ICLR 2024. + +import warnings +from dataclasses import asdict +from enum import Enum +from typing import List, Optional + +import torch +from torch import nn +from tqdm import tqdm + +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists +from peft.utils import ( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, + _get_submodules, +) + +from .config import BOFTConfig +from .layer import BOFTLayer, Conv2d, Linear + + +class BOFTModel(BaseTuner): + """ + Creates BOFT and OFT model from a pretrained transformers model. Paper: https://arxiv.org/abs/2311.06243 + https://arxiv.org/abs/2306.07280 + + Args: + model ([`transformers.PreTrainedModel`]): The model to be adapted. + config ([`BOFTConfig`]): The configuration of the BOFT model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + `torch.nn.Module`: The BOFT model. + + Example:: + + >>> import transformers >>> from transformers import AutoModelForSeq2SeqLM, BOFTConfig >>> from peft import + BOFTConfig, get_peft_model + + >>> config = BOFTConfig( ... boft_block_size=8, ... boft_n_butterfly_factor=1, ... target_modules=["query", + "value", "key", "output.dense", "mlp.fc1", "mlp.fc2"], ... boft_dropout=0.1, ... bias="boft_only", ... + modules_to_save=["classifier"], ... ) + + >>> model = transformers.Dinov2ForImageClassification.from_pretrained( ... "facebook/dinov2-large", ... + num_labels=100, ... ) >>> boft_model = get_peft_model(model, config) + + **Attributes**: + - **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted. + - **peft_config** ([`BOFTConfig`]): The configuration of the BOFT model. + """ + + prefix: str = "boft_" + + def __init__(self, model, config, adapter_name) -> None: + super().__init__(model, config, adapter_name) + + def _check_new_adapter_config(self, config: BOFTConfig) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + # TODO: there should be a check if any of the existing adapters actually has bias != "none", or else the check + # does not fully correspond to the error message. + if (len(self.peft_config) > 1) and (config.bias != "none"): + raise ValueError( + f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, " + "set bias to 'none' for all adapters." + ) + + @staticmethod + def _check_target_module_exists(boft_config, key): + return check_target_module_exists(boft_config, key) + + def _create_and_replace( + self, + boft_config, + adapter_name, + target, + target_name, + parent, + current_key, + **optional_kwargs, + ): + if current_key is None: + raise ValueError("Current Key shouldn't be `None`") + + bias = hasattr(target, "bias") and target.bias is not None + kwargs = { + "boft_block_size": boft_config.boft_block_size, + "boft_block_num": boft_config.boft_block_num, + "boft_n_butterfly_factor": boft_config.boft_n_butterfly_factor, + "boft_dropout": boft_config.boft_dropout, + "fan_in_fan_out": boft_config.fan_in_fan_out, + "init_weights": boft_config.init_weights, + } + kwargs["bias"] = bias + + # If it is not a BOFTLayer, create a new module, else update it with new adapters + if not isinstance(target, BOFTLayer): + new_module = self._create_new_module(boft_config, adapter_name, target, **kwargs) + if adapter_name not in self.active_adapters: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + else: + target.update_layer( + adapter_name, + boft_block_size=boft_config.boft_block_size, + boft_block_num=boft_config.boft_block_num, + boft_n_butterfly_factor=boft_config.boft_n_butterfly_factor, + boft_dropout=boft_config.boft_dropout, + init_weights=boft_config.init_weights, + ) + + def _replace_module(self, parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + if self.prefix in name: + module.to(child.weight.device) + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.named_parameters(): + if self.prefix not in n: + p.requires_grad = False + + for active_adapter in self.active_adapters: + bias = self.peft_config[active_adapter].bias + if bias == "none": + continue + + if bias == "all": + for n, p in model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "boft_only": + for name, m in model.named_modules(): + if isinstance(m, BOFTLayer) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError(f"Requested bias: {bias}, is not implemented.") + + @staticmethod + def _create_new_module(boft_config, adapter_name, target, **kwargs): + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, torch.nn.Linear): + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." + ) + kwargs["fan_in_fan_out"] = boft_config.fan_in_fan_out = False + new_module = Linear(target, adapter_name, **kwargs) + elif isinstance(target_base_layer, torch.nn.Conv2d): + new_module = Conv2d(target, adapter_name, **kwargs) + else: + raise ValueError( + f"Target module {target} is not supported. " + "Currently, only `torch.nn.Linear` and `torch.nn.Conv2d` are supported." + ) + + return new_module + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.model, name) + + def get_peft_config_as_dict(self, inference: bool = False): + config_dict = {} + for key, value in self.peft_config.items(): + config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} + if inference: + config["inference_mode"] = True + config_dict[key] = config + return config + + def _set_adapter_layers(self, enabled=True): + for module in self.model.modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self): + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self): + for active_adapter in self.active_adapters: + val = self.peft_config[active_adapter].bias + if val != "none": + msg = ( + f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " + "output as the the base model would without adaption." + ) + warnings.warn(msg) + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name): + for module in self.model.modules(): + if isinstance(module, BOFTLayer): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return peft_config + + def _unload_and_optionally_merge( + self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[List[str]] = None, + ): + self._unloading_checks(adapter_names) + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + desc = "Unloading " + ("and merging " if merge else "") + "model" + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + + if hasattr(target, "base_layer"): + if merge: + target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + setattr(parent, target_name, target.modules_to_save[target.active_adapter]) + + return self.model + + def delete_adapter(self, adapter_name: str) -> None: + """ + Deletes an existing adapter. + + Args: + adapter_name (str): Name of the adapter to be deleted. + """ + if adapter_name not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter_name} does not exist") + del self.peft_config[adapter_name] + + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, BOFTLayer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapters[:] + + self.active_adapter = new_adapter or [] + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[List[str]] = None + ) -> torch.nn.Module: + r""" + This method merges the BOFT layers into the base model. This is needed if someone wants to use the base model + as a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self) -> torch.nn.Module: + """ + Gets back the base model by removing all the boft modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) diff --git a/peft/tuners/ia3/__init__.py b/peft/tuners/ia3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..763d0133a285eda4cf2aa68f624ea2c1b3a447a2 --- /dev/null +++ b/peft/tuners/ia3/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available + +from .config import IA3Config +from .layer import Conv2d, IA3Layer, Linear +from .model import IA3Model + + +__all__ = ["Conv2d", "IA3Config", "IA3Layer", "IA3Model", "Linear"] + + +def __getattr__(name): + if (name == "Linear8bitLt") and is_bnb_available(): + from .bnb import Linear8bitLt + + return Linear8bitLt + + if (name == "Linear4bit") and is_bnb_4bit_available(): + from .bnb import Linear4bit + + return Linear4bit + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/peft/tuners/ia3/__pycache__/__init__.cpython-310.pyc b/peft/tuners/ia3/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cabe4d57af91d6b3d2a94c96772dacc56328cdc Binary files /dev/null and b/peft/tuners/ia3/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/ia3/__pycache__/bnb.cpython-310.pyc b/peft/tuners/ia3/__pycache__/bnb.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae41aba8b97dab2fcefa5fbb42ef6829f66f950d Binary files /dev/null and b/peft/tuners/ia3/__pycache__/bnb.cpython-310.pyc differ diff --git a/peft/tuners/ia3/__pycache__/config.cpython-310.pyc b/peft/tuners/ia3/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f08a4bf2e09958a999b53be0d1e5a26ee6575d3e Binary files /dev/null and b/peft/tuners/ia3/__pycache__/config.cpython-310.pyc differ diff --git a/peft/tuners/ia3/__pycache__/layer.cpython-310.pyc b/peft/tuners/ia3/__pycache__/layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5ea0fd2db957b8e920124aa1d7c26f355e0c79c Binary files /dev/null and b/peft/tuners/ia3/__pycache__/layer.cpython-310.pyc differ diff --git a/peft/tuners/ia3/__pycache__/model.cpython-310.pyc b/peft/tuners/ia3/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db19e5da9ebe37b266d23f3ab7066ddba7882ba2 Binary files /dev/null and b/peft/tuners/ia3/__pycache__/model.cpython-310.pyc differ diff --git a/peft/tuners/ia3/bnb.py b/peft/tuners/ia3/bnb.py new file mode 100644 index 0000000000000000000000000000000000000000..628e3ce7229528a0b3157da349b2b34153573c51 --- /dev/null +++ b/peft/tuners/ia3/bnb.py @@ -0,0 +1,129 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available + +from .layer import IA3Layer + + +if is_bnb_available(): + + class Linear8bitLt(torch.nn.Module, IA3Layer): + # (IA)^3 implemented in a dense layer + def __init__( + self, + base_layer: torch.nn.Module, + adapter_name: str, + is_feedforward: bool, + init_ia3_weights: bool = True, + **kwargs, + ) -> None: + super().__init__() + IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward) + + # Freezing the pre-trained weight matrix + self.get_base_layer().weight.requires_grad = False + self._active_adapter = adapter_name + self.update_layer(adapter_name, init_ia3_weights) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + # note: no check for self.merged because merging is not supported (yet) + if self.disable_adapters: + return self.base_layer(x) + + ia3_scaling = 1 + for active_adapter in self.active_adapters: + if active_adapter not in self.ia3_l.keys(): + continue + ia3_scaling *= self.ia3_l[active_adapter].flatten() + + requires_conversion = (not torch.is_autocast_enabled()) and (x.dtype != torch.float32) + if requires_conversion: + x = x.float() + if self.is_feedforward: + result = self.base_layer(x * ia3_scaling) + expected_dtype = result.dtype + else: + result = self.base_layer(x) + expected_dtype = result.dtype + result = result * ia3_scaling + + if requires_conversion: + result = result.to(expected_dtype) + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "ia3." + rep + + +if is_bnb_4bit_available(): + + class Linear4bit(torch.nn.Module, IA3Layer): + # IA3 implemented in a dense layer + def __init__( + self, + base_layer: torch.nn.Module, + adapter_name: str, + is_feedforward: bool, + init_ia3_weights: bool = True, + **kwargs, + ) -> None: + super().__init__() + IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward) + + # Freezing the pre-trained weight matrix + self.get_base_layer().weight.requires_grad = False + self._active_adapter = adapter_name + self.update_layer(adapter_name, init_ia3_weights) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + # note: no check for self.merged because merging is not supported (yet) + if self.disable_adapters: + return self.base_layer(x) + + ia3_scaling = 1 + for active_adapter in self.active_adapters: + if active_adapter not in self.ia3_l.keys(): + continue + ia3_scaling *= self.ia3_l[active_adapter].flatten() + + requires_conversion = (not torch.is_autocast_enabled()) and (x.dtype != torch.float32) + if requires_conversion: + x = x.float() + if self.is_feedforward: + result = self.base_layer(x * ia3_scaling) + expected_dtype = result.dtype + else: + result = self.base_layer(x) + expected_dtype = result.dtype + result = result * ia3_scaling + + result = result.clone() + # adalora.py and lora.py both suggest that this is necessary for 4-bit training on older versions of Pytorch. + # This has been duplicated here. + + if requires_conversion: + result = result.to(expected_dtype) + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "ia3." + rep diff --git a/peft/tuners/ia3/config.py b/peft/tuners/ia3/config.py new file mode 100644 index 0000000000000000000000000000000000000000..322ea068d3d44fc6797354cc718f725b4999bbe4 --- /dev/null +++ b/peft/tuners/ia3/config.py @@ -0,0 +1,98 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import List, Optional, Union + +from peft.config import PeftConfig +from peft.utils import PeftType + + +@dataclass +class IA3Config(PeftConfig): + """ + This is the configuration class to store the configuration of a [`IA3Model`]. + + Args: + target_modules (`Optional[Union[List[str], str]]`): + The names of the modules to apply the adapter to. If this is specified, only the modules with the specified + names will be replaced. When passing a string, a regex match will be performed. When passing a list of + strings, either an exact match will be performed or it is checked if the name of the module ends with any + of the passed strings. If this is specified as 'all-linear', then all linear/Conv1D modules are chosen, + excluding the output layer. If this is not specified, modules will be chosen according to the model + architecture. If the architecture is not known, an error will be raised -- in this case, you should specify + the target modules manually. + feedforward_modules (`Optional[Union[List[str], str]]`): + The names of the modules to be treated as feedforward modules, as in the original paper. These modules will + have (IA)³ vectors multiplied to the input, instead of the output. `feedforward_modules` must be a name or + a subset of names present in `target_modules`. + fan_in_fan_out (`bool`): + Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses + `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`. + modules_to_save (`Optional[List[str]]`): + List of modules apart from (IA)³ layers to be set as trainable and saved in the final checkpoint. + init_ia3_weights (`bool`): + Whether to initialize the vectors in the (IA)³ layers, defaults to `True`. Setting this to `False` is + discouraged. + """ + + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": ( + "List of module names or regex expression of the module names to replace with (IA)³." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'." + "This can also be a wildcard 'all-linear' which matches all linear/Conv1D layers except the output layer." + "If not specified, modules will be chosen according to the model architecture, If the architecture is " + "not known, an error will be raised -- in this case, you should specify the target modules manually." + ), + }, + ) + feedforward_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or a regex expression of module names which are feedforward" + "For example, ['output.dense']" + }, + ) + fan_in_fan_out: bool = field( + default=False, + metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, + ) + modules_to_save: Optional[List[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from (IA)^3 layers to be set as trainable and saved in the final checkpoint. " + "For example, in Sequence Classification or Token Classification tasks, " + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." + }, + ) + init_ia3_weights: bool = field( + default=True, + metadata={"help": "Whether to initialize the vectors in the (IA)^3 layers."}, + ) + + def __post_init__(self): + self.peft_type = PeftType.IA3 + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) + self.feedforward_modules = ( + set(self.feedforward_modules) if isinstance(self.feedforward_modules, list) else self.feedforward_modules + ) + + # check if feedforward_modules is a subset of target_modules. run the check only if both are sets + if isinstance(self.feedforward_modules, set) and isinstance(self.target_modules, set): + if not self.feedforward_modules.issubset(self.target_modules): + raise ValueError("`feedforward_modules` should be a subset of `target_modules`") diff --git a/peft/tuners/ia3/layer.py b/peft/tuners/ia3/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..aef376860c9cd9ac244450ccfc7dc48a7fae4af8 --- /dev/null +++ b/peft/tuners/ia3/layer.py @@ -0,0 +1,310 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Any, List, Optional + +import torch +import torch.nn as nn +from transformers.pytorch_utils import Conv1D + +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils import transpose + + +class IA3Layer(BaseTunerLayer): + # All names of layers that may contain adapter weights + adapter_layer_names = ("ia3_l",) + + def __init__(self, base_layer: nn.Module, is_feedforward: bool, **kwargs) -> None: + self.base_layer = base_layer + self.ia3_l = nn.ParameterDict({}) + # Mark the weight as unmerged + self._disable_adapters = False + self.merged_adapters = [] + self.is_feedforward = is_feedforward + + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + in_features, out_features = base_layer.in_features, base_layer.out_features + elif isinstance(base_layer, nn.Conv2d): + in_features, out_features = base_layer.in_channels, base_layer.out_channels + elif isinstance(base_layer, nn.Embedding): + in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim + elif isinstance(base_layer, Conv1D): + in_features, out_features = ( + base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape + ) + else: + raise ValueError(f"Unsupported layer type {type(base_layer)}") + self.in_features = in_features + self.out_features = out_features + + def update_layer(self, adapter_name, init_ia3_weights): + # This code works for linear layers, override for other layer types + # Actual trainable parameters + if self.is_feedforward: + weight = torch.randn((1, self.in_features)) + else: + weight = torch.randn((self.out_features, 1)) + self.ia3_l[adapter_name] = nn.Parameter(weight) + if init_ia3_weights: + self.reset_ia3_parameters(adapter_name) + self.to(self.get_base_layer().weight.device) + self.set_adapter(self.active_adapters) + + def reset_ia3_parameters(self, adapter_name): + if adapter_name in self.ia3_l.keys(): + # initialize learned vector with torch.ones + nn.init.constant_(self.ia3_l[adapter_name], 1.0) + + +class Linear(nn.Module, IA3Layer): + # (IA)^3 implemented in a dense layer + def __init__( + self, + base_layer: nn.Module, + adapter_name: str, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + is_feedforward: bool = False, # Set to True if the layer is treated as a feedforward layer + is_target_conv_1d_layer: bool = False, # whether target module is a conv1d layer. useful while unloading later + init_ia3_weights: bool = True, # whether to initialize IA3 weights + **kwargs, + ) -> None: + super().__init__() + IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward) + self.fan_in_fan_out = fan_in_fan_out + self.is_target_conv_1d_layer = is_target_conv_1d_layer + self._active_adapter = adapter_name + self.update_layer(adapter_name, init_ia3_weights) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.ia3_l.keys(): + base_layer = self.get_base_layer() + ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out) + orig_dtype = base_layer.weight.data.dtype + if safe_merge: + orig_weights = base_layer.weight.data + orig_weights = torch.mul(orig_weights, ia3_l) + + if not torch.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + base_layer.weight.data = orig_weights.to(orig_dtype) + else: + base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_l).to(orig_dtype) + + if not self.is_feedforward and (base_layer.bias is not None): + scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) + orig_dtype = base_layer.bias.data.dtype + base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data).to(orig_dtype) + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + warnings.warn("Unmerge result can be inaccurate for (IA)^3.") + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.ia3_l.keys(): + base_layer = self.get_base_layer() + # Add tolerace to avoid division by zero + ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out) + 1e-8 + orig_dtype = base_layer.weight.data.dtype + base_layer.weight.data = torch.div(base_layer.weight.data, ia3_l).to(orig_dtype) + + if not self.is_feedforward and (base_layer.bias is not None): + scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) + orig_dtype = base_layer.bias.data.dtype + base_layer.bias.data = torch.div(base_layer.bias.data, scaling.data + 1e-8).to(orig_dtype) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + dtype = previous_dtype = x.dtype + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + ia3_scaling = 1 + for active_adapter in self.active_adapters: + if active_adapter not in self.ia3_l.keys(): + continue + dtype = self.ia3_l[active_adapter].dtype + ia3_scaling *= self.ia3_l[active_adapter].flatten() + + if self.is_feedforward: + x = x.to(dtype) + # TODO: weight.dtype can be != self.ia3_l[self.active_adapters].dtype + # e.g. bf16 vs fp32. Is that okay? + interm = (x * ia3_scaling).to(previous_dtype) + result = self.base_layer(interm, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + result_dtype = result.dtype + result = (result * ia3_scaling).to(result_dtype) + + return result + + +class Conv2d(nn.Module, IA3Layer): + def __init__( + self, + base_layer: nn.Module, + adapter_name: str, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + is_feedforward: bool = False, # Set to True if the layer is treated as a feedforward layer + init_ia3_weights: bool = True, + **kwargs, + ) -> None: + super().__init__() + IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward) + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + + self.update_layer(adapter_name, init_ia3_weights) + + def update_layer(self, adapter_name, init_ia3_weights): + # Actual trainable parameters + if self.is_feedforward: + weight = torch.randn((1, self.in_features, 1, 1)) + else: + weight = torch.randn((1, self.out_features, 1, 1)) + self.ia3_l[adapter_name] = nn.Parameter(weight) + if init_ia3_weights: + self.reset_ia3_parameters(adapter_name) + self.to(self.get_base_layer().weight.device) + self.set_adapter(self.active_adapters) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.ia3_l.keys(): + base_layer = self.get_base_layer() + ia3_scaling = self.ia3_l[active_adapter].data + if not self.is_feedforward: + ia3_scaling = ia3_scaling.permute(1, 0, 2, 3) + + if safe_merge: + output_weight = torch.mul(base_layer.weight.data, ia3_scaling).clone() + + if not torch.isfinite(output_weight).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = output_weight + else: + base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_scaling) + + if not self.is_feedforward and (base_layer.bias is not None): + scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) + base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data) + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + warnings.warn("Unmerge result can be inaccurate for (IA)^3.") + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.ia3_l.keys(): + base_layer = self.get_base_layer() + # divide by (IA)^3 vector. Add tolerace to avoid division by zero + ia3_scaling = self.ia3_l[active_adapter].data + if not self.is_feedforward: + ia3_scaling = ia3_scaling.permute(1, 0, 2, 3) + base_layer.weight.data = torch.div(base_layer.weight.data, ia3_scaling + 1e-8) + + if not self.is_feedforward and (base_layer.bias is not None): + scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) + base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + dtype = previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + ia3_scaling = 1 + for active_adapter in self.active_adapters: + if active_adapter not in self.ia3_l.keys(): + continue + dtype = self.ia3_l[active_adapter].dtype + ia3_scaling *= self.ia3_l[active_adapter] + + if self.is_feedforward: + x = x.to(dtype) + # TODO: weight.dtype can be != self.ia3_l[self.active_adapters].dtype + # e.g. bf16 vs fp32. Is that okay? + interm = (x * ia3_scaling).to(self.get_base_layer().weight.dtype) + result = self.base_layer(interm, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + result = result.to(dtype) * ia3_scaling + + result = result.to(previous_dtype) + return result diff --git a/peft/tuners/ia3/model.py b/peft/tuners/ia3/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9469e7e7b2bfae50f57c93fbc3cd650db840aeba --- /dev/null +++ b/peft/tuners/ia3/model.py @@ -0,0 +1,395 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import re +import warnings +from dataclasses import asdict +from enum import Enum +from typing import Optional + +import torch +from torch import nn +from transformers.pytorch_utils import Conv1D + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists +from peft.utils import ( + TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, + _get_submodules, +) + +from .layer import Conv2d, IA3Layer, Linear + + +class IA3Model(BaseTuner): + """ + Creates a Infused Adapter by Inhibiting and Amplifying Inner Activations ((IA)^3) model from a pretrained + transformers model. The method is described in detail in https://arxiv.org/abs/2205.05638 + + Args: + model ([`~transformers.PreTrainedModel`]): The model to be adapted. + config ([`IA3Config`]): The configuration of the (IA)^3 model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + `torch.nn.Module`: The (IA)^3 model. + + Example: + + ```py + >>> from transformers import AutoModelForSeq2SeqLM, ia3Config + >>> from peft import IA3Model, IA3Config + + >>> config = IA3Config( + ... peft_type="IA3", + ... task_type="SEQ_2_SEQ_LM", + ... target_modules=["k", "v", "w0"], + ... feedforward_modules=["w0"], + ... ) + + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + >>> ia3_model = IA3Model(config, model) + ``` + + **Attributes**: + - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. + - **peft_config** ([`ia3Config`]): The configuration of the (IA)^3 model. + """ + + prefix: str = "ia3_" + + def __init__(self, model, config, adapter_name): + super().__init__(model, config, adapter_name) + + @staticmethod + def _create_new_module(ia3_config, adapter_name, target, **kwargs): + # avoid eager bnb import + if is_bnb_available(): + import bitsandbytes as bnb + + from .bnb import Linear8bitLt + + if is_bnb_4bit_available(): + from .bnb import Linear4bit + + loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) + loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) + is_feedforward = kwargs.pop("is_feedforward", False) + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt): + eightbit_kwargs = kwargs.copy() + eightbit_kwargs.update( + { + "has_fp16_weights": target_base_layer.state.has_fp16_weights, + "memory_efficient_backward": target_base_layer.state.memory_efficient_backward, + "threshold": target_base_layer.state.threshold, + "index": target_base_layer.index, + } + ) + new_module = Linear8bitLt(target, adapter_name, is_feedforward=is_feedforward, **eightbit_kwargs) + elif loaded_in_4bit and isinstance(target_base_layer, bnb.nn.Linear4bit): + fourbit_kwargs = kwargs.copy() + fourbit_kwargs.update( + { + "compute_dtype": target_base_layer.compute_dtype, + "compress_statistics": target_base_layer.weight.compress_statistics, + "quant_type": target_base_layer.weight.quant_type, + } + ) + new_module = Linear4bit(target, adapter_name, is_feedforward=is_feedforward, **fourbit_kwargs) + elif isinstance(target, torch.nn.Conv2d): + new_module = Conv2d(target, adapter_name, is_feedforward=is_feedforward, **kwargs) + elif isinstance(target_base_layer, torch.nn.Linear): + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." + ) + kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = False + new_module = Linear(target, adapter_name, is_feedforward=is_feedforward, **kwargs) + elif isinstance(target_base_layer, Conv1D): + if not kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to False but the target module is `Conv1D`. " + "Setting fan_in_fan_out to True." + ) + kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = True + new_module = Linear( + target, adapter_name, is_feedforward=is_feedforward, is_target_conv_1d_layer=True, **kwargs + ) + else: + raise ValueError( + f"Target module {target} is not supported. " + f"Currently, only `torch.nn.Linear`, `torch.nn.Conv2d`, and `Conv1D` are supported." + ) + return new_module + + @staticmethod + def _check_target_module_exists(ia3_config, key): + return check_target_module_exists(ia3_config, key) + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.named_parameters(): + if self.prefix not in n: + p.requires_grad = False + + def _create_and_replace( + self, + ia3_config, + adapter_name, + target, + target_name, + parent, + current_key, + ): + # check if target module is in feedforward_modules + is_feedforward = self._check_target_module_feedforward(ia3_config, current_key) + + kwargs = { + "fan_in_fan_out": ia3_config.fan_in_fan_out, + "init_ia3_weights": ia3_config.init_ia3_weights, + "is_feedforward": is_feedforward, + "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), + "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), + } + + if isinstance(target, IA3Layer): + target.update_layer( + adapter_name, + ia3_config.init_ia3_weights, + ) + else: + new_module = self._create_new_module(ia3_config, adapter_name, target, **kwargs) + if adapter_name not in self.active_adapters: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + + @staticmethod + def _check_target_module_feedforward(ia3_config, key) -> bool: + """ + A helper private method that checks if the target module `key` matches with a feedforward module specified in + `ia3_config` + """ + if isinstance(ia3_config.feedforward_modules, str): + is_feedforward = bool(re.fullmatch(ia3_config.feedforward_modules, key)) + else: + is_feedforward = any(key.endswith(target_key) for target_key in ia3_config.feedforward_modules) + return is_feedforward + + def _replace_module(self, parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + # layers with base_layer don't need the weight to be copied, as they have a reference already + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + if self.prefix in name: + module.to(child.weight.device) + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.model, name) + + def get_peft_config_as_dict(self, inference: bool = False): + config_dict = {} + for key, value in self.peft_config.items(): + config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} + if inference: + config["inference_mode"] = True + config_dict[key] = config + return config + + def _set_adapter_layers(self, enabled=True): + for module in self.model.modules(): + if isinstance(module, (IA3Layer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self) -> None: + """Enable all adapters. + + Call this if you have previously disabled all adapters and want to re-enable them. + """ + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self) -> None: + """Disable all adapters. + + When disabling all adapters, the model output corresponds to the output of the base model. + """ + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name: str | list[str]) -> None: + """Set the active adapter(s). + + Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is + not desired, use the following code. + + ```py + >>> for name, param in model_peft.named_parameters(): + ... if ...: # some check on name (ex. if 'lora' in name) + ... param.requires_grad = False + ``` + + Args: + adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + """ + for module in self.model.modules(): + if isinstance(module, IA3Layer): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + def _prepare_adapter_config(self, peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_config["model_type"]] + if peft_config.feedforward_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING: + raise ValueError("Please specify `feedforward_modules` in `peft_config`") + peft_config.feedforward_modules = TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING[ + model_config["model_type"] + ] + return peft_config + + def _unload_and_optionally_merge( + self, merge: bool = True, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ): + r""" + This method merges the (IA)^3 layers into the base model. This is needed if someone wants to use the base model + as a standalone model. + + Args: + safe_merge (`bool`, `optional`, defaults to `False`): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + if getattr(self.model, "is_loaded_in_8bit", False): + raise ValueError("Cannot merge ia3 layers when the model is loaded in 8-bit mode") + + if getattr(self.model, "is_loaded_in_4bit", False): + raise ValueError("Cannot merge ia3 layers when the model is loaded in 4-bit mode") + + self._unloading_checks(adapter_names) + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + for key in key_list: + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + + if hasattr(target, "base_layer"): + if merge: + target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + new_module = target.modules_to_save[target.active_adapter] + if hasattr(new_module, "base_layer"): + # check if the module is itself a tuner layer + if merge: + new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names) + new_module = new_module.get_base_layer() + setattr(parent, target_name, new_module) + + return self.model + + def merge_and_unload(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> torch.nn.Module: + r""" + This method merges the IA³ layers into the base model. This is needed if someone wants to use the base model as + a standalone model. + + Args: + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + + Example: + + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import PeftModel + + >>> base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b") + >>> peft_model_id = "smangrul/falcon-40B-int4-peft-lora-sfttrainer-sample" + >>> model = PeftModel.from_pretrained(base_model, peft_model_id) + >>> merged_model = model.merge_and_unload() + ``` + """ + return self._unload_and_optionally_merge(safe_merge=safe_merge, adapter_names=adapter_names) + + def unload(self) -> torch.nn.Module: + """ + Gets back the base model by removing all the IA³ modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) + + def delete_adapter(self, adapter_name: str) -> None: + """ + Deletes an existing adapter. + + Args: + adapter_name (str): Name of the adapter to be deleted. + """ + if adapter_name not in self.peft_config: + raise ValueError(f"Adapter {adapter_name} does not exist") + del self.peft_config[adapter_name] + + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, IA3Layer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapters[:] + + self.active_adapter = new_adapter or [] diff --git a/peft/tuners/ln_tuning/__init__.py b/peft/tuners/ln_tuning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..afaae73a4462cf8f7a31aafcbb40cd89d8627519 --- /dev/null +++ b/peft/tuners/ln_tuning/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import LNTuningConfig +from .model import LNTuningModel + + +__all__ = ["LNTuningConfig", "LNTuningModel"] diff --git a/peft/tuners/ln_tuning/__pycache__/__init__.cpython-310.pyc b/peft/tuners/ln_tuning/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0fcffe94d3a9219a294589746659365cb8e295c Binary files /dev/null and b/peft/tuners/ln_tuning/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/ln_tuning/__pycache__/config.cpython-310.pyc b/peft/tuners/ln_tuning/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b823aa03d6464630e53452985c14909b529f480 Binary files /dev/null and b/peft/tuners/ln_tuning/__pycache__/config.cpython-310.pyc differ diff --git a/peft/tuners/ln_tuning/__pycache__/layer.cpython-310.pyc b/peft/tuners/ln_tuning/__pycache__/layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..875ac9422c2dc26b69c13d09200829f56b42d092 Binary files /dev/null and b/peft/tuners/ln_tuning/__pycache__/layer.cpython-310.pyc differ diff --git a/peft/tuners/ln_tuning/__pycache__/model.cpython-310.pyc b/peft/tuners/ln_tuning/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db5bd3931933c11436191e35203fbd8aff28c035 Binary files /dev/null and b/peft/tuners/ln_tuning/__pycache__/model.cpython-310.pyc differ diff --git a/peft/tuners/ln_tuning/config.py b/peft/tuners/ln_tuning/config.py new file mode 100644 index 0000000000000000000000000000000000000000..fac6a633c6cec02dc241b5e810e0c07355c0c6c4 --- /dev/null +++ b/peft/tuners/ln_tuning/config.py @@ -0,0 +1,61 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional, Union + +from peft.config import PeftConfig +from peft.utils import PeftType + + +@dataclass +class LNTuningConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a :class:`~peft.tuners.LNTuningModel`. + + Args: + target_modules (`Optional[Union[List[str], str]]`): + List of module names or regex expression of the module names to replace with LNTuning. For example, + '.*decoder.*' or '.*encoder.*'. If this is not specified, modules will be chosen according to the model + architecture. If the architecture is not known, an error will be raised -- in this case, you should specify + the target modules manually. + modules_to_save (`Optional[Union[List[str], str]]`): + List of modules to be set as trainable and saved in the final checkpoint. For example, in Sequence + Classification or Token Classification tasks, the final layer `classifier/score` are randomly initialized + and as such need to be trainable and saved. + """ + + target_modules: Optional[Union[list[str], str]] = field( + default=None, + metadata={ + "help": ( + "List of module names or regex expression of the module names to replace with LNTuning." + "For example, '.*decoder.*' or '.*encoder.*'. " + "If not specified, modules will be chosen according to the model architecture, If the architecture is " + "not known, an error will be raised -- in this case, you shoud specify the target modules manually." + ), + }, + ) + modules_to_save: Optional[Union[list[str], str]] = field( + default=None, + metadata={ + "help": "List of modules to be set as trainable and saved in the final checkpoint. " + "For example, in Sequence Classification or Token Classification tasks, " + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.LN_TUNING diff --git a/peft/tuners/ln_tuning/layer.py b/peft/tuners/ln_tuning/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..1c90d4b9e618c6c343cad8279cfdfb4e2d4c6122 --- /dev/null +++ b/peft/tuners/ln_tuning/layer.py @@ -0,0 +1,117 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from copy import deepcopy +from typing import List, Optional + +import torch +import torch.nn as nn + +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge + + +class LNTuningLayer(nn.Module, BaseTunerLayer): + """ + Selects a layer from the model. + """ + + adapter_layer_names = ("ln_tuning_layers",) + + def __init__(self, base_layer: nn.Module, adapter_name: str): + super().__init__() + self.base_layer = base_layer + self.ln_tuning_layers = nn.ModuleDict({}) + self.update_layer(self.base_layer, adapter_name) + self._active_adapter = adapter_name + self.merged_adapters = [] + + def update_layer(self, layer: nn.Module, adapter_name: str): + self.ln_tuning_layers[adapter_name] = deepcopy(layer) + + def enable_adapters(self, enabled: bool) -> None: + """Toggle the enabling and disabling of adapters + + Takes care of setting the requires_grad flag for the adapter weights. + + Args: + enabled (bool): True to enable adapters, False to disable adapters + """ + if enabled: + self.set_adapter(self.active_adapters) + self._disable_adapters = False + else: + if self.merged: + self.unmerge() + # disable grads on all adapter layers + for layer_name in self.adapter_layer_names: + layer = getattr(self, layer_name) + layer.requires_grad_(False) + self._disable_adapters = True + + def merge(self, adapter_names: Optional[List[str]] = None): + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + if len(adapter_names) > 1: + raise ValueError( + f"Trying to merge {len(adapter_names)} adapters, but LN " + f"tuning does not allow merging more than one adapter at a time" + ) + merged_adapters = set(self.merged_adapters) + if merged_adapters: + warnings.warn(f"Already merged with {merged_adapters}. Unmerging first.") + self.unmerge() + + self.base_layer, self.ln_tuning_layers[adapter_names[0]] = ( + self.ln_tuning_layers[adapter_names[0]], + self.base_layer, + ) + self.merged_adapters.append(adapter_names[0]) + + def unmerge(self): + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + # popping one element is sufficient because LN + # tuning does not allow merging more than one adapter at a time. + merged_name = self.merged_adapters.pop() + self.base_layer, self.ln_tuning_layers[merged_name] = ( + self.ln_tuning_layers[merged_name], + self.base_layer, + ) + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + if len(self.active_adapters) != 1: + raise ValueError( + f"Trying to run forward with {len(self.active_adapters)} active " + f"adapters, but LN tuning does not allow inference with more than one adapter at a time" + ) + active_adapter = self.active_adapters[0] + result = self.ln_tuning_layers[active_adapter](x, *args, **kwargs) + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "ln_tuning." + rep diff --git a/peft/tuners/ln_tuning/model.py b/peft/tuners/ln_tuning/model.py new file mode 100644 index 0000000000000000000000000000000000000000..3ec8eeb6200fa54ab49718292e528f7dd34dde59 --- /dev/null +++ b/peft/tuners/ln_tuning/model.py @@ -0,0 +1,201 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import warnings +from typing import Optional + +from torch import nn +from torch.nn.modules import Module +from tqdm import tqdm + +from peft.config import PeftConfig +from peft.tuners.tuners_utils import BaseTuner, _get_submodules, check_target_module_exists +from peft.utils import TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, ModulesToSaveWrapper + +from .layer import LNTuningLayer + + +class LNTuningModel(BaseTuner): + """ + Creates LayerNorm tuning from a pretrained transformer model. + + The method is described in detail in https://arxiv.org/abs/2312.11420. + + Args: + model ([`torch.nn.Module`]): The model to be adapted. + config ([`LNTuningConfig`]): The configuration of the Lora model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + 'torch.nn.Module': The adapted model with LayerNorm tuned on. + + Example: + + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import get_peft_model, TaskType, LNTuningConfig + + >>> peft_config = LNTuningConfig( + ... task_type=TaskType.CAUSAL_LM, + ... ) + + >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> model = get_peft_model(model, peft_config) + >>> model.print_trainable_parameters() + ``` + + **Attributes**: + - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. + - **peft_config** ([`LNTuningConfig`]): The configuration of the Lora model. + """ + + prefix: str = "ln_tuning_" + + def __init__(self, model, config, adapter_name) -> None: + # self.adapter_name = adapter_name + super().__init__(model, config, adapter_name) + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.model, name) + + # TODO: here need to handle the modules_to_save rather than the target_modules + @staticmethod + def _prepare_adapter_config(peft_config: PeftConfig, model_config: dict) -> PeftConfig: + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return peft_config + + def _create_and_replace( + self, + peft_config: PeftConfig, + adapter_name: str, + target: Module, + target_name: str, + parent: Module, + current_key: str, + ) -> None: + # replace the original module with a same new module + new_module = self._create_new_module(peft_config, target, adapter_name) + if adapter_name != self.active_adapter: + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + + def _create_new_module( + self, + peft_config: PeftConfig, + target: Module, + adapter_name: str, + ) -> Module: + if not isinstance(target, LNTuningLayer): + new_module = LNTuningLayer(target, adapter_name) + else: + new_module = target + new_module.update_layer(target.base_layer, adapter_name) + return new_module + + def _replace_module(self, parent: Module, child_name: str, new_module: Module, child: Module) -> None: + setattr(parent, child_name, new_module) + + if hasattr(child, "base_layer"): + child = child.base_layer + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + for name, module in new_module.named_modules(): + weight = child.qweight if hasattr(child, "qweight") else child.weight + module.to(weight.device) + + def _mark_only_adapters_as_trainable(self, model: Module): + for n, p in model.named_parameters(): + if self.prefix not in n: + p.requires_grad = False + else: + p.requires_grad = True + + def _check_target_module_exists(self, peft_config: PeftConfig, key: str) -> bool: + return check_target_module_exists(peft_config, key) + + def _set_adapter_layers(self, enabled: bool) -> None: + for module in self.model.modules(): + if isinstance(module, (LNTuningLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self) -> None: + """Enable all adapters. + + Call this if you have previously disabled all adapters and want to re-enable them. + """ + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self) -> None: + """Disable all adapters. + + When disabling all adapters, the model output corresponds to the output of the base model. + """ + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name: str) -> None: + for module in self.model.modules(): + if isinstance(module, LNTuningLayer): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + def _unload_and_optionally_merge( + self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[list[str]] = None, + ): + self._unloading_checks(adapter_names) + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + desc = "Unloading adapters " + ("and merging " if merge else "") + "model" + + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + + if hasattr(target, "base_layer"): + if merge: + target.merge(adapter_names) + self._replace_module(parent, target_name, target.get_base_layer(), target) + + return self.model + + def unload(self): + return self._unload_and_optionally_merge(merge=False) + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> nn.Module: + return self._unload_and_optionally_merge(merge=True) diff --git a/peft/tuners/loha/__init__.py b/peft/tuners/loha/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2f39deee17ab9cb0e24b3a98d8b54eb7eeb27c1f --- /dev/null +++ b/peft/tuners/loha/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import LoHaConfig +from .layer import Conv2d, Linear, LoHaLayer +from .model import LoHaModel + + +__all__ = ["LoHaConfig", "LoHaModel", "Conv2d", "Linear", "LoHaLayer"] diff --git a/peft/tuners/loha/__pycache__/__init__.cpython-310.pyc b/peft/tuners/loha/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38f9e7b827d56cdeaeddd2063d4aad14c17dd330 Binary files /dev/null and b/peft/tuners/loha/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/loha/__pycache__/config.cpython-310.pyc b/peft/tuners/loha/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a0e96c2544adb484e03dfb5261aff7d09196135 Binary files /dev/null and b/peft/tuners/loha/__pycache__/config.cpython-310.pyc differ diff --git a/peft/tuners/loha/__pycache__/layer.cpython-310.pyc b/peft/tuners/loha/__pycache__/layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38d21cd45886c8a6b64bb74180f7dd283bbcfc11 Binary files /dev/null and b/peft/tuners/loha/__pycache__/layer.cpython-310.pyc differ diff --git a/peft/tuners/loha/__pycache__/model.cpython-310.pyc b/peft/tuners/loha/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a6946f6fb2c3971c0f5a6b3782b5d92380726cb Binary files /dev/null and b/peft/tuners/loha/__pycache__/model.cpython-310.pyc differ diff --git a/peft/tuners/loha/config.py b/peft/tuners/loha/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c38ba7828b59668d87113bd53a8cd6bd7bd570e2 --- /dev/null +++ b/peft/tuners/loha/config.py @@ -0,0 +1,121 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import List, Optional, Union + +from peft.tuners.lycoris_utils import LycorisConfig +from peft.utils import PeftType + + +@dataclass +class LoHaConfig(LycorisConfig): + """ + This is the configuration class to store the configuration of a [`LoHaModel`]. + + Args: + r (`int`): + LoHa rank. + alpha (`int`): + The alpha parameter for LoHa scaling. + rank_dropout (`float`): + The dropout probability for rank dimension during training. + module_dropout (`float`): + The dropout probability for disabling LoHa modules during training. + use_effective_conv2d (`bool`): + Use parameter effective decomposition for Conv2d with ksize > 1 ("Proposition 3" from FedPara paper). + target_modules (`Optional[Union[List[str], str]]`): + The names of the modules to apply the adapter to. If this is specified, only the modules with the specified + names will be replaced. When passing a string, a regex match will be performed. When passing a list of + strings, either an exact match will be performed or it is checked if the name of the module ends with any + of the passed strings. If this is specified as 'all-linear', then all linear/Conv1D modules are chosen, + excluding the output layer. If this is not specified, modules will be chosen according to the model + architecture. If the architecture is not known, an error will be raised -- in this case, you should specify + the target modules manually. + init_weights (`bool`): + Whether to perform initialization of adapter weights. This defaults to `True`, passing `False` is + discouraged. + layers_to_transform (`Union[List[int], int]`): + The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices + that are specified in this list. If a single integer is passed, it will apply the transformations on the + layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None`. + rank_pattern (`dict`): + The mapping from layer names or regexp expression to ranks which are different from the default rank + specified by `r`. + alpha_pattern (`dict`): + The mapping from layer names or regexp expression to alphas which are different from the default alpha + specified by `alpha`. + modules_to_save (`Optional[List[str]]`): + List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint. + """ + + r: int = field(default=8, metadata={"help": "LoHa rank"}) + alpha: int = field(default=8, metadata={"help": "LoHa alpha"}) + rank_dropout: float = field( + default=0.0, metadata={"help": "The dropout probability for rank dimension during training"} + ) + module_dropout: float = field( + default=0.0, metadata={"help": "The dropout probability for disabling LoHa modules during training"} + ) + use_effective_conv2d: bool = field( + default=False, + metadata={ + "help": 'Use parameter effective decomposition for Conv2d 3x3 with ksize > 1 ("Proposition 3" from FedPara paper)' + }, + ) + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to replace with LoHa." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " + "This can also be a wildcard 'all-linear' which matches all linear/Conv1D layers except the output layer." + }, + ) + init_weights: bool = field( + default=True, + metadata={ + "help": ( + "Whether to initialize the weights of the LoHa layers with their default initialization. Don't change " + "this setting, except if you know exactly what you're doing." + ), + }, + ) + layers_to_transform: Optional[Union[List[int], int]] = field( + default=None, + metadata={ + "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index." + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." + }, + ) + modules_to_save: Optional[List[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from LoHA layers to be set as trainable and saved in the final checkpoint. " + "For example, in Sequence Classification or Token Classification tasks, " + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.LOHA + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) diff --git a/peft/tuners/loha/layer.py b/peft/tuners/loha/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..b958decfad80ebb6eb5cb53f51365c2f7b5b1beb --- /dev/null +++ b/peft/tuners/loha/layer.py @@ -0,0 +1,375 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Set, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from peft.tuners.lycoris_utils import LycorisLayer + + +class LoHaLayer(nn.Module, LycorisLayer): + # All names of layers that may contain adapter weights + adapter_layer_names = ("hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b", "hada_t1", "hada_t2") + # other_param_names is defined on parent class + + def __init__(self, base_layer: nn.Module): + super().__init__() + LycorisLayer.__init__(self, base_layer) + + # LoHa info + self.hada_w1_a = nn.ParameterDict({}) + self.hada_w1_b = nn.ParameterDict({}) + self.hada_w2_a = nn.ParameterDict({}) + self.hada_w2_b = nn.ParameterDict({}) + self.hada_t1 = nn.ParameterDict({}) + self.hada_t2 = nn.ParameterDict({}) + + @property + def _available_adapters(self) -> Set[str]: + return {*self.hada_w1_a, *self.hada_w1_b, *self.hada_w2_a, *self.hada_w2_b, *self.hada_t1, *self.hada_t2} + + def create_adapter_parameters(self, adapter_name: str, r: int, shape: Tuple[int, ...]): + # https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L130C9-L143C75 + if len(shape) == 4: + self.hada_t1[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], shape[3])) + self.hada_w1_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode + self.hada_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode + + self.hada_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], shape[3])) + self.hada_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode + self.hada_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode + else: + self.hada_w1_a[adapter_name] = nn.Parameter(torch.empty(shape[0], r)) + self.hada_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) + + self.hada_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0], r)) + self.hada_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) + + def reset_adapter_parameters(self, adapter_name: str): + # Original implementation performs initialization with normal distribution + # https://github.com/KohakuBlueleaf/LyCORIS/blob/3549fdef8f564761d68b695a08ef88b1122fdedc/lycoris/modules/loha.py#L158 + + # FedPara paper proposes to perform He initialization, let's stick with it + # It is enough to initialize only single matrix with zeros to make adapter do nothing after initialization + if adapter_name in self.hada_w1_a.keys(): + nn.init.kaiming_uniform_(self.hada_w1_a[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.hada_w1_b[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.hada_w2_a[adapter_name], a=math.sqrt(5)) + nn.init.zeros_(self.hada_w2_b[adapter_name]) + if adapter_name in self.hada_t1.keys(): + nn.init.kaiming_uniform_(self.hada_t1[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.hada_t2[adapter_name], a=math.sqrt(5)) + + def reset_adapter_parameters_random(self, adapter_name: str): + # Original implementation performs initialization with normal distribution + # https://github.com/KohakuBlueleaf/LyCORIS/blob/3549fdef8f564761d68b695a08ef88b1122fdedc/lycoris/modules/loha.py#L158 + + # FedPara paper proposes to perform He initialization, let's stick with it + # It is enough to initialize only single matrix with zeros to make adapter do nothing after initialization + if adapter_name in self.hada_w1_a.keys(): + nn.init.kaiming_uniform_(self.hada_w1_a[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.hada_w1_b[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.hada_w2_a[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.hada_w2_b[adapter_name], a=math.sqrt(5)) + if adapter_name in self.hada_t1.keys(): + nn.init.kaiming_uniform_(self.hada_t1[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.hada_t2[adapter_name], a=math.sqrt(5)) + + def update_layer( + self, + adapter_name: str, + r: int, + alpha: float, + rank_dropout: float, + module_dropout: float, + init_weights: bool, + use_effective_conv2d: bool = False, + **kwargs, + ) -> None: + """Internal function to create loha adapter + + Args: + adapter_name (`str`): Name for the adapter to add. + r (`int`): Rank for the added adapter. + alpha (`float`): Alpha for the added adapter. + rank_dropout (`float`): The dropout probability for rank dimension during training. + module_dropout (`float`): The dropout probability for disabling adapter during training. + init_weights (`bool`): Whether to initialize weights. + use_effective_conv2d (`bool`, *optional*, defaults to `False`): + Use parameter effective decomposition for Conv2d with ksize > 1. + """ + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + self.alpha[adapter_name] = alpha + self.scaling[adapter_name] = alpha / r + self.rank_dropout[adapter_name] = rank_dropout + self.module_dropout[adapter_name] = module_dropout + + # Determine shape of LoHa weights + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + shape = tuple(base_layer.weight.shape) + elif isinstance(base_layer, nn.Conv2d): + use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1) + if use_effective_conv2d: + shape = (base_layer.out_channels, base_layer.in_channels, *base_layer.kernel_size) + else: + shape = ( + base_layer.out_channels, + base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1], + ) + else: + raise TypeError(f"LoHa is not implemented for base layers of type {type(base_layer).__name__}") + + # Create weights with provided shape + self.create_adapter_parameters(adapter_name, r, shape) + + # Initialize weights + if init_weights: + self.reset_adapter_parameters(adapter_name) + else: + self.reset_adapter_parameters_random(adapter_name) + + # Move new weights to device + weight = getattr(self.get_base_layer(), "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + if weight.dtype.is_floating_point or weight.dtype.is_complex: + self.to(weight.device, dtype=weight.dtype) + else: + self.to(weight.device) + self.set_adapter(self.active_adapters) + + def get_delta_weight(self, adapter_name: str) -> torch.Tensor: + # https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L178 + if adapter_name in self.hada_t1.keys(): + weight = make_weight_cp( + self.hada_t1[adapter_name], + self.hada_w1_a[adapter_name], + self.hada_w1_b[adapter_name], + self.hada_t2[adapter_name], + self.hada_w2_a[adapter_name], + self.hada_w2_b[adapter_name], + scale=torch.tensor(self.scaling[adapter_name]), + ) + else: + weight = make_weight( + self.hada_w1_a[adapter_name], + self.hada_w1_b[adapter_name], + self.hada_w2_a[adapter_name], + self.hada_w2_b[adapter_name], + scale=torch.tensor(self.scaling[adapter_name]), + ) + + base_layer = self.get_base_layer() + weight = weight.reshape(base_layer.weight.shape) + + # Perform rank dropout during training - drop rows of addition weights + rank_dropout = self.rank_dropout[adapter_name] + if self.training and rank_dropout: + drop = (torch.rand(weight.size(0)) > rank_dropout).to(weight.dtype) + drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device) + # TODO: Investigate if there should be a scaler like in normal dropout during training + # Original implementation doesn't have it + # https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L193 + drop /= drop.mean() + weight *= drop + + return weight + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + + # Execute all the adapters + for active_adapter in self.active_adapters: + if active_adapter not in self._available_adapters: + continue + + module_dropout = self.module_dropout[active_adapter] + + # Modify current execution weights + if (not self.training) or (self.training and torch.rand(1) > module_dropout): + result = result + self._get_delta_activations(active_adapter, x, *args, **kwargs) + + result = result.to(previous_dtype) + return result + + +class Linear(LoHaLayer): + """LoHa implemented in Linear layer""" + + def __init__( + self, + base_layer: nn.Module, + adapter_name: str = "default", + r: int = 0, + alpha: float = 0.0, + rank_dropout: float = 0.0, + module_dropout: float = 0.0, + init_weights: bool = True, + **kwargs, + ): + super().__init__(base_layer) + + # Create adapter and set it active + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, **kwargs) + + def _get_delta_activations( + self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any + ) -> torch.Tensor: + delta_weight = self.get_delta_weight(adapter_name) + # don't add bias here, because the bias is already included in the output of the base_layer + return F.linear(input, delta_weight) + + def __repr__(self) -> str: + rep = super().__repr__() + return "loha." + rep + + +class Conv2d(LoHaLayer): + """LoHa implemented in Conv2d layer""" + + def __init__( + self, + base_layer: nn.Module, + adapter_name: str = "default", + r: int = 0, + alpha: float = 0.0, + rank_dropout: float = 0.0, + module_dropout: float = 0.0, + use_effective_conv2d: bool = False, + init_weights: bool = True, + **kwargs, + ): + super().__init__(base_layer) + + # Create adapter and set it active + self._active_adapter = adapter_name + self.update_layer( + adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs + ) + + def _get_delta_activations( + self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any + ) -> torch.Tensor: + delta_weight = self.get_delta_weight(adapter_name) + # don't add bias here, because the bias is already included in the output of the base_layer + base_layer = self.get_base_layer() + return F.conv2d( + input, + delta_weight, + stride=base_layer.stride, + padding=base_layer.padding, + dilation=base_layer.dilation, + groups=base_layer.groups, + ) + + def __repr__(self) -> str: + rep = super().__repr__() + return "loha." + rep + + +# Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L9 + + +class HadaWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, w1a, w1b, w2a, w2b, scale=torch.tensor(1)): + ctx.save_for_backward(w1a, w1b, w2a, w2b, scale) + diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale + return diff_weight + + @staticmethod + def backward(ctx, grad_out): + (w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors + grad_out = grad_out * scale + temp = grad_out * (w2a @ w2b) + grad_w1a = temp @ w1b.T + grad_w1b = w1a.T @ temp + + temp = grad_out * (w1a @ w1b) + grad_w2a = temp @ w2b.T + grad_w2b = w2a.T @ temp + + del temp + return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None + + +class HadaWeightCP(torch.autograd.Function): + @staticmethod + def forward(ctx, t1, w1a, w1b, t2, w2a, w2b, scale=torch.tensor(1)): + ctx.save_for_backward(t1, w1a, w1b, t2, w2a, w2b, scale) + + rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", t1, w1b, w1a) + rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", t2, w2b, w2a) + + return rebuild1 * rebuild2 * scale + + @staticmethod + def backward(ctx, grad_out): + (t1, w1a, w1b, t2, w2a, w2b, scale) = ctx.saved_tensors + grad_out = grad_out * scale + + temp = torch.einsum("i j k l, j r -> i r k l", t2, w2b) + rebuild = torch.einsum("i j k l, i r -> r j k l", temp, w2a) + + grad_w = rebuild * grad_out + del rebuild + + grad_w1a = torch.einsum("r j k l, i j k l -> r i", temp, grad_w) + grad_temp = torch.einsum("i j k l, i r -> r j k l", grad_w, w1a.T) + del grad_w, temp + + grad_w1b = torch.einsum("i r k l, i j k l -> r j", t1, grad_temp) + grad_t1 = torch.einsum("i j k l, j r -> i r k l", grad_temp, w1b.T) + del grad_temp + + temp = torch.einsum("i j k l, j r -> i r k l", t1, w1b) + rebuild = torch.einsum("i j k l, i r -> r j k l", temp, w1a) + + grad_w = rebuild * grad_out + del rebuild + + grad_w2a = torch.einsum("r j k l, i j k l -> r i", temp, grad_w) + grad_temp = torch.einsum("i j k l, i r -> r j k l", grad_w, w2a.T) + del grad_w, temp + + grad_w2b = torch.einsum("i r k l, i j k l -> r j", t2, grad_temp) + grad_t2 = torch.einsum("i j k l, j r -> i r k l", grad_temp, w2b.T) + del grad_temp + return grad_t1, grad_w1a, grad_w1b, grad_t2, grad_w2a, grad_w2b, None + + +def make_weight(w1a, w1b, w2a, w2b, scale): + return HadaWeight.apply(w1a, w1b, w2a, w2b, scale) + + +def make_weight_cp(t1, w1a, w1b, t2, w2a, w2b, scale): + return HadaWeightCP.apply(t1, w1a, w1b, t2, w2a, w2b, scale) diff --git a/peft/tuners/loha/model.py b/peft/tuners/loha/model.py new file mode 100644 index 0000000000000000000000000000000000000000..6f1aaac9d5918959edb04e826c6a81edf02a625e --- /dev/null +++ b/peft/tuners/loha/model.py @@ -0,0 +1,114 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from itertools import chain +from typing import Dict, Type, Union + +import torch +from torch import nn + +from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner + +from .layer import Conv2d, Linear, LoHaLayer + + +class LoHaModel(LycorisTuner): + """ + Creates Low-Rank Hadamard Product model from a pretrained model. The method is partially described in + https://arxiv.org/abs/2108.06098 Current implementation heavily borrows from + https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py + + Args: + model (`torch.nn.Module`): The model to which the adapter tuner layers will be attached. + config ([`LoHaConfig`]): The configuration of the LoHa model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + `torch.nn.Module`: The LoHa model. + + Example: + ```py + >>> from diffusers import StableDiffusionPipeline + >>> from peft import LoHaModel, LoHaConfig + + >>> config_te = LoHaConfig( + ... r=8, + ... lora_alpha=32, + ... target_modules=["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"], + ... rank_dropout=0.0, + ... module_dropout=0.0, + ... init_weights=True, + ... ) + >>> config_unet = LoHaConfig( + ... r=8, + ... lora_alpha=32, + ... target_modules=[ + ... "proj_in", + ... "proj_out", + ... "to_k", + ... "to_q", + ... "to_v", + ... "to_out.0", + ... "ff.net.0.proj", + ... "ff.net.2", + ... ], + ... rank_dropout=0.0, + ... module_dropout=0.0, + ... init_weights=True, + ... use_effective_conv2d=True, + ... ) + + >>> model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> model.text_encoder = LoHaModel(model.text_encoder, config_te, "default") + >>> model.unet = LoHaModel(model.unet, config_unet, "default") + ``` + + **Attributes**: + - **model** ([`~torch.nn.Module`]) -- The model to be adapted. + - **peft_config** ([`LoHaConfig`]): The configuration of the LoHa model. + """ + + prefix: str = "hada_" + layers_mapping: Dict[Type[torch.nn.Module], Type[LoHaLayer]] = { + torch.nn.Conv2d: Conv2d, + torch.nn.Linear: Linear, + } + + def _create_and_replace( + self, + config: LycorisConfig, + adapter_name: str, + target: Union[LoHaLayer, nn.Module], + target_name: str, + parent: nn.Module, + current_key: str, + ) -> None: + """ + A private method to create and replace the target module with the adapter module. + """ + + # Regexp matching - Find key which matches current target_name in patterns provided + pattern_keys = list(chain(config.rank_pattern.keys(), config.alpha_pattern.keys())) + target_name_key = next(filter(lambda key: re.match(rf"(.*\.)?{key}$", current_key), pattern_keys), target_name) + + kwargs = config.to_dict() + kwargs["r"] = config.rank_pattern.get(target_name_key, config.r) + kwargs["alpha"] = config.alpha_pattern.get(target_name_key, config.alpha) + + if isinstance(target, LoHaLayer): + target.update_layer(adapter_name, **kwargs) + else: + new_module = self._create_new_module(config, adapter_name, target, **kwargs) + self._replace_module(parent, target_name, new_module, target) diff --git a/peft/tuners/lokr/__init__.py b/peft/tuners/lokr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..408cf2a54ae4c0befa9e3f1cad4ff93d71cfedc5 --- /dev/null +++ b/peft/tuners/lokr/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import LoKrConfig +from .layer import Conv2d, Linear, LoKrLayer +from .model import LoKrModel + + +__all__ = ["LoKrConfig", "LoKrModel", "Conv2d", "Linear", "LoKrLayer"] diff --git a/peft/tuners/lokr/__pycache__/__init__.cpython-310.pyc b/peft/tuners/lokr/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b10a867a213b29ff420823e1b1c62d2ab2bc429 Binary files /dev/null and b/peft/tuners/lokr/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/lokr/__pycache__/config.cpython-310.pyc b/peft/tuners/lokr/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c9391e426b7d6130279aa192a4472074f06b02e Binary files /dev/null and b/peft/tuners/lokr/__pycache__/config.cpython-310.pyc differ diff --git a/peft/tuners/lokr/__pycache__/layer.cpython-310.pyc b/peft/tuners/lokr/__pycache__/layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8243263de1845d3e276214e42027fbf341222ee Binary files /dev/null and b/peft/tuners/lokr/__pycache__/layer.cpython-310.pyc differ diff --git a/peft/tuners/lokr/__pycache__/model.cpython-310.pyc b/peft/tuners/lokr/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..589c95fb409d59443f4eaf0e65824824527c8b6e Binary files /dev/null and b/peft/tuners/lokr/__pycache__/model.cpython-310.pyc differ diff --git a/peft/tuners/lokr/config.py b/peft/tuners/lokr/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c8d60a7463c59e114e42658965ac7c81f3fb563e --- /dev/null +++ b/peft/tuners/lokr/config.py @@ -0,0 +1,127 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import List, Optional, Union + +from peft.tuners.lycoris_utils import LycorisConfig +from peft.utils import PeftType + + +@dataclass +class LoKrConfig(LycorisConfig): + """ + Configuration class of [`LoKrModel`]. + + Args: + r (`int`): + LoKr rank. + alpha (`int`): + The alpha parameter for LoKr scaling. + rank_dropout (`float`): + The dropout probability for rank dimension during training. + module_dropout (`float`): + The dropout probability for disabling LoKr modules during training. + use_effective_conv2d (`bool`): + Use parameter effective decomposition for Conv2d with ksize > 1 ("Proposition 3" from FedPara paper). + decompose_both (`bool`): + Perform rank decomposition of left kronecker product matrix. + decompose_factor (`int`): + Kronecker product decomposition factor. + target_modules (`Optional[Union[List[str], str]]`): + The names of the modules to apply the adapter to. If this is specified, only the modules with the specified + names will be replaced. When passing a string, a regex match will be performed. When passing a list of + strings, either an exact match will be performed or it is checked if the name of the module ends with any + of the passed strings. If this is specified as 'all-linear', then all linear/Conv1D modules are chosen, + excluding the output layer. If this is not specified, modules will be chosen according to the model + architecture. If the architecture is not known, an error will be raised -- in this case, you should specify + the target modules manually. + init_weights (`bool`): + Whether to perform initialization of adapter weights. This defaults to `True`, passing `False` is + discouraged. + layers_to_transform (`Union[List[int], int]`): + The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices + that are specified in this list. If a single integer is passed, it will apply the transformations on the + layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None`. + rank_pattern (`dict`): + The mapping from layer names or regexp expression to ranks which are different from the default rank + specified by `r`. + alpha_pattern (`dict`): + The mapping from layer names or regexp expression to alphas which are different from the default alpha + specified by `alpha`. + modules_to_save (`Optional[List[str]]`): + List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint. + """ + + r: int = field(default=8, metadata={"help": "LoKr rank"}) + alpha: int = field(default=8, metadata={"help": "LoKr alpha"}) + rank_dropout: float = field( + default=0.0, metadata={"help": "The dropout probability for rank dimension during training"} + ) + module_dropout: float = field( + default=0.0, metadata={"help": "The dropout probability for disabling LoKr modules during training"} + ) + use_effective_conv2d: bool = field( + default=False, + metadata={ + "help": 'Use parameter effective decomposition for Conv2d 3x3 with ksize > 1 ("Proposition 3" from FedPara paper)' + }, + ) + decompose_both: bool = field( + default=False, + metadata={"help": "Perform rank decomposition of left kronecker product matrix."}, + ) + decompose_factor: int = field(default=-1, metadata={"help": "Kronecker product decomposition factor."}) + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to replace with LoKr." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " + "This can also be a wildcard 'all-linear' which matches all linear/Conv1D layers except the output layer." + }, + ) + init_weights: bool = field( + default=True, + metadata={ + "help": ( + "Whether to initialize the weights of the LoKr layers with their default initialization. Don't change " + "this setting, except if you know exactly what you're doing." + ), + }, + ) + layers_to_transform: Optional[Union[List[int], int]] = field( + default=None, + metadata={ + "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index." + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." + }, + ) + modules_to_save: Optional[List[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from LoKr layers to be set as trainable and saved in the final checkpoint. " + "For example, in Sequence Classification or Token Classification tasks, " + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.LOKR diff --git a/peft/tuners/lokr/layer.py b/peft/tuners/lokr/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..28e4e5ca61bc4b6826d740fb6e1b77b583f891a7 --- /dev/null +++ b/peft/tuners/lokr/layer.py @@ -0,0 +1,409 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from peft.tuners.lycoris_utils import LycorisLayer + + +class LoKrLayer(nn.Module, LycorisLayer): + # All names of layers that may contain adapter weights + adapter_layer_names = ( + "lokr_w1", + "lokr_w1_a", + "lokr_w1_b", + "lokr_w2", + "lokr_w2_a", + "lokr_w2_b", + "lokr_t2", + ) + # other_param_names is defined on parent class + + def __init__(self, base_layer: nn.Module) -> None: + super().__init__() + LycorisLayer.__init__(self, base_layer) + + # LoKr info + self.lokr_w1 = nn.ParameterDict({}) + self.lokr_w1_a = nn.ParameterDict({}) + self.lokr_w1_b = nn.ParameterDict({}) + self.lokr_w2 = nn.ParameterDict({}) + self.lokr_w2_a = nn.ParameterDict({}) + self.lokr_w2_b = nn.ParameterDict({}) + self.lokr_t2 = nn.ParameterDict({}) + + @property + def _available_adapters(self) -> Set[str]: + return { + *self.lokr_w1, + *self.lokr_w1_a, + *self.lokr_w1_b, + *self.lokr_w2, + *self.lokr_w2_a, + *self.lokr_w2_b, + *self.lokr_t2, + } + + def create_adapter_parameters( + self, + adapter_name: str, + r: int, + shape, + use_w1: bool, + use_w2: bool, + use_effective_conv2d: bool, + ): + if use_w1: + self.lokr_w1[adapter_name] = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) + else: + self.lokr_w1_a[adapter_name] = nn.Parameter(torch.empty(shape[0][0], r)) + self.lokr_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][0])) + + if len(shape) == 4: + # Conv2d + if use_w2: + self.lokr_w2[adapter_name] = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *shape[2:])) + elif use_effective_conv2d: + self.lokr_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], shape[3])) + self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0][1])) # b, 1-mode + self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1])) # d, 2-mode + else: + self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0][1], r)) + self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1] * shape[2] * shape[3])) + else: + # Linear + if use_w2: + self.lokr_w2[adapter_name] = nn.Parameter(torch.empty(shape[0][1], shape[1][1])) + else: + self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0][1], r)) + self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1])) + + def reset_adapter_parameters(self, adapter_name: str): + if adapter_name in self.lokr_w1: + nn.init.zeros_(self.lokr_w1[adapter_name]) + else: + nn.init.zeros_(self.lokr_w1_a[adapter_name]) + nn.init.kaiming_uniform_(self.lokr_w1_b[adapter_name], a=math.sqrt(5)) + + if adapter_name in self.lokr_w2: + nn.init.kaiming_uniform_(self.lokr_w2[adapter_name], a=math.sqrt(5)) + else: + nn.init.kaiming_uniform_(self.lokr_w2_a[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lokr_w2_b[adapter_name], a=math.sqrt(5)) + + if adapter_name in self.lokr_t2: + nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5)) + + def reset_adapter_parameters_random(self, adapter_name: str): + if adapter_name in self.lokr_w1: + nn.init.kaiming_uniform_(self.lokr_w1[adapter_name], a=math.sqrt(5)) + else: + nn.init.kaiming_uniform_(self.lokr_w1_a[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lokr_w1_b[adapter_name], a=math.sqrt(5)) + + if adapter_name in self.lokr_w2: + nn.init.kaiming_uniform_(self.lokr_w2[adapter_name], a=math.sqrt(5)) + else: + nn.init.kaiming_uniform_(self.lokr_w2_a[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lokr_w2_b[adapter_name], a=math.sqrt(5)) + + if adapter_name in self.lokr_t2: + nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5)) + + def update_layer( + self, + adapter_name: str, + r: int, + alpha: float, + rank_dropout: float, + module_dropout: float, + init_weights: bool, + use_effective_conv2d: bool, + decompose_both: bool, + decompose_factor: int, + **kwargs, + ) -> None: + """Internal function to create lokr adapter + + Args: + adapter_name (`str`): Name for the adapter to add. + r (`int`): Rank for the added adapter. + alpha (`float`): Alpha for the added adapter. + rank_dropout (`float`): The dropout probability for rank dimension during training + module_dropout (`float`): The dropout probability for disabling adapter during training. + init_weights (`bool`): Whether to initialize adapter weights. + use_effective_conv2d (`bool`): Use parameter effective decomposition for Conv2d with ksize > 1. + decompose_both (`bool`): Perform rank decomposition of left kronecker product matrix. + decompose_factor (`int`): Kronecker product decomposition factor. + """ + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + self.alpha[adapter_name] = alpha + self.scaling[adapter_name] = alpha / r + self.rank_dropout[adapter_name] = rank_dropout + self.module_dropout[adapter_name] = module_dropout + base_layer = self.get_base_layer() + + # Determine shape of LoKr weights + if isinstance(base_layer, nn.Linear): + in_dim, out_dim = base_layer.in_features, base_layer.out_features + + in_m, in_n = factorization(in_dim, decompose_factor) + out_l, out_k = factorization(out_dim, decompose_factor) + shape = ((out_l, out_k), (in_m, in_n)) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d + + use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2) + use_w2 = not (r < max(shape[0][1], shape[1][1]) / 2) + use_effective_conv2d = False + elif isinstance(base_layer, nn.Conv2d): + in_dim, out_dim = base_layer.in_channels, base_layer.out_channels + k_size = base_layer.kernel_size + + in_m, in_n = factorization(in_dim, decompose_factor) + out_l, out_k = factorization(out_dim, decompose_factor) + shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size) + + use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2) + use_w2 = r >= max(shape[0][1], shape[1][1]) / 2 + use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1) + else: + raise TypeError(f"LoKr is not implemented for base layers of type {type(base_layer).__name__}") + + # Create weights with provided shape + self.create_adapter_parameters(adapter_name, r, shape, use_w1, use_w2, use_effective_conv2d) + + # Initialize weights + if init_weights: + self.reset_adapter_parameters(adapter_name) + else: + self.reset_adapter_parameters_random(adapter_name) + + # Move new weights to device + weight = getattr(self.get_base_layer(), "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + if weight.dtype.is_floating_point or weight.dtype.is_complex: + self.to(weight.device, dtype=weight.dtype) + else: + self.to(weight.device) + self.set_adapter(self.active_adapters) + + def get_delta_weight(self, adapter_name: str) -> torch.Tensor: + # https://github.com/KohakuBlueleaf/LyCORIS/blob/e4259b870d3354a9615a96be61cb5d07455c58ea/lycoris/modules/lokr.py#L224 + if adapter_name in self.lokr_w1: + w1 = self.lokr_w1[adapter_name] + else: + w1 = self.lokr_w1_a[adapter_name] @ self.lokr_w1_b[adapter_name] + + if adapter_name in self.lokr_w2: + w2 = self.lokr_w2[adapter_name] + elif adapter_name in self.lokr_t2: + w2 = make_weight_cp(self.lokr_t2[adapter_name], self.lokr_w2_a[adapter_name], self.lokr_w2_b[adapter_name]) + else: + w2 = self.lokr_w2_a[adapter_name] @ self.lokr_w2_b[adapter_name] + + # Make weights with Kronecker product + weight = make_kron(w1, w2) + weight = weight.reshape(self.get_base_layer().weight.shape) + + # Perform rank dropout during training - drop rows of addition weights + rank_dropout = self.rank_dropout[adapter_name] + if self.training and rank_dropout: + drop = (torch.rand(weight.size(0)) > rank_dropout).float() + drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device) + drop /= drop.mean() + weight *= drop + + return weight + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + + # Execute all the adapters + for active_adapter in self.active_adapters: + if active_adapter not in self._available_adapters: + continue + + module_dropout = self.module_dropout[active_adapter] + + # Modify current execution weights + if (not self.training) or (self.training and torch.rand(1) > module_dropout): + result = result + self._get_delta_activations(active_adapter, x, *args, **kwargs) + + result = result.to(previous_dtype) + return result + + +class Linear(LoKrLayer): + """LoKr implemented in Linear layer""" + + def __init__( + self, + base_layer: nn.Module, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, + adapter_name: str = "default", + r: int = 0, + alpha: float = 0.0, + rank_dropout: float = 0.0, + module_dropout: float = 0.0, + init_weights: bool = True, + **kwargs, + ): + super().__init__(base_layer) + + # Create adapter and set it active + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, **kwargs) + + def _get_delta_activations( + self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any + ) -> torch.Tensor: + delta_weight = self.get_delta_weight(adapter_name) + # don't add bias here, because the bias is already included in the output of the base_layer + return F.linear(input, delta_weight) + + def __repr__(self) -> str: + rep = super().__repr__() + return "lokr." + rep + + +class Conv2d(LoKrLayer): + """LoKr implemented in Conv2d layer""" + + def __init__( + self, + base_layer: nn.Module, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, + adapter_name: str = "default", + r: int = 0, + alpha: float = 0.0, + rank_dropout: float = 0.0, + module_dropout: float = 0.0, + use_effective_conv2d: bool = False, + init_weights: bool = True, + **kwargs, + ): + super().__init__(base_layer) + + # Create adapter and set it active + self._active_adapter = adapter_name + self.update_layer( + adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs + ) + + def _get_delta_activations( + self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any + ) -> torch.Tensor: + delta_weight = self.get_delta_weight(adapter_name) + # don't add bias here, because the bias is already included in the output of the base_layer + base_layer = self.get_base_layer() + return F.conv2d( + input, + delta_weight, + stride=base_layer.stride, + padding=base_layer.padding, + dilation=base_layer.dilation, + groups=base_layer.groups, + ) + + def __repr__(self) -> str: + rep = super().__repr__() + return "lokr." + rep + + +# Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/lokr.py#L11 + + +def factorization(dimension: int, factor: int = -1) -> Tuple[int, int]: + """Factorizes the provided number into the product of two numbers + + Args: + dimension (`int`): The number that needs to be factorized. + factor (`int`, optional): + Factorization divider. The algorithm will try to output two numbers, one of each will be as close to the + factor as possible. If -1 is provided, the decomposition algorithm would try to search dividers near the + square root of the dimension. Defaults to -1. + + Returns: + Tuple[`int`, `int`]: A tuple of two numbers, whose product is equal to the provided number. The first number is + always less than or equal to the second. + + Example: + ```py + >>> factorization(256, factor=-1) + (16, 16) + + >>> factorization(128, factor=-1) + (8, 16) + + >>> factorization(127, factor=-1) + (1, 127) + + >>> factorization(128, factor=4) + (4, 32) + ``` + """ + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + return m, n + if factor == -1: + factor = dimension + m, n = 1, dimension + length = m + n + while m < n: + new_m = m + 1 + while dimension % new_m != 0: + new_m += 1 + new_n = dimension // new_m + if new_m + new_n > length or new_m > factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + + +def make_weight_cp(t, wa, wb): + rebuild2 = torch.einsum("i j k l, i p, j r -> p r k l", t, wa, wb) # [c, d, k1, k2] + return rebuild2 + + +def make_kron(w1, w2, scale=1.0): + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + w2 = w2.contiguous() + rebuild = torch.kron(w1, w2) + + return rebuild * scale diff --git a/peft/tuners/lokr/model.py b/peft/tuners/lokr/model.py new file mode 100644 index 0000000000000000000000000000000000000000..eecad8dd13d8f637ea8e6da2377473f314fa6aac --- /dev/null +++ b/peft/tuners/lokr/model.py @@ -0,0 +1,115 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from itertools import chain +from typing import Dict, Type, Union + +import torch +from torch import nn + +from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner + +from .layer import Conv2d, Linear, LoKrLayer + + +class LoKrModel(LycorisTuner): + """ + Creates Low-Rank Kronecker Product model from a pretrained model. The original method is partially described in + https://arxiv.org/abs/2108.06098 and in https://arxiv.org/abs/2309.14859 Current implementation heavily borrows + from + https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/lokr.py + + Args: + model (`torch.nn.Module`): The model to which the adapter tuner layers will be attached. + config ([`LoKrConfig`]): The configuration of the LoKr model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + `torch.nn.Module`: The LoKr model. + + Example: + ```py + >>> from diffusers import StableDiffusionPipeline + >>> from peft import LoKrModel, LoKrConfig + + >>> config_te = LoKrConfig( + ... r=8, + ... lora_alpha=32, + ... target_modules=["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"], + ... rank_dropout=0.0, + ... module_dropout=0.0, + ... init_weights=True, + ... ) + >>> config_unet = LoKrConfig( + ... r=8, + ... lora_alpha=32, + ... target_modules=[ + ... "proj_in", + ... "proj_out", + ... "to_k", + ... "to_q", + ... "to_v", + ... "to_out.0", + ... "ff.net.0.proj", + ... "ff.net.2", + ... ], + ... rank_dropout=0.0, + ... module_dropout=0.0, + ... init_weights=True, + ... use_effective_conv2d=True, + ... ) + + >>> model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> model.text_encoder = LoKrModel(model.text_encoder, config_te, "default") + >>> model.unet = LoKrModel(model.unet, config_unet, "default") + ``` + + **Attributes**: + - **model** ([`~torch.nn.Module`]) -- The model to be adapted. + - **peft_config** ([`LoKrConfig`]): The configuration of the LoKr model. + """ + + prefix: str = "lokr_" + layers_mapping: Dict[Type[torch.nn.Module], Type[LoKrLayer]] = { + torch.nn.Conv2d: Conv2d, + torch.nn.Linear: Linear, + } + + def _create_and_replace( + self, + config: LycorisConfig, + adapter_name: str, + target: Union[LoKrLayer, nn.Module], + target_name: str, + parent: nn.Module, + current_key: str, + ) -> None: + """ + A private method to create and replace the target module with the adapter module. + """ + + # Regexp matching - Find key which matches current target_name in patterns provided + pattern_keys = list(chain(config.rank_pattern.keys(), config.alpha_pattern.keys())) + target_name_key = next(filter(lambda key: re.match(rf"(.*\.)?{key}$", current_key), pattern_keys), target_name) + + kwargs = config.to_dict() + kwargs["r"] = config.rank_pattern.get(target_name_key, config.r) + kwargs["alpha"] = config.alpha_pattern.get(target_name_key, config.alpha) + + if isinstance(target, LoKrLayer): + target.update_layer(adapter_name, **kwargs) + else: + new_module = self._create_new_module(config, adapter_name, target, **kwargs) + self._replace_module(parent, target_name, new_module, target) diff --git a/peft/tuners/lora/__init__.py b/peft/tuners/lora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a0bce2a5fae4aa90ea71210105978c7512aa47f --- /dev/null +++ b/peft/tuners/lora/__init__.py @@ -0,0 +1,42 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_eetq_available + +from .config import LoftQConfig, LoraConfig +from .gptq import QuantLinear +from .layer import Conv2d, Embedding, Linear, LoraLayer +from .model import LoraModel + + +__all__ = ["LoraConfig", "LoftQConfig", "Conv2d", "Embedding", "LoraLayer", "Linear", "LoraModel", "QuantLinear"] + + +def __getattr__(name): + if (name == "Linear8bitLt") and is_bnb_available(): + from .bnb import Linear8bitLt + + return Linear8bitLt + + if (name == "Linear4bit") and is_bnb_4bit_available(): + from .bnb import Linear4bit + + return Linear4bit + + if (name == "EetqLoraLinear") and is_eetq_available(): + from .eetq import EetqLoraLinear + + return EetqLoraLinear + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/peft/tuners/lora/__pycache__/__init__.cpython-310.pyc b/peft/tuners/lora/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43066fb91b9eadf1383ba47951994f20e2d39ebe Binary files /dev/null and b/peft/tuners/lora/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/lora/__pycache__/aqlm.cpython-310.pyc b/peft/tuners/lora/__pycache__/aqlm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b59d028a4488cbcae29e3a77b7d661436db436b Binary files /dev/null and b/peft/tuners/lora/__pycache__/aqlm.cpython-310.pyc differ diff --git a/peft/tuners/lora/__pycache__/awq.cpython-310.pyc b/peft/tuners/lora/__pycache__/awq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..053a6d09147a23d30304bb5fef5197befccd6bcc Binary files /dev/null and b/peft/tuners/lora/__pycache__/awq.cpython-310.pyc differ diff --git a/peft/tuners/lora/__pycache__/bnb.cpython-310.pyc b/peft/tuners/lora/__pycache__/bnb.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9fd73ad08603f81003f37e15d5ceead1bf4856c Binary files /dev/null and b/peft/tuners/lora/__pycache__/bnb.cpython-310.pyc differ diff --git a/peft/tuners/lora/__pycache__/config.cpython-310.pyc b/peft/tuners/lora/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57c08959f8d6993f9d72d225d1cdd628e1a3db71 Binary files /dev/null and b/peft/tuners/lora/__pycache__/config.cpython-310.pyc differ diff --git a/peft/tuners/lora/__pycache__/eetq.cpython-310.pyc b/peft/tuners/lora/__pycache__/eetq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a06827872f2a6ed7c5a77caedf48bf8a33a8d42 Binary files /dev/null and b/peft/tuners/lora/__pycache__/eetq.cpython-310.pyc differ diff --git a/peft/tuners/lora/__pycache__/gptq.cpython-310.pyc b/peft/tuners/lora/__pycache__/gptq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e5810cc1116a71af71f58eb379add607fa68605 Binary files /dev/null and b/peft/tuners/lora/__pycache__/gptq.cpython-310.pyc differ diff --git a/peft/tuners/lora/__pycache__/hqq.cpython-310.pyc b/peft/tuners/lora/__pycache__/hqq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8177da62bb7d826672a032fbb1a11729e5cfb976 Binary files /dev/null and b/peft/tuners/lora/__pycache__/hqq.cpython-310.pyc differ diff --git a/peft/tuners/lora/__pycache__/layer.cpython-310.pyc b/peft/tuners/lora/__pycache__/layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fceffa4daff88aa6fd925ae12722183e8e47388b Binary files /dev/null and b/peft/tuners/lora/__pycache__/layer.cpython-310.pyc differ diff --git a/peft/tuners/lora/__pycache__/model.cpython-310.pyc b/peft/tuners/lora/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c1956d00d5eab0679f5bf25ddbb69720abc4297 Binary files /dev/null and b/peft/tuners/lora/__pycache__/model.cpython-310.pyc differ diff --git a/peft/tuners/lora/__pycache__/tp_layer.cpython-310.pyc b/peft/tuners/lora/__pycache__/tp_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..881b5359c888dc05903ae134f1c7213b7b23d374 Binary files /dev/null and b/peft/tuners/lora/__pycache__/tp_layer.cpython-310.pyc differ diff --git a/peft/tuners/lora/aqlm.py b/peft/tuners/lora/aqlm.py new file mode 100644 index 0000000000000000000000000000000000000000..8c8e90e62bb439467ea0300954e1f673e0e431ae --- /dev/null +++ b/peft/tuners/lora/aqlm.py @@ -0,0 +1,100 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import torch + +from peft.import_utils import is_aqlm_available +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer + + +if is_aqlm_available(): + from aqlm import QuantizedLinear + + +class AqlmLoraLinear(torch.nn.Module, LoraLayer): + def __init__( + self, + base_layer, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + use_rslora: bool = False, + **kwargs, + ): + super().__init__() + LoraLayer.__init__(self, base_layer) + + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + + def forward(self, x: torch.Tensor): + # note: logic differs from default Linear because merging is not supported + result = self.base_layer(x) + + if self.disable_adapters: + return result + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result += output + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + # TODO: Check if it is better as suggested by users https://github.com/PanQiWei/AutoGPTQ/pull/102 + # def reset_lora_parameters(self, adapter_name): + # if adapter_name in self.lora_A.keys(): + # torch.nn.init.xavier_uniform_(self.lora_A[adapter_name].weight) + # torch.nn.init.zeros_(self.lora_B[adapter_name].weight) + + +def dispatch_aqlm( + target: torch.nn.Module, + adapter_name: str, + **kwargs: Any, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if is_aqlm_available() and isinstance(target_base_layer, QuantizedLinear): + new_module = AqlmLoraLinear(target, adapter_name, **kwargs) + target.qweight = target_base_layer.codes + + return new_module diff --git a/peft/tuners/lora/awq.py b/peft/tuners/lora/awq.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f5bf3978f18ea58b52e5df1bcaf25f1e44fd40 --- /dev/null +++ b/peft/tuners/lora/awq.py @@ -0,0 +1,108 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib.metadata as importlib_metadata +from typing import Any, Optional + +import packaging.version +import torch + +from peft.import_utils import is_auto_awq_available +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer + + +if is_auto_awq_available(): + from awq.modules.linear import WQLinear_GEMM + + +class AwqLoraLinear(torch.nn.Module, LoraLayer): + def __init__( + self, + base_layer, + adapter_name, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + use_rslora: bool = False, + **kwargs, + ): + super().__init__() + LoraLayer.__init__(self, base_layer) + + # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter + # for backwards compatibility + self.quant_linear_module = base_layer + + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + + def forward(self, x: torch.Tensor): + result = self.quant_linear_module(x) + + if self.disable_adapters: + return result + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result = result + output + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +def dispatch_awq( + target: torch.nn.Module, + adapter_name: str, + **kwargs: Any, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if is_auto_awq_available() and isinstance(target_base_layer, WQLinear_GEMM): + # Raise the error only at the dispatch level + AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0") + version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq")) + + if AUTOAWQ_MINIMUM_VERSION > version_autoawq: + raise ImportError( + f"Found an incompatible version of auto-awq. Found version {version_autoawq}, " + f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported for PEFT." + ) + + new_module = AwqLoraLinear(target, adapter_name, **kwargs) + target.qweight = target_base_layer.qweight + + return new_module diff --git a/peft/tuners/lora/bnb.py b/peft/tuners/lora/bnb.py new file mode 100644 index 0000000000000000000000000000000000000000..c9f8cf3f6dd88b99fd308b73a998b17f76eac513 --- /dev/null +++ b/peft/tuners/lora/bnb.py @@ -0,0 +1,508 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import warnings +from typing import Any, Optional + +import bitsandbytes as bnb +import torch + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.integrations import dequantize_bnb_weight +from peft.utils.other import transpose + +from .layer import LoraLayer + + +if is_bnb_available(): + + class Linear8bitLt(torch.nn.Module, LoraLayer): + # Lora implemented in a dense layer + def __init__( + self, + base_layer: torch.nn.Module, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ) -> None: + super().__init__() + LoraLayer.__init__(self, base_layer) + self.fan_in_fan_out = False + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. + Defaults to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter not in self.lora_A.keys(): + continue + + warnings.warn( + "Merge lora module to 8-bit linear may get different generations due to rounding errors." + ) + lora_data = self.get_delta_weight(active_adapter) + + weight = self.get_base_layer().weight + state = self.get_base_layer().state + if state.SCB is None: + state.SCB = weight.SCB + + # Dequantize the result of identity matrix and int8 weight because bitsandbytes does not support int8 + # dequantization directly + output = dequantize_bnb_weight(weight, state=state) + if not self.use_dora[active_adapter]: + w_data = output.to(lora_data.dtype).to(lora_data.device) + lora_data + else: + # handle dora + # since output already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm(output, lora_data, scaling=1).detach() + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + w_data = dora_factor.view(-1, 1) * (output + lora_data) + + if safe_merge and not torch.isfinite(w_data).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + self.get_base_layer().weight = bnb.nn.Int8Params( + w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights + ).to(weight.device) + state.reset_grads() + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter not in self.lora_A.keys(): + continue + warnings.warn( + "Unmerge lora module to 8-bit linear may get different generations due to rounding errors." + ) + lora_data = self.get_delta_weight(active_adapter) + + weight = self.get_base_layer().weight + state = self.get_base_layer().state + if state.SCB is None: + state.SCB = weight.SCB + output = dequantize_bnb_weight(weight, state=state) + + if not self.use_dora[active_adapter]: + w_data = output.to(lora_data.dtype).to(lora_data.device) - lora_data + else: + weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + w_data = output.data / dora_factor.view(-1, 1) - lora_data + + self.get_base_layer().weight = bnb.nn.Int8Params( + w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights + ).to(weight.device) + state.reset_grads() + + def get_delta_weight(self, adapter): + return ( + transpose( + self.lora_B[adapter].weight @ self.lora_A[adapter].weight, + False, + ) + * self.scaling[adapter] + ) + + def _mixed_batch_forward( + self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any + ) -> torch.Tensor: + # This is a special method that handles the case when users pass the argument `adapter_names`. This is an + # extra argument that allows mixing different adapters in the same batch at inference time. + result = self.base_layer(x, *args, **kwargs) + + unique_adapters = set(adapter_names) + sub_batch_indices_list = [] + for adapter in unique_adapters: + sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + + for i, active_adapter in enumerate(unique_adapters): + if active_adapter == "__base__": + continue + if active_adapter not in self.lora_A.keys(): + continue + + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + compute_dtype = lora_A.weight.dtype + if x.dtype != compute_dtype: + x = x.to(compute_dtype) + + # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear + # layer output + sub_batch = x[sub_batch_indices_list[i]] + output = lora_B(lora_A(dropout(sub_batch))) * scaling + if requires_conversion: + output = output.to(expected_dtype) + result[sub_batch_indices_list[i]] += output + + return result + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + compute_dtype = lora_A.weight.dtype + if x.dtype != compute_dtype: + x = x.to(compute_dtype) + + if not self.use_dora[active_adapter]: + output = lora_B(lora_A(dropout(x))) * scaling + else: + output = self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) + if requires_conversion: + output = output.to(expected_dtype) + + result = result + output + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + def dispatch_bnb_8bit(target: torch.nn.Module, adapter_name: str, **kwargs): + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + loaded_in_8bit = kwargs.get("loaded_in_8bit", False) + if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt): + eightbit_kwargs = kwargs.copy() + eightbit_kwargs.update( + { + "has_fp16_weights": target.state.has_fp16_weights, + "memory_efficient_backward": target.state.memory_efficient_backward, + "threshold": target.state.threshold, + "index": target.index, + } + ) + new_module = Linear8bitLt(target, adapter_name, **eightbit_kwargs) + + return new_module + + +if is_bnb_4bit_available(): + + class Linear4bit(torch.nn.Module, LoraLayer): + # Lora implemented in a dense layer + def __init__( + self, + base_layer: torch.nn.Module, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ) -> None: + super().__init__() + LoraLayer.__init__(self, base_layer) + self.fan_in_fan_out = False + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. + Defaults to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter not in self.lora_A.keys(): + continue + + warnings.warn( + "Merge lora module to 4-bit linear may get different generations due to rounding errors." + ) + # Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930 + weight = self.get_base_layer().weight + kwargs = weight.__dict__ + lora_data = self.get_delta_weight(active_adapter) + + output = dequantize_bnb_weight(weight, state=weight.quant_state) + if not self.use_dora[active_adapter]: + w_data = output + lora_data + else: + # handle dora + # since output already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm(output, lora_data, scaling=1).detach() + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + w_data = dora_factor.view(-1, 1) * (output + lora_data) + + if safe_merge and not torch.isfinite(w_data).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + if "bnb_quantized" in kwargs: + kwargs["bnb_quantized"] = False + self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to( + weight.device + ) + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter not in self.lora_A.keys(): + continue + warnings.warn( + "Unmerge lora module to 4-bit linear may get different generations due to rounding errors." + ) + + lora_data = self.get_delta_weight(active_adapter) + weight = self.get_base_layer().weight + kwargs = weight.__dict__ + output = dequantize_bnb_weight(weight, state=weight.quant_state) + + if not self.use_dora[active_adapter]: + w_data = output - lora_data + else: + weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + w_data = output.data / dora_factor.view(-1, 1) - lora_data + + if "bnb_quantized" in kwargs: + kwargs["bnb_quantized"] = False + self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to( + weight.device + ) + + def get_delta_weight(self, adapter): + return ( + transpose( + self.lora_B[adapter].weight @ self.lora_A[adapter].weight, + False, + ) + * self.scaling[adapter] + ) + + def _mixed_batch_forward( + self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any + ) -> torch.Tensor: + # This is a special method that handles the case when users pass the argument `adapter_names`. This is an + # extra argument that allows mixing different adapters in the same batch at inference time. + result = self.base_layer(x, *args, **kwargs) + + unique_adapters = set(adapter_names) + sub_batch_indices_list = [] + for adapter in unique_adapters: + sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + + for i, active_adapter in enumerate(unique_adapters): + if active_adapter == "__base__": + continue + if active_adapter not in self.lora_A.keys(): + continue + + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear + # layer output + sub_batch = x[sub_batch_indices_list[i]] + output = lora_B(lora_A(dropout(sub_batch))) * scaling + if requires_conversion: + output = output.to(expected_dtype) + result[sub_batch_indices_list[i]] += output + + return result + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + # As per Tim Dettmers, for 4bit, we need to defensively clone here. + # The reason is that in some cases, an error can occur that backprop + # does not work on a manipulated view. This issue may be solved with + # newer PyTorch versions but this would need extensive testing to be + # sure. + result = result.clone() + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + if not self.use_dora[active_adapter]: + output = lora_B(lora_A(dropout(x))) * scaling + else: + output = self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) + if requires_conversion: + output = output.to(expected_dtype) + + result = result + output + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + def dispatch_bnb_4bit(target: torch.nn.Module, adapter_name: str, **kwargs): + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + loaded_in_4bit = kwargs.get("loaded_in_4bit", False) + if loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit): + fourbit_kwargs = kwargs.copy() + fourbit_kwargs.update( + { + "compute_dtype": target_base_layer.compute_dtype, + "compress_statistics": target_base_layer.weight.compress_statistics, + "quant_type": target_base_layer.weight.quant_type, + } + ) + new_module = Linear4bit(target, adapter_name, **fourbit_kwargs) + + return new_module diff --git a/peft/tuners/lora/config.py b/peft/tuners/lora/config.py new file mode 100644 index 0000000000000000000000000000000000000000..3317c5b753100a349e0d5096335246042189d069 --- /dev/null +++ b/peft/tuners/lora/config.py @@ -0,0 +1,310 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Literal, Optional, Union + +from peft.config import PeftConfig +from peft.utils import PeftType + + +@dataclass +class LoftQConfig: + """ + This is the sub-configuration class to store the configuration of a [`LoraModel`]. + + Args: + bits_pattern (`dict`): The mapping from layer names or regexp expression to bits which are different from the + default bits specified by `bits`. For example, `{model.decoder.layers.0.encoder_attn.k_proj: 2`}. + bits (`int`): Quantization bits for LoftQ. + iter (`int`): Alternating iterations for LoftQ. + fake (`bool`): True: use fp16/fp32; used for first time to save weights. False: use bitsandbytes 4bit linear + models. weights can't be saved. Recommend to set to True, save the weights and load the saved weights in 4 + bits. + """ + + loftq_bits: int = field(default=4, metadata={"help": "Quantization bits for LoftQ"}) + loftq_iter: int = field(default=1, metadata={"help": "Alternating iterations for LoftQ"}) + + +@dataclass +class LoraConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`LoraModel`]. + + Args: + r (`int`): + Lora attention dimension (the "rank"). + target_modules (`Optional[Union[List[str], str]]`): + The names of the modules to apply the adapter to. If this is specified, only the modules with the specified + names will be replaced. When passing a string, a regex match will be performed. When passing a list of + strings, either an exact match will be performed or it is checked if the name of the module ends with any + of the passed strings. If this is specified as 'all-linear', then all linear/Conv1D modules are chosen, + excluding the output layer. If this is not specified, modules will be chosen according to the model + architecture. If the architecture is not known, an error will be raised -- in this case, you should specify + the target modules manually. + lora_alpha (`int`): + The alpha parameter for Lora scaling. + lora_dropout (`float`): + The dropout probability for Lora layers. + fan_in_fan_out (`bool`): + Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses + `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`. + bias (`str`): + Bias type for LoRA. Can be 'none', 'all' or 'lora_only'. If 'all' or 'lora_only', the corresponding biases + will be updated during training. Be aware that this means that, even when disabling the adapters, the model + will not produce the same output as the base model would have without adaptation. + use_rslora (`bool`): + When set to True, uses Rank-Stabilized LoRA which + sets the adapter scaling factor to `lora_alpha/math.sqrt(r)`, since it was proven to work better. + Otherwise, it will use the original default value of `lora_alpha/r`. + modules_to_save (`List[str]`): + List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint. + init_lora_weights (`bool` | `Literal["gaussian", "pissa", "pissa_niter_[number of iters]", "loftq"]`): + How to initialize the weights of the adapter layers. Passing True (default) results in the default + initialization from the reference implementation from Microsoft. Passing 'gaussian' results in Gaussian + initialization scaled by the LoRA rank for linear and layers. Setting the initialization to False leads to + completely random initialization and is discouraged. Pass `'loftq'` to use LoftQ initialization. Passing + 'pissa' results in the initialization of PiSSA, which converge more rapidly than LoRA and ultimately + achieve superior performance. Moreover, PiSSA reduces the quantization error compared to QLoRA, leading to + further enhancements. Passing 'pissa_niter_[number of iters]' initiates Fast-SVD-based PiSSA + initialization, where [number of iters] indicates the number of subspace iterations to perform FSVD, and + must be a nonnegative integer. When the [number of iters] is set to 16, it can complete the initialization + of a 7b model within seconds, and the training effect is approximately equivalent to using SVD. For more + information, see Principal Singular values and Singular vectors + Adaptation. + layers_to_transform (`Union[List[int], int]`): + The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices + that are specified in this list. If a single integer is passed, it will apply the transformations on the + layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None`. + rank_pattern (`dict`): + The mapping from layer names or regexp expression to ranks which are different from the default rank + specified by `r`. + alpha_pattern (`dict`): + The mapping from layer names or regexp expression to alphas which are different from the default alpha + specified by `lora_alpha`. + megatron_config (`Optional[dict]`): + The TransformerConfig arguments for Megatron. It is used to create LoRA's parallel linear layer. You can + get it like this, `core_transformer_config_from_args(get_args())`, these two functions being from Megatron. + The arguments will be used to initialize the TransformerConfig of Megatron. You need to specify this + parameter when you want to apply LoRA to the ColumnParallelLinear and RowParallelLinear layers of megatron. + megatron_core (`Optional[str]`): + The core module from Megatron to use, defaults to `"megatron.core"`. + loftq_config (`Optional[LoftQConfig]`): + The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone weights + and initialize Lora layers. Also pass `init_lora_weights='loftq'`. Note that you should not pass a + quantized model in this case, as LoftQ will quantize the model itself. + use_dora (`bool`): + Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the weights + into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is + handled by a separate learnable parameter. This can improve the performance of LoRA especially at low + ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a bigger overhead than pure + LoRA, so it is recommended to merge weights for inference. For more information, see + https://arxiv.org/abs/2402.09353. + layer_replication (`List[Tuple[int, int]]`): + Build a new stack of layers by stacking the original model layers according to the ranges specified. This + allows expanding (or shrinking) the model without duplicating the base model weights. The new layers will + all have separate LoRA adapters attached to them. + """ + + r: int = field(default=8, metadata={"help": "Lora attention dimension"}) + target_modules: Optional[Union[list[str], str]] = field( + default=None, + metadata={ + "help": ( + "List of module names or regex expression of the module names to replace with LoRA." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'." + "This can also be a wildcard 'all-linear' which matches all linear/Conv1D layers except the output layer." + "If not specified, modules will be chosen according to the model architecture, If the architecture is " + "not known, an error will be raised -- in this case, you should specify the target modules manually." + ), + }, + ) + lora_alpha: int = field(default=8, metadata={"help": "Lora alpha"}) + lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"}) + fan_in_fan_out: bool = field( + default=False, + metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, + ) + bias: Literal["none", "all", "lora_only"] = field( + default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"} + ) + use_rslora: bool = field( + default=False, + metadata={ + "help": ( + "When set to True, uses Rank-Stabilized LoRA doi.org/10.48550/arXiv.2312.03732" + " which sets the adapter scaling factor to `lora_alpha/math.sqrt(r)`, since it" + " was proven to work better. Otherwise, it will use the original default" + " value of `lora_alpha/r`." + ) + }, + ) + modules_to_save: Optional[list[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. " + "For example, in Sequence Classification or Token Classification tasks, " + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." + }, + ) + init_lora_weights: bool | Literal["gaussian", "pissa", "pissa_niter_[number of iters]", "loftq"] = field( + default=True, + metadata={ + "help": ( + "How to initialize the weights of the LoRA layers. Passing True (default) results in the default " + "initialization from the reference implementation from Microsoft. Passing 'gaussian' results " + "in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization " + "to False leads to completely random initialization and is discouraged." + "Passing 'pissa' results in PiSSA initialization." + "Passing 'pissa_niter_[number of iters]' initiates Fast-SVD-based PiSSA initialization, " + "where [number of iters] indicates the number of subspace iterations to perform fsvd, and must be a nonnegative integer." + "Pass `'loftq'` to use LoftQ initialization" + ), + }, + ) + layers_to_transform: Optional[Union[list[int], int]] = field( + default=None, + metadata={ + "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index. " + "This only works when target_modules is a list of str." + }, + ) + layers_pattern: Optional[Union[list[str], str]] = field( + default=None, + metadata={ + "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." + "This only works when target_modules is a list of str." + }, + ) + rank_pattern: Optional[dict] = field( + default_factory=dict, + metadata={ + "help": ( + "The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`. " + "For example, `{model.decoder.layers.0.encoder_attn.k_proj: 8`}" + ) + }, + ) + alpha_pattern: Optional[dict] = field( + default_factory=dict, + metadata={ + "help": ( + "The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `lora_alpha`. " + "For example, `{model.decoder.layers.0.encoder_attn.k_proj: 32`}" + ) + }, + ) + megatron_config: Optional[dict] = field( + default=None, + metadata={ + "help": ( + "The TransformerConfig from Megatron. It is used to create LoRA's parallel linear layer." + "You can get it like this, `core_transformer_config_from_args(get_args())`, " + "these two functions being from Megatron." + "You need to specify this parameter when you want to apply LoRA to the ColumnParallelLinear and " + "RowParallelLinear layers of megatron." + "It should be noted that we may not be able to use the `save_pretrained` and `from_pretrained` " + "functions, because TransformerConfig may not necessarily be serialized." + "But when using megatron, we can use `get_peft_model_state_dict` function and " + "megatron's framework, they can also save and load models and configurations." + ) + }, + ) + megatron_core: Optional[str] = field( + default="megatron.core", + metadata={ + "help": ( + "The core module from Megatron, it is used to create LoRA's parallel linear layer. " + "It only needs to be passed in when you need to use your own modified megatron core module. " + "Otherwise, it will use the default value `megatron.core`. " + ) + }, + ) + # dict type is used when loading config.json + loftq_config: Union[LoftQConfig, dict] = field( + default_factory=dict, + metadata={ + "help": ( + "The configuration of LoftQ. If this is passed, then LoftQ will be used to quantize the backbone " + "weights and initialize Lora layers. Also set `init_lora_weights='loftq'` in this case." + ) + }, + ) + use_dora: bool = field( + default=False, + metadata={ + "help": ( + "Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the " + "weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the " + "magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, " + "especially at low ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a bigger" + "overhead than pure LoRA, so it is recommended to merge weights for inference. For more information, " + "see https://arxiv.org/abs/2402.09353." + ) + }, + ) + # Enables replicating layers in a model to expand it to a larger model. + layer_replication: Optional[list[tuple[int, int]]] = field( + default=None, + metadata={ + "help": ( + "This enables using LoRA to effectively expand a transformer model to a larger size by repeating some layers. " + "The transformation handles models (currently Llama, Bert or Falcon compatible architectures) with " + "a module list in the model which it modifies to expand the number of modules. " + "Base weights are shared so the memory usage is close to the original model. The intended use is these base weights " + "remain fixed during finetuning but each layer has a separate LoRA adapter so the layers can be specialed via " + "the adapter layers fit during fine tuning." + "The format is a list of [start, end) pairs which specify the layer ranges to stack. For example:\n" + " Original model has 5 layers labelled by their position in the model: `[0, 1, 2, 3, 4]`\n" + " layer_replication: `[[0, 4], [2, 5]]`\n" + " Final model will have this arrangement of original layers: `[0, 1, 2, 3, 2, 3, 4]`\n" + "This format is based on what is used for pass-through merges in mergekit. It makes it simple to select sequential " + "ranges of a model and stack them while reusing layers at either end of each sequence." + ) + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.LORA + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) + # if target_modules is a regex expression, then layers_to_transform should be None + if isinstance(self.target_modules, str) and self.layers_to_transform is not None: + raise ValueError("`layers_to_transform` cannot be used when `target_modules` is a str.") + + # if target_modules is a regex expression, then layers_pattern should be None + if isinstance(self.target_modules, str) and self.layers_pattern is not None: + raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") + + if self.use_dora and self.megatron_config: + raise ValueError("DoRA does not support megatron_core, please set `use_dora=False`.") + + # handle init_lora_weights and loftq_config + if self.init_lora_weights == "loftq": + import importlib + + if not importlib.util.find_spec("scipy"): + raise ImportError("The required package 'scipy' is not installed. Please install it to continue.") + if self.loftq_config is None: + raise ValueError("`loftq_config` must be specified when `init_lora_weights` is 'loftq'.") + + # convert loftq_config to dict + if self.loftq_config and not isinstance(self.loftq_config, dict): + self.loftq_config = vars(self.loftq_config) diff --git a/peft/tuners/lora/eetq.py b/peft/tuners/lora/eetq.py new file mode 100644 index 0000000000000000000000000000000000000000..6bf42c68141e15b3543cb2587be7fd5a527be685 --- /dev/null +++ b/peft/tuners/lora/eetq.py @@ -0,0 +1,104 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional + +import torch + +from peft.import_utils import is_eetq_available +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer + + +if is_eetq_available(): + from eetq import EetqLinear + + class EetqLoraLinear(torch.nn.Module, LoraLayer): + def __init__( + self, + base_layer, + adapter_name, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + use_rslora: bool = False, + **kwargs, + ): + super().__init__() + LoraLayer.__init__(self, base_layer) + + # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter + # for backwards compatibility + self.quant_linear_module = base_layer + + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + + def forward(self, x: torch.Tensor): + result = self.quant_linear_module(x) + + if self.disable_adapters: + return result + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result = result + output + return result + + def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + raise AttributeError("Merging LoRA layers is not supported for Eetq layers.") + + def unmerge(self) -> None: + raise AttributeError("Unmerging LoRA layers is not supported for Eetq layers.") + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +def dispatch_eetq( + target: torch.nn.Module, + adapter_name: str, + **kwargs: Any, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if is_eetq_available() and isinstance(target_base_layer, EetqLinear): + new_module = EetqLoraLinear(target, adapter_name, **kwargs) + target.weight = target_base_layer.weight + + if hasattr(target, "bias"): + target.bias = target_base_layer.bias + + return new_module diff --git a/peft/tuners/lora/gptq.py b/peft/tuners/lora/gptq.py new file mode 100644 index 0000000000000000000000000000000000000000..333dfa6feb7595e185ae81f540aaa18fc1f2233a --- /dev/null +++ b/peft/tuners/lora/gptq.py @@ -0,0 +1,114 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import torch + +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer +from peft.utils import get_auto_gptq_quant_linear + + +class QuantLinear(torch.nn.Module, LoraLayer): + def __init__( + self, + base_layer, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ): + super().__init__() + LoraLayer.__init__(self, base_layer) + + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + + # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter + # for backwards compatibility + self.quant_linear_module = base_layer + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) + + def forward(self, x: torch.Tensor): + # note: logic differs from default Linear because merging is not supported + result = self.quant_linear_module(x) + + if self.disable_adapters: + return result + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result += output + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + # TODO: Check if it is better as suggested by users https://github.com/PanQiWei/AutoGPTQ/pull/102 + # def reset_lora_parameters(self, adapter_name): + # if adapter_name in self.lora_A.keys(): + # torch.nn.init.xavier_uniform_(self.lora_A[adapter_name].weight) + # torch.nn.init.zeros_(self.lora_B[adapter_name].weight) + + +def dispatch_gptq( + target: torch.nn.Module, + adapter_name: str, + **kwargs: Any, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + gptq_quantization_config = kwargs.get("gptq_quantization_config", None) + AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) + + if AutoGPTQQuantLinear is not None and isinstance(target_base_layer, AutoGPTQQuantLinear): + new_module = QuantLinear(target, adapter_name, **kwargs) + target.qweight = target_base_layer.qweight + + return new_module diff --git a/peft/tuners/lora/hqq.py b/peft/tuners/lora/hqq.py new file mode 100644 index 0000000000000000000000000000000000000000..c3a51e7b27907ac99de34b8b7f1f0111d2ebe385 --- /dev/null +++ b/peft/tuners/lora/hqq.py @@ -0,0 +1,247 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import copy +import warnings +from typing import Any, Optional + +import torch + +from peft.import_utils import is_hqq_available +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.other import transpose + +from .layer import LoraLayer + + +if is_hqq_available(): + from hqq.core.quantize import HQQLinear + + class HqqLoraLinear(torch.nn.Module, LoraLayer): + # Lora implemented in a dense layer + def __init__( + self, + base_layer: torch.nn.Module, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ) -> None: + super().__init__() + LoraLayer.__init__(self, base_layer) + self.fan_in_fan_out = False + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. + Defaults to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter not in self.lora_A.keys(): + continue + + layer = self.get_base_layer() + quant_config = {**copy.deepcopy(layer.quant_config), "offload_meta": layer.offload_meta} + lora_data = self.get_delta_weight(active_adapter) + + output = layer.dequantize() + if not self.use_dora[active_adapter]: + w_data = output + lora_data + else: + # handle dora + # since output already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm(output, lora_data, scaling=1).detach() + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + w_data = dora_factor.view(-1, 1) * (output + lora_data) + + if safe_merge and not torch.isfinite(w_data).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + new_hqq_layer = HQQLinear(None, quant_config, compute_dtype=layer.compute_dtype, device=layer.device) + quant_config.pop("offload_meta", None) + new_hqq_layer.quantize(w_data, **quant_config) + self.base_layer = new_hqq_layer + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter not in self.lora_A.keys(): + continue + + lora_data = self.get_delta_weight(active_adapter) + layer = self.get_base_layer() + quant_config = {**copy.deepcopy(layer.quant_config), "offload_meta": layer.offload_meta} + output = layer.dequantize() + + if not self.use_dora[active_adapter]: + w_data = output - lora_data + else: + weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + w_data = output.data / dora_factor.view(-1, 1) - lora_data + + new_hqq_layer = HQQLinear(None, quant_config, compute_dtype=layer.compute_dtype, device=layer.device) + quant_config.pop("offload_meta", None) + new_hqq_layer.quantize(w_data, **quant_config) + self.base_layer = new_hqq_layer + + def get_delta_weight(self, adapter): + return ( + transpose( + self.lora_B[adapter].weight @ self.lora_A[adapter].weight, + False, + ) + * self.scaling[adapter] + ) + + def _mixed_batch_forward( + self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any + ) -> torch.Tensor: + # This is a special method that handles the case when users pass the argument `adapter_names`. This is an + # extra argument that allows mixing different adapters in the same batch at inference time. + result = self.base_layer(x, *args, **kwargs) + + unique_adapters = set(adapter_names) + sub_batch_indices_list = [] + for adapter in unique_adapters: + sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + + for i, active_adapter in enumerate(unique_adapters): + if active_adapter == "__base__": + continue + if active_adapter not in self.lora_A.keys(): + continue + + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + compute_dtype = lora_A.weight.dtype + if x.dtype != compute_dtype: + x = x.to(compute_dtype) + + # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear + # layer output + sub_batch = x[sub_batch_indices_list[i]] + output = lora_B(lora_A(dropout(sub_batch))) * scaling + if requires_conversion: + output = output.to(expected_dtype) + result[sub_batch_indices_list[i]] += output + + return result + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + compute_dtype = lora_A.weight.dtype + if x.dtype != compute_dtype: + x = x.to(compute_dtype) + + if not self.use_dora[active_adapter]: + output = lora_B(lora_A(dropout(x))) * scaling + else: + output = self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) + if requires_conversion: + output = output.to(expected_dtype) + + result = result + output + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +def dispatch_hqq(target: torch.nn.Module, adapter_name: str, **kwargs): + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if is_hqq_available() and isinstance(target_base_layer, HQQLinear): + new_module = HqqLoraLinear(target_base_layer, adapter_name, **kwargs) + + return new_module diff --git a/peft/tuners/lora/layer.py b/peft/tuners/lora/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..01272a49d4849d737a05ab463eb4f514e3a4c985 --- /dev/null +++ b/peft/tuners/lora/layer.py @@ -0,0 +1,1357 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import math +import warnings +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.pytorch_utils import Conv1D + +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.integrations import dequantize_bnb_weight, gather_params_ctx +from peft.utils.other import transpose + +from .config import LoraConfig + +from einops import rearrange + +class LoraLayer(BaseTunerLayer): + # All names of layers that may contain (trainable) adapter weights + adapter_layer_names = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B") + # All names of other parameters that may contain adapter-related parameters + other_param_names = ("r", "lora_alpha", "scaling", "lora_dropout") + + def __init__(self, base_layer: nn.Module, **kwargs) -> None: + self.base_layer = base_layer + self.r = {} + self.lora_alpha = {} + self.scaling = {} + self.lora_dropout = nn.ModuleDict({}) + self.lora_A = nn.ModuleDict({}) + self.lora_B = nn.ModuleDict({}) + # For Embedding layer + self.lora_embedding_A = nn.ParameterDict({}) + self.lora_embedding_B = nn.ParameterDict({}) + # Mark the weight as unmerged + self._disable_adapters = False + self.merged_adapters = [] + self.use_dora: dict[str, bool] = {} + self.lora_magnitude_vector: Optional[torch.nn.ParameterDict] = None # for DoRA + self._caches: dict[str, Any] = {} + self.kwargs = kwargs + + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + in_features, out_features = base_layer.in_features, base_layer.out_features + elif isinstance(base_layer, nn.Conv2d): + in_features, out_features = base_layer.in_channels, base_layer.out_channels + elif isinstance(base_layer, nn.Embedding): + in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim + elif isinstance(base_layer, Conv1D): + in_features, out_features = ( + base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape + ) + elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"): + # QuantLinear + in_features, out_features = base_layer.infeatures, base_layer.outfeatures + elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"): + # Megatron ColumnParallelLinear,RowParallelLinear + in_features, out_features = base_layer.input_size, base_layer.output_size + elif hasattr(base_layer, "codebooks") and base_layer.__class__.__name__ == "QuantizedLinear": + # AQLM QuantLinear + in_features, out_features = base_layer.in_features, base_layer.out_features + elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM": + # Awq layers + in_features, out_features = base_layer.in_features, base_layer.out_features + else: + raise ValueError(f"Unsupported layer type {type(base_layer)}") + + self.in_features = in_features + self.out_features = out_features + + def update_layer( + self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora: bool = False + ): + # This code works for linear layers, override for other layer types + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) + # Actual trainable parameters + self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False) + self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False) + if use_rslora: + self.scaling[adapter_name] = lora_alpha / math.sqrt(r) + else: + self.scaling[adapter_name] = lora_alpha / r + + if init_lora_weights == "loftq": + self.loftq_init(adapter_name) + elif init_lora_weights: + self.reset_lora_parameters(adapter_name, init_lora_weights) + + # check weight and qweight (for GPTQ) + for weight_name in ("weight", "qweight"): + weight = getattr(self.get_base_layer(), weight_name, None) + if weight is not None: + # the layer is already completely initialized, this is an update + if weight.dtype.is_floating_point or weight.dtype.is_complex: + self.to(weight.device, dtype=weight.dtype) + else: + self.to(weight.device) + break + + if use_dora: + self.dora_init(adapter_name) + self.use_dora[adapter_name] = True + else: + self.use_dora[adapter_name] = False + + self.set_adapter(self.active_adapters) + + def reset_lora_parameters(self, adapter_name, init_lora_weights): + if init_lora_weights is False: + return + + if adapter_name in self.lora_A.keys(): + if init_lora_weights is True: + # initialize A the same way as the default for nn.Linear and B to zero + # https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124 + nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5)) + elif init_lora_weights.lower() == "gaussian": + nn.init.normal_(self.lora_A[adapter_name].weight, std=1 / self.r[adapter_name]) + else: + raise ValueError(f"Unknown initialization {init_lora_weights=}") + nn.init.zeros_(self.lora_B[adapter_name].weight) + if adapter_name in self.lora_embedding_A.keys(): + # initialize a the same way as the default for nn.linear and b to zero + nn.init.zeros_(self.lora_embedding_A[adapter_name]) + nn.init.normal_(self.lora_embedding_B[adapter_name]) + + def loftq_init(self, adapter_name): + from peft.utils.loftq_utils import loftq_init + + weight = self.get_base_layer().weight + kwargs = { + "num_bits": self.kwargs.get("loftq_bits", 4), + "reduced_rank": self.r[adapter_name], + "num_iter": self.kwargs.get("loftq_iter", 1), + } + + qweight, lora_A, lora_B = loftq_init(weight, **kwargs) + if adapter_name in self.lora_A.keys(): + # initialize A the same way as the default for nn.Linear and B to zero + self.lora_A[adapter_name].weight.data = lora_A + self.lora_B[adapter_name].weight.data = lora_B + if adapter_name in self.lora_embedding_A.keys(): + # initialize a the same way as the default for nn.linear and b to zero + self.lora_embedding_A[adapter_name].weight.data = lora_A + self.lora_embedding_B[adapter_name].weight.data = lora_B + self.get_base_layer().weight.data = qweight + + def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: + # calculate L2 norm of weight matrix, column-wise + weight = weight + scaling * lora_weight + weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype) + return weight_norm + + def dora_init(self, adapter_name: str) -> None: + lora_A = self.lora_A[adapter_name] + lora_B = self.lora_B[adapter_name] + scaling = self.scaling[adapter_name] + with gather_params_ctx(self.get_base_layer()): + weight = self.get_base_layer().weight + quant_state = getattr(self.get_base_layer(), "state", None) + weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb + if weight.data.ndim == 4: # For handling LoRAs applied to Conv2Ds. + lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1)) + lora_weight = lora_weight.reshape(weight.shape) + else: + lora_weight = lora_B.weight @ lora_A.weight + weight_norm = self._get_weight_norm(weight, lora_weight, scaling) + self.lora_magnitude_vector = nn.ParameterDict() + self.lora_magnitude_vector[adapter_name] = nn.Parameter(weight_norm, requires_grad=True) + # add lora_magnitude_vector to the list of learnable parameters + self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",) + + def _cache_store(self, key: str, value: Any) -> None: + self._caches[key] = value + + def _cache_pop(self, key: str) -> Any: + value = self._caches.pop(key) + return value + + def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter): + """ + For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer + output. + """ + lora_weight = lora_B.weight @ lora_A.weight + magnitude = self.lora_magnitude_vector[active_adapter] + weight = self.get_base_layer().weight + quant_state = getattr(self.get_base_layer(), "state", None) + weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb + weight = weight.to(x.dtype) + weight_norm = self._get_weight_norm(weight, lora_weight, scaling) + # see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) + # "[...] we suggest treating ||V +∆V ||_c in + # Eq. (5) as a constant, thereby detaching it from the gradient + # graph. This means that while ||V + ∆V ||_c dynamically + # reflects the updates of ∆V , it won’t receive any gradient + # during backpropagation" + weight_norm = weight_norm.detach() + mag_norm_scale = (magnitude / weight_norm).view(1, -1) + result_dora = (mag_norm_scale - 1) * ( + F.linear(x, transpose(weight, self.fan_in_fan_out)) + ) + mag_norm_scale * lora_B(lora_A(x)) * scaling + + # Note: Computation could potentially be accelerated by using the code below instead of calculating X@W again. + # This is only correct if dropout=0, otherwise results will differ: + # https://github.com/huggingface/peft/pull/1474#issuecomment-1964682771 + # bias = self.get_base_layer().bias + # if bias is not None: + # result = result - bias + # result = mag_norm_scale * result + mag_norm_scale * lora_B(lora_A(x)) * scaling + # if bias is not None: + # result = result + bias + + return result_dora + + def set_scale(self, adapter, scale): + if adapter not in self.scaling: + # Ignore the case where the adapter is not in the layer + return + self.scaling[adapter] = scale * self.lora_alpha[adapter] / self.r[adapter] + + def scale_layer(self, scale: float) -> None: + if scale == 1: + return + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + + self.scaling[active_adapter] *= scale + + def unscale_layer(self, scale=None) -> None: + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + + if scale is None: + self.scaling[active_adapter] = self.lora_alpha[active_adapter] / self.r[active_adapter] + else: + self.scaling[active_adapter] /= scale + + def _check_forward_args(self, x, *args, **kwargs): + """Check if the arguments are compatible with the configs and state of the model""" + adapter_names = kwargs.get("adapter_names", None) + if adapter_names is None: + return + + if len(x) != len(adapter_names): + msg = ( + "Length of `adapter_names` should be the same as the number of inputs, but got " + f"{len(adapter_names)} and {len(x)} respectively." + ) + raise ValueError(msg) + + if self.merged: + # It is unclear what would be the right thing to do if users pass adapter_names and there are merged + # adapters. Therefore, it is better to raise an error in this case. + msg = "Cannot pass `adapter_names` when there are merged adapters, please call `unmerge_adapter` first." + raise ValueError(msg) + + unique_adapters = set(self.active_adapters) + for adapter_name in unique_adapters: + if self.use_dora.get(adapter_name, False): + msg = "Cannot pass `adapter_names` when DoRA is enabled." + raise ValueError(msg) + + def _mixed_batch_forward( + self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any + ) -> torch.Tensor: + # This is a special method that handles the case when users pass the argument `adapter_names`. This is an + # extra argument that allows mixing different adapters in the same batch at inference time. + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + + unique_adapters = set(adapter_names) + sub_batch_indices_list = [] + for adapter in unique_adapters: + sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + + for i, active_adapter in enumerate(unique_adapters): + if active_adapter == "__base__": + continue + if active_adapter not in self.lora_A.keys(): + continue + + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear + # layer output + sub_batch = x[sub_batch_indices_list[i]].to(lora_A.weight.dtype) + lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling + result[sub_batch_indices_list[i]] += lora_output.to(torch_result_dtype) + + return result + + +# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# and modified to work with PyTorch FSDP + + +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + + +class Linear(nn.Module, LoraLayer): + # Lora implemented in a dense layer + def __init__( + self, + base_layer, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + is_target_conv_1d_layer: bool = False, + init_lora_weights: Union[bool, str] = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ) -> None: + super().__init__() + LoraLayer.__init__(self, base_layer, **kwargs) + self.fan_in_fan_out = fan_in_fan_out + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) + self.is_target_conv_1d_layer = is_target_conv_1d_layer + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.lora_A.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weights = base_layer.weight.data.clone() + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + orig_weights = orig_weights + delta_weight + else: + # handle dora + # since delta_weight already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm(orig_weights, delta_weight, scaling=1).detach() + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + orig_weights = dora_factor.view(-1, 1) * (orig_weights + delta_weight) + + if not torch.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = orig_weights + else: + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + base_layer.weight.data = base_layer.weight.data + delta_weight + else: + # handle dora + # since delta_weight already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm(base_layer.weight, delta_weight, scaling=1).detach() + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + new_weight = dora_factor.view(-1, 1) * (base_layer.weight.data + delta_weight) + base_layer.weight.data = new_weight + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.lora_A.keys(): + weight = self.get_base_layer().weight + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + weight.data -= delta_weight + else: + weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + weight_orig = weight.data / dora_factor.view(-1, 1) - delta_weight + weight.data = weight_orig + + def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + device = self.lora_B[adapter].weight.device + dtype = self.lora_B[adapter].weight.dtype + + # In case users wants to merge the adapter weights that are in + # float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16. + cast_to_fp32 = device.type == "cpu" and dtype == torch.float16 + + weight_A = self.lora_A[adapter].weight + weight_B = self.lora_B[adapter].weight + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + + output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter] + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + self.lora_A[adapter].weight.data = weight_A.to(dtype) + self.lora_B[adapter].weight.data = weight_B.to(dtype) + + return output_tensor + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(lora_A.weight.dtype) + + if not self.use_dora[active_adapter]: + result = result + lora_B(lora_A(dropout(x))) * scaling + else: + x = dropout(x) + result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) + + result = result.to(torch_result_dtype) + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +class Embedding(nn.Module, LoraLayer): + # LoRA implemented in a Embedding layer + def __init__( + self, + base_layer: nn.Module, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: Union[bool, str] = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ) -> None: + super().__init__() + LoraLayer.__init__(self, base_layer) + + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) + + def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora): + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout[adapter_name] = lora_dropout_layer + # Actual trainable parameters + weight_A = torch.randn((r, self.in_features)) + weight_B = torch.randn((self.out_features, r)) + self.lora_embedding_A[adapter_name] = nn.Parameter(weight_A) + self.lora_embedding_B[adapter_name] = nn.Parameter(weight_B) + if use_rslora: + self.scaling[adapter_name] = lora_alpha / math.sqrt(r) + else: + self.scaling[adapter_name] = lora_alpha / r + + if init_lora_weights == "loftq": + self.loftq_init(adapter_name) + elif init_lora_weights: + self.reset_lora_parameters(adapter_name, init_lora_weights) + + base_layer = self.get_base_layer() + weight = getattr(base_layer, "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + self.to(base_layer.weight.device, dtype=weight.dtype) + + self.set_adapter(self.active_adapters) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.lora_embedding_A.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weights = base_layer.weight.data.clone() + orig_weights = orig_weights + self.get_delta_weight(active_adapter) + + if not torch.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = orig_weights + else: + base_layer.weight.data = base_layer.weight.data + self.get_delta_weight(active_adapter) + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.lora_embedding_A.keys(): + self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) + + def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + device = self.lora_embedding_B[adapter].device + dtype = self.lora_embedding_A[adapter].dtype + + # In case users wants to merge the adapter weights that are in + # float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16. + cast_to_fp32 = device.type == "cpu" and dtype == torch.float16 + + weight_A = self.lora_embedding_A[adapter] + weight_B = self.lora_embedding_B[adapter] + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + + output_tensor = transpose(weight_B @ weight_A, True) * self.scaling[adapter] + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + self.lora_embedding_A[adapter] = weight_A.to(dtype) + self.lora_embedding_B[adapter] = weight_B.to(dtype) + + return output_tensor + + def _mixed_batch_forward( + self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any + ) -> torch.Tensor: + # This is a special method that handles the case when users pass the argument `adapter_names`. This is an + # extra argument that allows mixing different adapters in the same batch at inference time. + result = self.base_layer(x, *args, **kwargs) + + unique_adapters = set(adapter_names) + sub_batch_indices_list = [] + for adapter in unique_adapters: + sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + + for i, active_adapter in enumerate(unique_adapters): + if active_adapter == "__base__": + continue + if active_adapter not in self.lora_embedding_A.keys(): + continue + + embedding_A = self.lora_embedding_A[active_adapter].T + embedding_B = self.lora_embedding_B[active_adapter].T + scaling = self.scaling[active_adapter] + + # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear + # layer output + sub_batch = x[sub_batch_indices_list[i]] + after_A = self._embed(sub_batch, embedding_A) + result[sub_batch_indices_list[i]] += (after_A @ embedding_B) * scaling + + return result + + def _embed(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + base_layer = self.get_base_layer() + return F.embedding( + input, + weight, + padding_idx=base_layer.padding_idx, + max_norm=base_layer.max_norm, + norm_type=base_layer.norm_type, + scale_grad_by_freq=base_layer.scale_grad_by_freq, + sparse=base_layer.sparse, + ) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + # TODO: no dtype conversion here, unlike in Linear, is that correct? + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_embedding_A: + continue + embedding_A = self.lora_embedding_A[active_adapter].T + embedding_B = self.lora_embedding_B[active_adapter].T + scaling = self.scaling[active_adapter] + after_A = self._embed(x, embedding_A) + result = result + (after_A @ embedding_B) * scaling + result = result.to(torch_result_dtype) + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +class Conv2d(nn.Module, LoraLayer): + # Lora implemented in a conv2d layer + def __init__( + self, + base_layer: nn.Module, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: Union[bool, str] = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ) -> None: + super().__init__() + LoraLayer.__init__(self, base_layer) + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) + + def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora): + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout[adapter_name] = lora_dropout_layer + # Actual trainable parameters + base_layer = self.get_base_layer() + kernel_size = base_layer.kernel_size + stride = base_layer.stride + padding = base_layer.padding + self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False) + self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False) + if use_rslora: + self.scaling[adapter_name] = lora_alpha / math.sqrt(r) + else: + self.scaling[adapter_name] = lora_alpha / r + + if init_lora_weights == "loftq": + self.loftq_init(adapter_name) + elif init_lora_weights: + self.reset_lora_parameters(adapter_name, init_lora_weights) + + weight = getattr(base_layer, "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + self.to(base_layer.weight.device, dtype=weight.dtype) + + if use_dora: + self.dora_init(adapter_name) + self.use_dora[adapter_name] = True + else: + self.use_dora[adapter_name] = False + + self.set_adapter(self.active_adapters) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights inside the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.lora_A.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weights = base_layer.weight.data.clone() + delta_weight = self.get_delta_weight(active_adapter) + + if not self.use_dora[active_adapter]: + orig_weights = orig_weights + delta_weight + else: + # handle dora + # since delta_weight already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm(orig_weights, delta_weight, scaling=1).detach() + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + orig_weights = dora_factor.view(-1, 1, 1, 1) * (orig_weights + delta_weight) + + if not torch.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + base_layer.weight.data = orig_weights + else: + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + base_layer.weight.data = base_layer.weight.data + delta_weight + else: + # handle dora + # since delta_weight already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm(base_layer.weight, delta_weight, scaling=1).detach() + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + new_weight = dora_factor.view(-1, 1, 1, 1) * (base_layer.weight.data + delta_weight) + base_layer.weight.data = new_weight + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.lora_A.keys(): + weight = self.get_base_layer().weight + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + weight.data -= delta_weight + else: + weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + weight_orig = weight.data / dora_factor.view(-1, 1, 1, 1) - delta_weight + weight.data = weight_orig + + def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + device = self.lora_B[adapter].weight.device + dtype = self.lora_A[adapter].weight.dtype + + # In case users wants to merge the adapter weights that are in + # float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16. + cast_to_fp32 = device.type == "cpu" and dtype == torch.float16 + + weight_A = self.lora_A[adapter].weight + weight_B = self.lora_B[adapter].weight + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + + # https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117 + if self.get_base_layer().weight.size()[2:4] == (1, 1): + # conv2d 1x1 + output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze( + 3 + ) * self.scaling[adapter] + else: + # conv2d 3x3 + output_tensor = ( + F.conv2d( + weight_A.permute(1, 0, 2, 3), + weight_B, + ).permute(1, 0, 2, 3) + * self.scaling[adapter] + ) + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + self.lora_A[adapter].weight.data = weight_A.to(dtype) + self.lora_B[adapter].weight.data = weight_B.to(dtype) + + return output_tensor + + def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: + # calculate L2 norm of weight matrix, channel-wise + weight = weight + scaling * lora_weight + # the following is needed to have compatibility with the 4D weight tensors of Conv2D + weight_norm = weight.norm(p=2, dim=(1, 2, 3), keepdim=True).transpose(1, 0) + return weight_norm + + def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter): + """ + For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer + output. + """ + base_layer = self.get_base_layer() + weight = base_layer.weight + lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1)) + lora_weight = lora_weight.reshape(weight.shape) + magnitude = self.lora_magnitude_vector[active_adapter] + weight_norm = self._get_weight_norm(weight, lora_weight, scaling) + # see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) + # "[...] we suggest treating ||V +∆V ||_c in + # Eq. (5) as a constant, thereby detaching it from the gradient + # graph. This means that while ||V + ∆V ||_c dynamically + # reflects the updates of ∆V , it won’t receive any gradient + # during backpropagation" + weight_norm = weight_norm.detach() + mag_norm_scale = magnitude / weight_norm + result_dora = (mag_norm_scale - 1) * ( + F.conv2d( + x, + weight, + bias=None, + stride=base_layer.stride, + padding=base_layer.padding, + dilation=base_layer.dilation, + groups=base_layer.groups, + ) + ) + mag_norm_scale * lora_B(lora_A(x)) * scaling + + return result_dora + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(lora_A.weight.dtype) + + if not self.use_dora[active_adapter]: + result = result + lora_B(lora_A(dropout(x))) * scaling + else: + x = dropout(x) + result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +class InflatedConv3d(nn.Module, LoraLayer): + # Lora implemented in a conv2d layer + def __init__( + self, + base_layer: nn.Module, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: Union[bool, str] = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ) -> None: + super().__init__() + LoraLayer.__init__(self, base_layer) + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) + + def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora): + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout[adapter_name] = lora_dropout_layer + # Actual trainable parameters + base_layer = self.get_base_layer() + kernel_size = base_layer.kernel_size + stride = base_layer.stride + padding = base_layer.padding + self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False) + self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False) + if use_rslora: + self.scaling[adapter_name] = lora_alpha / math.sqrt(r) + else: + self.scaling[adapter_name] = lora_alpha / r + + if init_lora_weights == "loftq": + self.loftq_init(adapter_name) + elif init_lora_weights: + self.reset_lora_parameters(adapter_name, init_lora_weights) + + weight = getattr(base_layer, "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + self.to(base_layer.weight.device, dtype=weight.dtype) + + if use_dora: + self.dora_init(adapter_name) + self.use_dora[adapter_name] = True + else: + self.use_dora[adapter_name] = False + + self.set_adapter(self.active_adapters) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights inside the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.lora_A.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weights = base_layer.weight.data.clone() + delta_weight = self.get_delta_weight(active_adapter) + + if not self.use_dora[active_adapter]: + orig_weights = orig_weights + delta_weight + else: + # handle dora + # since delta_weight already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm(orig_weights, delta_weight, scaling=1).detach() + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + orig_weights = dora_factor.view(-1, 1, 1, 1) * (orig_weights + delta_weight) + + if not torch.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + base_layer.weight.data = orig_weights + else: + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + base_layer.weight.data = base_layer.weight.data + delta_weight + else: + # handle dora + # since delta_weight already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm(base_layer.weight, delta_weight, scaling=1).detach() + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + new_weight = dora_factor.view(-1, 1, 1, 1) * (base_layer.weight.data + delta_weight) + base_layer.weight.data = new_weight + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.lora_A.keys(): + weight = self.get_base_layer().weight + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + weight.data -= delta_weight + else: + weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + weight_orig = weight.data / dora_factor.view(-1, 1, 1, 1) - delta_weight + weight.data = weight_orig + + def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + device = self.lora_B[adapter].weight.device + dtype = self.lora_A[adapter].weight.dtype + + # In case users wants to merge the adapter weights that are in + # float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16. + cast_to_fp32 = device.type == "cpu" and dtype == torch.float16 + + weight_A = self.lora_A[adapter].weight + weight_B = self.lora_B[adapter].weight + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + + # https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117 + if self.get_base_layer().weight.size()[2:4] == (1, 1): + # conv2d 1x1 + output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze( + 3 + ) * self.scaling[adapter] + else: + # conv2d 3x3 + output_tensor = ( + F.conv2d( + weight_A.permute(1, 0, 2, 3), + weight_B, + ).permute(1, 0, 2, 3) + * self.scaling[adapter] + ) + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + self.lora_A[adapter].weight.data = weight_A.to(dtype) + self.lora_B[adapter].weight.data = weight_B.to(dtype) + + return output_tensor + + def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: + # calculate L2 norm of weight matrix, channel-wise + weight = weight + scaling * lora_weight + # the following is needed to have compatibility with the 4D weight tensors of Conv2D + weight_norm = weight.norm(p=2, dim=(1, 2, 3), keepdim=True).transpose(1, 0) + return weight_norm + + def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter): + """ + For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer + output. + """ + base_layer = self.get_base_layer() + weight = base_layer.weight + lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1)) + lora_weight = lora_weight.reshape(weight.shape) + magnitude = self.lora_magnitude_vector[active_adapter] + weight_norm = self._get_weight_norm(weight, lora_weight, scaling) + # see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) + # "[...] we suggest treating ||V +∆V ||_c in + # Eq. (5) as a constant, thereby detaching it from the gradient + # graph. This means that while ||V + ∆V ||_c dynamically + # reflects the updates of ∆V , it won’t receive any gradient + # during backpropagation" + weight_norm = weight_norm.detach() + mag_norm_scale = magnitude / weight_norm + result_dora = (mag_norm_scale - 1) * ( + F.conv2d( + x, + weight, + bias=None, + stride=base_layer.stride, + padding=base_layer.padding, + dilation=base_layer.dilation, + groups=base_layer.groups, + ) + ) + mag_norm_scale * lora_B(lora_A(x)) * scaling + + return result_dora + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + ori_dim = x.ndim + if ori_dim == 5: + frames = x.shape[2] + x = rearrange(x, "b c f h w -> (b f) c h w") + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(lora_A.weight.dtype) + + if not self.use_dora[active_adapter]: + result = result + lora_B(lora_A(dropout(x))) * scaling + else: + x = dropout(x) + result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) + + result = result.to(torch_result_dtype) + if ori_dim == 5: + result = rearrange(result, "(b f) c h w -> b c f h w", f=frames) + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +def dispatch_default( + target: torch.nn.Module, + adapter_name: str, + lora_config: LoraConfig, + **kwargs, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, torch.nn.Embedding): + embedding_kwargs = kwargs.copy() + embedding_kwargs.pop("fan_in_fan_out", None) + embedding_kwargs.update(lora_config.loftq_config) + new_module = Embedding(target, adapter_name, **embedding_kwargs) + elif 'InflatedConv3d' in str(type(target_base_layer)): + kwargs.update(lora_config.loftq_config) + new_module = InflatedConv3d(target, adapter_name, **kwargs) + elif isinstance(target_base_layer, torch.nn.Conv2d): + kwargs.update(lora_config.loftq_config) + new_module = Conv2d(target, adapter_name, **kwargs) + elif isinstance(target_base_layer, torch.nn.Linear): + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." + ) + kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False + kwargs.update(lora_config.loftq_config) + new_module = Linear(target, adapter_name, **kwargs) + elif isinstance(target_base_layer, Conv1D): + if not kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to False but the target module is `Conv1D`. " "Setting fan_in_fan_out to True." + ) + kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True + kwargs.update(lora_config.loftq_config) + new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs) + + return new_module diff --git a/peft/tuners/lora/model.py b/peft/tuners/lora/model.py new file mode 100644 index 0000000000000000000000000000000000000000..6148bdcead73a828b5e9da59f812ab4f92edf616 --- /dev/null +++ b/peft/tuners/lora/model.py @@ -0,0 +1,886 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import math +import operator +import re +import warnings +from contextlib import contextmanager +from dataclasses import asdict, replace +from enum import Enum +from functools import partial, reduce +from itertools import chain +from typing import Literal, Optional + +import torch +from torch import nn +from tqdm import tqdm + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.tuners.tuners_utils import ( + BaseTuner, + BaseTunerLayer, + check_target_module_exists, + onload_layer, + replicate_layers, +) +from peft.utils import ( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, + _freeze_adapter, + _get_submodules, + get_peft_model_state_dict, + get_quantization_config, +) +from peft.utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties + +from .aqlm import dispatch_aqlm +from .awq import dispatch_awq +from .config import LoraConfig +from .eetq import dispatch_eetq +from .gptq import dispatch_gptq +from .hqq import dispatch_hqq +from .layer import Conv2d, LoraLayer, dispatch_default +from .tp_layer import dispatch_megatron + + +def _adapter_names_pre_forward_hook(target, args, kwargs, adapter_names): + # pre-forward hook to inject the adapter_names argument when using mixed adapter batches inference + kwargs["adapter_names"] = adapter_names + return args, kwargs + + +class LoraModel(BaseTuner): + """ + Creates Low Rank Adapter (LoRA) model from a pretrained transformers model. + + The method is described in detail in https://arxiv.org/abs/2106.09685. + + Args: + model ([`torch.nn.Module`]): The model to be adapted. + config ([`LoraConfig`]): The configuration of the Lora model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + `torch.nn.Module`: The Lora model. + + Example: + + ```py + >>> from transformers import AutoModelForSeq2SeqLM + >>> from peft import LoraModel, LoraConfig + + >>> config = LoraConfig( + ... task_type="SEQ_2_SEQ_LM", + ... r=8, + ... lora_alpha=32, + ... target_modules=["q", "v"], + ... lora_dropout=0.01, + ... ) + + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + >>> lora_model = LoraModel(model, config, "default") + ``` + + ```py + >>> import torch + >>> import transformers + >>> from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training + + >>> rank = ... + >>> target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc_in", "fc_out", "wte"] + >>> config = LoraConfig( + ... r=4, lora_alpha=16, target_modules=target_modules, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM" + ... ) + >>> quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True) + + >>> tokenizer = transformers.AutoTokenizer.from_pretrained( + ... "kakaobrain/kogpt", + ... revision="KoGPT6B-ryan1.5b-float16", # or float32 version: revision=KoGPT6B-ryan1.5b + ... bos_token="[BOS]", + ... eos_token="[EOS]", + ... unk_token="[UNK]", + ... pad_token="[PAD]", + ... mask_token="[MASK]", + ... ) + >>> model = transformers.GPTJForCausalLM.from_pretrained( + ... "kakaobrain/kogpt", + ... revision="KoGPT6B-ryan1.5b-float16", # or float32 version: revision=KoGPT6B-ryan1.5b + ... pad_token_id=tokenizer.eos_token_id, + ... use_cache=False, + ... device_map={"": rank}, + ... torch_dtype=torch.float16, + ... quantization_config=quantization_config, + ... ) + >>> model = prepare_model_for_kbit_training(model) + >>> lora_model = get_peft_model(model, config) + ``` + + **Attributes**: + - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. + - **peft_config** ([`LoraConfig`]): The configuration of the Lora model. + """ + + prefix: str = "lora_" + + def __init__(self, model, config, adapter_name) -> None: + super().__init__(model, config, adapter_name) + + def _check_new_adapter_config(self, config: LoraConfig) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + # TODO: there should be a check if any of the existing adapters actually has bias != "none", or else the check + # does not fully correspond to the error message. + if (len(self.peft_config) > 1) and (config.bias != "none"): + raise ValueError( + f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, " + "set bias to 'none' for all adapters." + ) + + @staticmethod + def _check_target_module_exists(lora_config, key): + return check_target_module_exists(lora_config, key) + + def _prepare_model(self, peft_config: LoraConfig, model: nn.Module): + r""" + A private method to modify the model structure before adapter is applied. + + Args: + peft_config (`PeftConfig`): + The prepared adapter config. + model (`nn.Module`): + The model that is going to be adapted. + """ + if peft_config.layer_replication: + replicate_layers(model, peft_config.layer_replication) + + def _create_and_replace( + self, + lora_config, + adapter_name, + target, + target_name, + parent, + current_key, + ): + if current_key is None: + raise ValueError("Current Key shouldn't be `None`") + + # Regexp matching - Find key which matches current target_name in patterns provided + pattern_keys = list(chain(lora_config.rank_pattern.keys(), lora_config.alpha_pattern.keys())) + target_name_key = next(filter(lambda key: re.match(rf".*\.{key}$", current_key), pattern_keys), current_key) + r = lora_config.rank_pattern.get(target_name_key, lora_config.r) + alpha = lora_config.alpha_pattern.get(target_name_key, lora_config.lora_alpha) + + kwargs = { + "r": r, + "lora_alpha": alpha, + "lora_dropout": lora_config.lora_dropout, + "fan_in_fan_out": lora_config.fan_in_fan_out, + "init_lora_weights": lora_config.init_lora_weights, + "use_rslora": lora_config.use_rslora, + "use_dora": lora_config.use_dora, + "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), + "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), + } + + quant_methods = ["gptq", "aqlm", "awq"] + for quant_method in quant_methods: + quantization_config = get_quantization_config(self.model, method=quant_method) + if quantization_config is not None: + kwargs[f"{quant_method}_quantization_config"] = quantization_config + + # note: AdaLoraLayer is a subclass of LoraLayer, we need to exclude it + from peft.tuners.adalora import AdaLoraLayer + + if isinstance(target, LoraLayer) and not isinstance(target, AdaLoraLayer): + target.update_layer( + adapter_name, + r, + lora_alpha=alpha, + lora_dropout=lora_config.lora_dropout, + init_lora_weights=lora_config.init_lora_weights, + use_rslora=lora_config.use_rslora, + use_dora=lora_config.use_dora, + ) + else: + new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs) + if adapter_name not in self.active_adapters: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + + def _replace_module(self, parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + if (self.prefix in name) or ("ranknum" in name): + weight = ( + child.qweight + if hasattr(child, "qweight") + else child.W_q + if hasattr(child, "W_q") + else child.weight + ) + module.to(weight.device) + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.named_parameters(): + if self.prefix not in n: + p.requires_grad = False + + for active_adapter in self.active_adapters: + bias = self.peft_config[active_adapter].bias + if bias == "none": + continue + + if bias == "all": + for n, p in model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "lora_only": + for m in model.modules(): + if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError(f"Requested bias: {bias}, is not implemented.") + + @staticmethod + def _create_new_module(lora_config, adapter_name, target, **kwargs): + # Collect dispatcher functions to decide what backend to use for the replaced LoRA layer. The order matters, + # because the first match is always used. Therefore, the default layers should be checked last. + dispatchers = [] + + # avoid eager bnb import + if is_bnb_available(): + from .bnb import dispatch_bnb_8bit + + dispatchers.append(dispatch_bnb_8bit) + + if is_bnb_4bit_available(): + from .bnb import dispatch_bnb_4bit + + dispatchers.append(dispatch_bnb_4bit) + + dispatchers.extend( + [ + dispatch_eetq, + dispatch_aqlm, + dispatch_awq, + dispatch_gptq, + dispatch_hqq, + dispatch_megatron, + dispatch_default, + ] + ) + + new_module = None + for dispatcher in dispatchers: + new_module = dispatcher(target, adapter_name, lora_config=lora_config, **kwargs) + if new_module is not None: # first match wins + break + + if new_module is None: + # no module could be matched + raise ValueError( + f"Target module {target} is not supported. Currently, only the following modules are supported: " + "`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`." + ) + + return new_module + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.model, name) + + def get_peft_config_as_dict(self, inference: bool = False): + config_dict = {} + for key, value in self.peft_config.items(): + config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} + if inference: + config["inference_mode"] = True + config_dict[key] = config + return config + + def _set_adapter_layers(self, enabled: bool = True) -> None: + for module in self.model.modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self) -> None: + """Enable all adapters. + + Call this if you have previously disabled all adapters and want to re-enable them. + """ + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self) -> None: + """Disable all adapters. + + When disabling all adapters, the model output corresponds to the output of the base model. + """ + for active_adapter in self.active_adapters: + val = self.peft_config[active_adapter].bias + if val != "none": + msg = ( + f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " + "output as the the base model would without adaption." + ) + warnings.warn(msg) + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name: str | list[str]) -> None: + """Set the active adapter(s). + + Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is + not desired, use the following code. + + ```py + >>> for name, param in model_peft.named_parameters(): + ... if ...: # some check on name (ex. if 'lora' in name) + ... param.requires_grad = False + ``` + + Args: + adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + """ + for module in self.model.modules(): + if isinstance(module, LoraLayer): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + @contextmanager + def _enable_peft_forward_hooks(self, *args, **kwargs): + # If adapter_names is passed as an argument, we inject it into the forward arguments. + adapter_names = kwargs.pop("adapter_names", None) + if adapter_names is None: + # nothing to do + yield + return + + if self.training: + raise ValueError("Cannot pass `adapter_names` when the model is in training mode.") + + hook_handles = [] + for module in self.modules(): + if isinstance(module, LoraLayer): + pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names) + handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) + hook_handles.append(handle) + + yield + + for handle in hook_handles: + handle.remove() + + def _check_merge_allowed(self): + """Verify that the configuration supports merging. + + Currently gptq quantization and replicated layers do not support merging. + """ + if getattr(self.model, "quantization_method", None) == "gptq": + raise ValueError("Cannot merge LORA layers when the model is gptq quantized") + if self.peft_config.get("layer_replication"): + raise ValueError("Cannot merge LORA layers when base model layers are replicated") + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return peft_config + + def _unload_and_optionally_merge( + self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[list[str]] = None, + ): + if merge: + self._check_merge_allowed() + + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + desc = "Unloading " + ("and merging " if merge else "") + "model" + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + with onload_layer(target): + if hasattr(target, "base_layer"): + if merge: + target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + new_module = target.modules_to_save[target.active_adapter] + if hasattr(new_module, "base_layer"): + # check if the module is itself a tuner layer + if merge: + new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names) + new_module = new_module.get_base_layer() + setattr(parent, target_name, new_module) + + return self.model + + def _check_add_weighted_adapter( + self, adapters: list[str], combination_type: str, svd_rank: int | None + ) -> tuple[str, int, str]: + """ + Helper function to check if the arguments to add_weighted_adapter are valid and compatible with the underlying + model. + """ + for adapter in adapters: + if adapter not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter} does not exist") + + # If more than one of the adapters targets the same module with modules_to_save, raise an error, as these + # modules cannot be merged. First, find the ModulesToSaveWrapper instances in the model, then check if they + # have modules for the adapters to be merged. + modules_to_save_wrappers = [module for module in self.modules() if isinstance(module, ModulesToSaveWrapper)] + problematic_wrappers = [ + wrapper + for wrapper in modules_to_save_wrappers + if sum(adapter in wrapper.modules_to_save for adapter in adapters) > 1 + ] + if problematic_wrappers: + raise ValueError( + "Cannot add weighted adapters if they target the same module with modules_to_save, but found " + f"{len(problematic_wrappers)} such instance(s)." + ) + + # if there is only one adapter, we can only use linear merging + combination_type = "linear" if len(adapters) == 1 else combination_type + + adapters_ranks = [self.peft_config[adapter].r for adapter in adapters] + if combination_type in ("linear", "ties", "dare_ties", "dare_linear", "magnitude_prune"): + # all adapters ranks should be same, new rank is just this value + if len(set(adapters_ranks)) != 1: + raise ValueError( + "All adapters must have the same r value when using combination_type linear, ties, dare_ties or " + "dare_linear." + ) + new_rank = adapters_ranks[0] + elif combination_type == "cat": + # adapters ranks may be different, new rank is sum of all ranks + # be careful, because output adapter rank may be really big if mixing a lot of adapters + new_rank = sum(adapters_ranks) + elif combination_type.endswith("svd"): + # new rank is the max of all ranks of the adapters if not provided + new_rank = svd_rank or max(adapters_ranks) + else: + raise ValueError(f"Invalid combination_type: {combination_type}") + + target_module_types = [type(self.peft_config[adapter].target_modules) for adapter in adapters] + if not target_module_types: + raise ValueError(f"Found no adapter matching the names in {adapters}") + if len(set(target_module_types)) > 1: + raise ValueError( + "all adapter configs should follow the same target modules type. " + "Combining adapters with `target_modules` type being a mix of list/set and string is not supported." + ) + + if target_module_types[0] == str: + new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters) + elif target_module_types[0] == set: + new_target_modules = reduce( + operator.or_, (self.peft_config[adapter].target_modules for adapter in adapters) + ) + else: + raise TypeError(f"Invalid type {target_module_types[0]} found in target_modules") + + return combination_type, new_rank, new_target_modules + + def add_weighted_adapter( + self, + adapters: list[str], + weights: list[float], + adapter_name: str, + combination_type: str = "svd", + svd_rank: int | None = None, + svd_clamp: int | None = None, + svd_full_matrices: bool = True, + svd_driver: str | None = None, + density: float | None = None, + majority_sign_method: Literal["total", "frequency"] = "total", + ) -> None: + """ + This method adds a new adapter by merging the given adapters with the given weights. + + When using the `cat` combination_type you should be aware that rank of the resulting adapter will be equal to + the sum of all adapters ranks. So it's possible that the mixed adapter may become too big and result in OOM + errors. + + Args: + adapters (`list`): + List of adapter names to be merged. + weights (`list`): + List of weights for each adapter. + adapter_name (`str`): + Name of the new adapter. + combination_type (`str`): + The merging type can be one of [`svd`, `linear`, `cat`, `ties`, `ties_svd`, `dare_ties`, `dare_linear`, + `dare_ties_svd`, `dare_linear_svd`, `magnitude_prune`, `magnitude_prune_svd`]. When using the `cat` + combination_type, the rank of the resulting adapter is equal to the sum of all adapters ranks (the + mixed adapter may be too big and result in OOM errors). + svd_rank (`int`, *optional*): + Rank of output adapter for svd. If None provided, will use max rank of merging adapters. + svd_clamp (`float`, *optional*): + A quantile threshold for clamping SVD decomposition output. If None is provided, do not perform + clamping. Defaults to None. + svd_full_matrices (`bool`, *optional*): + Controls whether to compute the full or reduced SVD, and consequently, the shape of the returned + tensors U and Vh. Defaults to True. + svd_driver (`str`, *optional*): + Name of the cuSOLVER method to be used. This keyword argument only works when merging on CUDA. Can be + one of [None, `gesvd`, `gesvdj`, `gesvda`]. For more info please refer to `torch.linalg.svd` + documentation. Defaults to None. + density (`float`, *optional*): + Value between 0 and 1. 0 means all values are pruned and 1 means no values are pruned. Should be used + with [`ties`, `ties_svd`, `dare_ties`, `dare_linear`, `dare_ties_svd`, `dare_linear_svd`, + `magnintude_prune`, `magnitude_prune_svd`] + majority_sign_method (`str`): + The method, should be one of ["total", "frequency"], to use to get the magnitude of the sign values. + Should be used with [`ties`, `ties_svd`, `dare_ties`, `dare_ties_svd`] + """ + + if adapter_name in list(self.peft_config.keys()): + return + for adapter in adapters: + if adapter not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter} does not exist") + + combination_type, new_rank, new_target_modules = self._check_add_weighted_adapter( + adapters=adapters, + combination_type=combination_type, + svd_rank=svd_rank, + ) + + self.peft_config[adapter_name] = replace( + self.peft_config[adapters[0]], + r=new_rank, + lora_alpha=new_rank, + target_modules=new_target_modules, + ) + self.inject_adapter(self.model, adapter_name) + + # Do we really need that? + _freeze_adapter(self.model, adapter_name) + + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, LoraLayer): + if adapter_name in target.lora_A: + target_lora_A = target.lora_A[adapter_name].weight + target_lora_B = target.lora_B[adapter_name].weight + elif adapter_name in target.lora_embedding_A: + target_lora_A = target.lora_embedding_A[adapter_name] + target_lora_B = target.lora_embedding_B[adapter_name] + else: + continue + + target_lora_A.data = target_lora_A.data * 0.0 + target_lora_B.data = target_lora_B.data * 0.0 + if combination_type == "cat": + loras_A, loras_B = [], [] + for adapter, weight in zip(adapters, weights): + if adapter in target.lora_A: + current_adapter_lora_A = target.lora_A[adapter].weight + current_adapter_lora_B = target.lora_B[adapter].weight + elif adapter in target.lora_embedding_A: + current_adapter_lora_A = target.lora_embedding_A[adapter] + current_adapter_lora_B = target.lora_embedding_B[adapter] + else: + continue + loras_A.append(current_adapter_lora_A.data * weight * target.scaling[adapter]) + loras_B.append(current_adapter_lora_B.data) + + if len(loras_A) == 0: + raise ValueError("No matching LoRAs found. Please raise an issue on GitHub.") + loras_A = torch.cat(loras_A, dim=0) + loras_B = torch.cat(loras_B, dim=1) + target_lora_A.data[: loras_A.shape[0], :] = loras_A + target_lora_B.data[:, : loras_B.shape[1]] = loras_B + elif combination_type in [ + "svd", + "ties_svd", + "dare_linear_svd", + "dare_ties_svd", + "magnitude_prune_svd", + ]: + target_lora_A.data, target_lora_B.data = self._svd_generalized_task_arithmetic_weighted_adapter( + combination_type, + adapters, + weights, + new_rank, + target, + target_lora_A, + target_lora_B, + density, + majority_sign_method, + svd_clamp, + full_matrices=svd_full_matrices, + driver=svd_driver, + ) + elif combination_type in ["linear", "ties", "dare_linear", "dare_ties", "magnitude_prune"]: + target_lora_A.data, target_lora_B.data = self._generalized_task_arithmetic_weighted_adapter( + combination_type, adapters, weights, target, density, majority_sign_method + ) + + def _svd_generalized_task_arithmetic_weighted_adapter( + self, + combination_type, + adapters, + weights, + new_rank, + target, + target_lora_A, + target_lora_B, + density, + majority_sign_method, + clamp=None, + full_matrices=True, + driver=None, + ): + valid_adapters = [] + valid_weights = [] + is_embedding = any(adapter in target.lora_embedding_A for adapter in adapters) + for adapter, weight in zip(adapters, weights): + if adapter in target.lora_A or adapter in target.lora_embedding_A: + valid_adapters.append(adapter) + valid_weights.append(weight * target.scaling[adapter]) + + # if no valid adapter, nothing to do + if len(valid_adapters) == 0: + raise ValueError("No matching LoRAs found. Please raise an issue on Github.") + delta_weight = [target.get_delta_weight(adapter) for adapter in valid_adapters] + valid_weights = torch.tensor(valid_weights).to(delta_weight[0].device) + if combination_type == "svd": + delta_weight = task_arithmetic(delta_weight, valid_weights) + elif combination_type == "ties_svd": + delta_weight = ties(delta_weight, valid_weights, density, majority_sign_method) + elif combination_type == "dare_linear_svd": + delta_weight = dare_linear(delta_weight, valid_weights, density) + elif combination_type == "dare_ties_svd": + delta_weight = dare_ties(delta_weight, valid_weights, density, majority_sign_method) + elif combination_type == "magnitude_prune_svd": + delta_weight = magnitude_prune(delta_weight, valid_weights, density) + else: + raise ValueError(f"Invalid value passed to combination type: {combination_type}") + + conv2d = isinstance(target, Conv2d) + if conv2d: + conv2d_1x1 = target.weight.size()[2:4] == (1, 1) + if not conv2d_1x1: + delta_weight = delta_weight.flatten(start_dim=1) + else: + delta_weight = delta_weight.squeeze() + if (hasattr(target, "fan_in_fan_out") and target.fan_in_fan_out) or is_embedding: + delta_weight = delta_weight.T + + # based on https://github.com/kohya-ss/sd-scripts/blob/main/networks/svd_merge_lora.py#L114-L131 + U, S, Vh = torch.linalg.svd(delta_weight, full_matrices=full_matrices, driver=driver) + U = U[:, :new_rank] + S = S[:new_rank] + U = U @ torch.diag(S) + Vh = Vh[:new_rank, :] + if clamp is not None: + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, clamp) + low_val = -hi_val + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + if conv2d: + U = U.reshape(target_lora_B.data.shape) + Vh = Vh.reshape(target_lora_A.data.shape) + return Vh, U + + def _generalized_task_arithmetic_weighted_adapter( + self, + combination_type, + adapters, + weights, + target, + density, + majority_sign_method, + ): + # account weights for LoRA A and B layers. + valid_weights = [] + lora_A_deltas = [] + lora_B_deltas = [] + for adapter, weight in zip(adapters, weights): + if adapter in target.lora_A: + current_adapter_lora_A = target.lora_A[adapter].weight + current_adapter_lora_B = target.lora_B[adapter].weight + elif adapter in target.lora_embedding_A: + current_adapter_lora_A = target.lora_embedding_A[adapter] + current_adapter_lora_B = target.lora_embedding_B[adapter] + else: + continue + valid_weights.append(math.sqrt(weight * target.scaling[adapter])) + lora_A_deltas.append(current_adapter_lora_A.data) + lora_B_deltas.append(current_adapter_lora_B.data) + valid_weights = torch.tensor(valid_weights).to(lora_A_deltas[0].device) + lora_deltas = [lora_A_deltas, lora_B_deltas] + dtype = lora_A_deltas[0].dtype + for i, task_tensors in enumerate(lora_deltas): + if combination_type == "linear": + lora_deltas[i] = task_arithmetic(task_tensors, valid_weights) + elif combination_type == "ties": + lora_deltas[i] = ties(task_tensors, valid_weights, density, majority_sign_method) + elif combination_type == "dare_linear": + lora_deltas[i] = dare_linear(task_tensors, valid_weights, density) + elif combination_type == "dare_ties": + lora_deltas[i] = dare_ties(task_tensors, valid_weights, density, majority_sign_method) + elif combination_type == "magnitude_prune": + lora_deltas[i] = magnitude_prune(task_tensors, valid_weights, density) + else: + raise ValueError("Invalid combination type") + lora_deltas = [delta.to(dtype) for delta in lora_deltas] + return lora_deltas + + def delete_adapter(self, adapter_name: str) -> None: + """ + Deletes an existing adapter. + + Args: + adapter_name (str): Name of the adapter to be deleted. + """ + if adapter_name not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter_name} does not exist") + del self.peft_config[adapter_name] + + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, LoraLayer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapters[:] + + self.active_adapter = new_adapter or [] + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> torch.nn.Module: + r""" + This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model + as a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + Example: + + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import PeftModel + + >>> base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b") + >>> peft_model_id = "smangrul/falcon-40B-int4-peft-lora-sfttrainer-sample" + >>> model = PeftModel.from_pretrained(base_model, peft_model_id) + >>> merged_model = model.merge_and_unload() + ``` + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self) -> torch.nn.Module: + """ + Gets back the base model by removing all the lora modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) + + def subtract_pissa_init( + self, output_state_dict: dict[str, torch.Tensor], adapter_name: str = "pissa_init", kwargs=None + ): + """ + This function can calculate the updates of the PiSSA by comparing the parameters of the PiSSA adapter in + `output_state_dict` with the initial values of PiSSA in `adapter_name`, thus converting PiSSA to LoRA. + """ + for name, param in self.model.named_parameters(): + if ( + param.data.dtype != torch.float32 + and param.data.dtype != torch.float16 + and param.data.dtype != torch.bfloat16 + ): + warnings.warn( + r"Note that Quant(W_res) + AB != Quant(W) + \Delta(AB); " + "the converted LoRA, when combined with W or Quant(W), may introduce a certain gap in the fine-tuned model. " + "Therefore, we recommend directly using the Quant(W_res) in conjunction with the PiSSA adapter. " + ) + pissa_init_state_dict = get_peft_model_state_dict( + self, + state_dict=kwargs.get("state_dict", None), + adapter_name=adapter_name, + ) + tensors_lora = {} + for name in output_state_dict.keys(): + ## W = W^{res} + A_0 \times B_0, + ## W + \Delta W = W^{res} + A \times B, + ## \Delta W = A \times B - A_0 \times B_0 = [A | A_0] \times [B | -B_0]^T = A'B'. + if "lora_A" in name: + tensors_lora[name] = torch.cat( + [output_state_dict[name], pissa_init_state_dict[".".join(name.split(".")[1:])]], dim=0 + ) + elif "lora_B" in name: + tensors_lora[name] = torch.cat( + [output_state_dict[name], -pissa_init_state_dict[".".join(name.split(".")[1:])]], dim=1 + ) + + return tensors_lora diff --git a/peft/tuners/lora/tp_layer.py b/peft/tuners/lora/tp_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..f98b42c15f0a0fa929c2dd13a94a8ffc19bf9467 --- /dev/null +++ b/peft/tuners/lora/tp_layer.py @@ -0,0 +1,230 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import warnings +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.nn.init as init + +from peft.tuners.tuners_utils import BaseTunerLayer + +from .layer import LoraLayer + + +class LoraParallelLinear(nn.Module, LoraLayer): + """ + When the target layer parallel_linear is RowParallelLinear, in order to keep the input and output shapes + consistent, we need to split the lora matrix A into rows, and the lora_B at this time should be a complete linear + layer; In the same way, when the target layer is ColumnParallelLinear, we perform column segmentation on lora_B, + while lora_A is still a complete linear layer. + """ + + def __init__( + self, + base_layer, + adapter_name: str, + backend, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ): + super().__init__() + LoraLayer.__init__(self, base_layer=base_layer) + + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + + self.backend = backend + self.is_parallel_a = isinstance(base_layer, backend.RowParallelLinear) + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + + megatron_config = kwargs["megatron_config"] + parallel_linear_kwargs = {"megatron_config": megatron_config} + init_method = init.xavier_normal_ + if hasattr(megatron_config, "init_method"): + init_method = megatron_config.init_method + input_is_parallel = True + gather_output = False + if isinstance(base_layer, self.backend.RowParallelLinear): + input_is_parallel = base_layer.input_is_parallel + else: + gather_output = base_layer.gather_output + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + init_method=init_method, + input_is_parallel=input_is_parallel, + gather_output=gather_output, + **parallel_linear_kwargs, + ) + + self.is_target_conv_1d_layer = False + + def update_layer( + self, + adapter_name, + r, + lora_alpha, + lora_dropout, + init_lora_weights, + use_rslora, + use_dora=False, + init_method=init.xavier_normal_, + input_is_parallel=True, + gather_output=False, + **parallel_linear_kwargs, + ): + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout[adapter_name] = lora_dropout_layer + + megatron_config = parallel_linear_kwargs["megatron_config"] + # lora needs to be forced to upgrade to 32-bit precision, otherwise it will overflow + megatron_config.params_dtype = torch.float32 + if self.is_parallel_a: + lora_a = self.backend.RowParallelLinear( + input_size=self.in_features, + output_size=r, + bias=False, + input_is_parallel=input_is_parallel, + skip_bias_add=True, + init_method=init_method, + config=megatron_config, + ) + lora_b = nn.Linear(in_features=r, out_features=self.out_features, bias=False, dtype=torch.float32) + else: + lora_a = nn.Linear(in_features=self.in_features, out_features=r, bias=False, dtype=torch.float32) + lora_b = self.backend.ColumnParallelLinear( + input_size=r, + output_size=self.out_features, + bias=False, + gather_output=gather_output, + init_method=init_method, + config=megatron_config, + ) + self.lora_A[adapter_name] = lora_a + self.lora_B[adapter_name] = lora_b + if use_rslora: + self.scaling[adapter_name] = lora_alpha / (r**0.5) + else: + self.scaling[adapter_name] = lora_alpha / r + if init_lora_weights: + self.reset_lora_parameters(adapter_name, init_lora_weights) + + weight = getattr(self.get_base_layer(), "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + if weight.dtype.is_floating_point or weight.dtype.is_complex: + self.to(weight.device, dtype=weight.dtype) + else: + self.to(weight.device) + self.set_adapter(self.active_adapters) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): + previous_dtype = x.dtype + # If weight is used for matrix multiplication here, the final aggregation operation of the original + # parallel_linear layer will be missing, so we need to directly call its forward function to obtain the + # output of the original parallel_linear layer. + if self.disable_adapters: + if self.merged: + self.unmerge() + result, bias = self.base_layer(x, *args, **kwargs) + elif self.merged: + result, bias = self.base_layer(x, *args, **kwargs) + else: + result, bias = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(lora_A.weight.dtype) + + lora_result = lora_A(dropout(x)) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + lora_result = lora_B(lora_result) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + lora_result = lora_result * scaling + + result = result + lora_result + + result = result.to(previous_dtype) + return result, bias + + +def dispatch_megatron( + target: torch.nn.Module, + adapter_name: str, + lora_config, + **kwargs: Any, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if lora_config.megatron_config: + megatron_core = importlib.import_module(lora_config.megatron_core) + else: + megatron_core = None + + if megatron_core and isinstance( + target_base_layer, + (megatron_core.tensor_parallel.ColumnParallelLinear, megatron_core.tensor_parallel.RowParallelLinear), + ): + megatron_kwargs = kwargs.copy() + megatron_config = lora_config.megatron_config + if isinstance(megatron_config, dict): + transformer_config_class = megatron_core.transformer.transformer_config.TransformerConfig + megatron_config = transformer_config_class(**lora_config.megatron_config) + megatron_kwargs["megatron_config"] = megatron_config + if megatron_kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `ColumnParallelLinear` " + "or `RowParallelLinear`. " + "Setting fan_in_fan_out to False." + ) + megatron_kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False + new_module = LoraParallelLinear( + base_layer=target, adapter_name=adapter_name, backend=megatron_core.tensor_parallel, **megatron_kwargs + ) + + return new_module diff --git a/peft/tuners/lycoris_utils.py b/peft/tuners/lycoris_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1d9f02a67a89a694d6ec2a71d420f9ba470ea236 --- /dev/null +++ b/peft/tuners/lycoris_utils.py @@ -0,0 +1,429 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import warnings +from abc import abstractmethod +from dataclasses import dataclass, field +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from tqdm import tqdm + +from peft.config import PeftConfig +from peft.utils import ( + ModulesToSaveWrapper, + _get_submodules, +) + +from .tuners_utils import BaseTuner, BaseTunerLayer, check_adapters_to_merge, check_target_module_exists + + +@dataclass +class LycorisConfig(PeftConfig): + r""" + A base config for LyCORIS like adapters + """ + + rank_pattern: Optional[dict] = field( + default_factory=dict, + metadata={ + "help": ( + "The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`. " + "For example, `{model.decoder.layers.0.encoder_attn.k_proj: 8`}" + ) + }, + ) + alpha_pattern: Optional[dict] = field( + default_factory=dict, + metadata={ + "help": ( + "The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `alpha`. " + "For example, `{model.decoder.layers.0.encoder_attn.k_proj: 32`}" + ) + }, + ) + + +class LycorisLayer(BaseTunerLayer): + r""" + A base layer for LyCORIS like adapters + """ + + # adapter_layer_names needs to be defined on the child class + other_param_names = ("r", "alpha", "scaling", "rank_dropout", "module_dropout") + + def __init__(self, base_layer: nn.Module) -> None: + self.base_layer = base_layer + self.r = {} + self.alpha = {} + self.scaling = {} + self.rank_dropout = {} + self.module_dropout = {} + + # Tuner info + self._disable_adapters = False + self.merged_adapters = [] + + @property + @abstractmethod + def _available_adapters(self) -> set[str]: + ... + + def _init_empty_weights(self, cls, *args, **kwargs) -> None: + # A helper method that allows to initialize the layer of the given class without spending time to initialize the + # model weights. The implementation is inspired by + # https://pytorch.org/docs/stable/generated/torch.nn.utils.skip_init.html but this function cannot be used + # directly. + # Instead of this approach, it would be possible to bypass the __init__ of the class but that runs the risk of + # omitting important logic inside that __init__. + kwargs = kwargs.copy() + final_device = kwargs.pop("device", "cpu") + cls.__init__(self, *args, device="meta", **kwargs) + self.to_empty(device=final_device) + + @abstractmethod + def create_adapter_parameters(self, adapter_name: str, r: int, **kwargs): + ... + + # TODO: refactor LoRA to use the same approach + @abstractmethod + def _get_delta_activations(self, adapter_name: str, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + """Activations added on top of the base layer output (i.e. after the base layer forward pass)""" + + @abstractmethod + def get_delta_weight(self, adapter_name: str) -> torch.Tensor: + ... + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If `None`, all active adapters will be merged. + Defaults to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self._available_adapters: + base_layer = self.get_base_layer() + if safe_merge: + orig_weights = base_layer.weight.data.clone() + orig_weights += self.get_delta_weight(active_adapter) + + if not torch.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = orig_weights + else: + base_layer.weight.data += self.get_delta_weight(active_adapter) + self.merged_adapters.append(active_adapter) + + @abstractmethod + def reset_adapter_parameters(self, adapter_name: str): + ... + + def set_scale(self, adapter, scale): + if adapter not in self._available_adapters: + # Ignore the case where the adapter is not in the layer + return + self.scaling[adapter] = scale * self.alpha[adapter] / self.r[adapter] + + def scale_layer(self, scale: float) -> None: + if scale == 1: + return + + for active_adapter in self.active_adapters: + if active_adapter not in self._available_adapters: + continue + + self.scaling[active_adapter] *= scale + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self._available_adapters: + self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) + + def unscale_layer(self, scale=None) -> None: + for active_adapter in self.active_adapters: + if active_adapter not in self._available_adapters: + continue + + if scale is None: + self.scaling[active_adapter] = self.alpha[active_adapter] / self.r[active_adapter] + else: + self.scaling[active_adapter] /= scale + + @abstractmethod + def update_layer(self, adapter_name: str, r: int, alpha: float, **kwargs): + ... + + +class LycorisTuner(BaseTuner): + r""" + A base tuner for LyCORIS like adapters + """ + + prefix: str + layers_mapping: dict[type[torch.nn.Module], type[LycorisLayer]] + + def __init__(self, model, config, adapter_name): + super().__init__(model, config, adapter_name) + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.model, name) + + @staticmethod + def _check_target_module_exists(config, key): + return check_target_module_exists(config, key) + + @abstractmethod + def _create_and_replace( + self, + config: LycorisConfig, + adapter_name: str, + target: Union[LycorisLayer, nn.Module], + target_name, + parent, + current_key, + ): + ... + + @classmethod + def _create_new_module(cls, config: LycorisConfig, adapter_name: str, target: nn.Module, **kwargs) -> LycorisLayer: + # Find corresponding subtype of provided target module + new_module_cls = None + for subtype, target_cls in cls.layers_mapping.items(): + if ( + hasattr(target, "base_layer") + and isinstance(target.get_base_layer(), subtype) + and isinstance(target, BaseTunerLayer) + ): + # nested tuner layers are allowed + new_module_cls = target_cls + break + elif isinstance(target, subtype): + new_module_cls = target_cls + break + + # We didn't find corresponding type, so adapter for this layer is not supported + if new_module_cls is None: + supported_modules = ", ".join(layer.__name__ for layer in cls.layers_mapping.keys()) + raise ValueError( + f"Target module of type {type(target)} not supported, " + f"currently only adapters for {supported_modules} are supported" + ) + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, torch.nn.Conv2d): + new_module = new_module_cls(target, adapter_name=adapter_name, **kwargs) + elif isinstance(target_base_layer, torch.nn.Linear): + new_module = new_module_cls(target, adapter_name=adapter_name, **kwargs) + else: + supported_modules = ", ".join(layer.__name__ for layer in cls.layers_mapping.keys()) + raise ValueError( + f"Target module of type {type(target)} not supported, " + f"currently only adapters for {supported_modules} are supported" + ) + + return new_module + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.named_parameters(): + if self.prefix not in n: + p.requires_grad = False + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + raise ValueError("Please specify `target_modules` in `peft_config`") + return peft_config + + def _replace_module(self, parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + if self.prefix in name: + module.to(child.weight.device) + + def _set_adapter_layers(self, enabled=True): + for module in self.model.modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def _unload_and_optionally_merge( + self, + merge: bool = True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[list[str]] = None, + ): + if merge: + if getattr(self.model, "quantization_method", None) == "gptq": + raise ValueError("Cannot merge LOHA layers when the model is gptq quantized") + + self._unloading_checks(adapter_names) + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + desc = "Unloading " + ("and merging " if merge else "") + "model" + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + + if hasattr(target, "base_layer"): + if merge: + target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + new_module = target.modules_to_save[target.active_adapter] + if hasattr(new_module, "base_layer"): + # check if the module is itself a tuner layer + if merge: + new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names) + new_module = new_module.get_base_layer() + setattr(parent, target_name, new_module) + + return self.model + + def enable_adapter_layers(self) -> None: + """Enable all adapters. + + Call this if you have previously disabled all adapters and want to re-enable them. + """ + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self) -> None: + """Disable all adapters. + + When disabling all adapters, the model output corresponds to the output of the base model. + """ + self._set_adapter_layers(enabled=False) + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> torch.nn.Module: + r""" + This method merges the adapter layers into the base model. This is needed if someone wants to use the base + model as a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self) -> torch.nn.Module: + """ + Gets back the base model by removing all the lora modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) + + def set_adapter(self, adapter_name: str | list[str]) -> None: + """Set the active adapter(s). + + Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is + not desired, use the following code. + + ```py + >>> for name, param in model_peft.named_parameters(): + ... if ...: # some check on name (ex. if 'lora' in name) + ... param.requires_grad = False + ``` + + Args: + adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + """ + for module in self.model.modules(): + if isinstance(module, LycorisLayer): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + def delete_adapter(self, adapter_name: str) -> None: + """ + Deletes an existing adapter. + + Args: + adapter_name (`str`): Name of the adapter to be deleted. + """ + if adapter_name not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter_name} does not exist") + del self.peft_config[adapter_name] + + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, LycorisLayer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapters[:] + + self.active_adapter = new_adapter or [] diff --git a/peft/tuners/mixed/__init__.py b/peft/tuners/mixed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2955d7258ddcf76b47b38fd6fd5ebeb3d1d6110c --- /dev/null +++ b/peft/tuners/mixed/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .model import COMPATIBLE_TUNER_TYPES, MixedModel + + +__all__ = ["COMPATIBLE_TUNER_TYPES", "MixedModel"] diff --git a/peft/tuners/mixed/__pycache__/__init__.cpython-310.pyc b/peft/tuners/mixed/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..434f0a0afa3e450f9a81fae0c2646177ad31e28b Binary files /dev/null and b/peft/tuners/mixed/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/mixed/__pycache__/model.cpython-310.pyc b/peft/tuners/mixed/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6af9320b2768cd109102bfa54778cf2fca1147d3 Binary files /dev/null and b/peft/tuners/mixed/__pycache__/model.cpython-310.pyc differ diff --git a/peft/tuners/mixed/model.py b/peft/tuners/mixed/model.py new file mode 100644 index 0000000000000000000000000000000000000000..d292ffec37820d9c88e374cf9891ec1485e41b8b --- /dev/null +++ b/peft/tuners/mixed/model.py @@ -0,0 +1,339 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import warnings +from typing import Any, Optional, Union + +from torch import nn +from tqdm import tqdm + +from peft.tuners import adalora, loha, lokr, lora, oft +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists +from peft.utils import ( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, + PeftType, + _get_submodules, + get_auto_gptq_quant_linear, +) + + +# Collection of constants used for all tuners +COMPATIBLE_TUNER_TYPES = (PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.ADALORA, PeftType.OFT) +PREFIXES = [lora.LoraModel.prefix, lokr.LoKrModel.prefix, loha.LoHaModel.prefix, oft.OFTModel.prefix] +Configs = Union[lora.LoraConfig, loha.LoHaConfig, lokr.LoKrConfig, adalora.AdaLoraConfig, oft.OFTConfig] +Layers = (lora.layer.LoraLayer, loha.layer.LoHaLayer, lokr.layer.LoKrLayer, adalora.layer.AdaLoraLayer, oft.OFTLayer) + + +class MixedModel(BaseTuner): + """ + A class that allows to mix different types of adapters in a single model. + + Note: This class should usually not be initialized directly. Instead, use `get_peft_model` with the argument + `mixed=True`. + + Args: + model (:obj:`nn.Module`): + The model to be tuned. + config (:obj:`PeftConfig`): + The config of the model to be tuned. The adapter type must be compatible. + adapter_name (:obj:`str`): + The name of the first adapter. + """ + + def __init__(self, model: nn.Module, config: Configs, adapter_name: str) -> None: + super().__init__(model, config, adapter_name) + + def _check_new_adapter_config(self, config: Configs) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + if not isinstance(config, Configs.__args__): + raise ValueError( + f"{self.__class__.__name__} only supports {COMPATIBLE_TUNER_TYPES} configs, but got {type(config)}." + ) + + biases = (getattr(config, "bias", None) for config in self.peft_config) + biases = [bias for bias in biases if bias not in (None, "none")] + if len(biases) > 1: + raise ValueError( + f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, " + "set bias to 'none' for all adapters." + ) + + @staticmethod + def _check_target_module_exists(config: Configs, key: str): + return check_target_module_exists(config, key) + + def _create_and_replace( + self, + config: Configs, + *args: Any, + **kwargs: Any, + ) -> None: + if isinstance(config, adalora.AdaLoraConfig): + adalora.AdaLoraModel._create_and_replace(self, config, *args, **kwargs) + elif isinstance(config, lora.LoraConfig): + lora.LoraModel._create_and_replace(self, config, *args, **kwargs) + elif isinstance(config, loha.LoHaConfig): + loha.LoHaModel._create_and_replace(self, config, *args, **kwargs) + elif isinstance(config, lokr.LoKrConfig): + lokr.LoKrModel._create_and_replace(self, config, *args, **kwargs) + elif isinstance(config, oft.OFTConfig): + oft.OFTModel._create_and_replace(self, config, *args, **kwargs) + else: + raise ValueError(f"Unsupported config type {type(config)}, should be one of {COMPATIBLE_TUNER_TYPES}.") + + def _replace_module(self, parent, child_name, new_module, child) -> None: + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.get_base_layer() + elif hasattr(child, "quant_linear_module"): + # TODO maybe not necessary to have special treatment? + child = child.quant_linear_module + + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + if any(prefix in name for prefix in PREFIXES): + module.to(child.weight.device) + if "ranknum" in name: + module.to(child.weight.device) + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.named_parameters(): + if not any(prefix in n for prefix in PREFIXES): + p.requires_grad = False + + for active_adapter in self.active_adapters: + bias = getattr(self.peft_config[active_adapter], "bias", "none") + if bias == "none": + continue + + if bias == "all": + for n, p in model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "lora_only": + # TODO: check if this is needed for other supported types + for m in model.modules(): + if isinstance(m, Layers) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise ValueError(f"Requested bias: {bias}, is not implemented.") + + @staticmethod + def _create_new_module(config, adapter_name, target, **kwargs): + gptq_quantization_config = kwargs.get("gptq_quantization_config", None) + AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) + if (gptq_quantization_config is not None) or (AutoGPTQQuantLinear is not None): + raise ValueError(f"GPTQ quantization not supported for {config.peft_type.value} (yet).") + + loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) + loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) + if loaded_in_8bit or loaded_in_4bit: + raise ValueError(f"8bit and 4bit quantization not supported for {config.peft_type.value} (yet).") + + if isinstance(config, adalora.AdaLoraConfig): + new_module = adalora.AdaLoraModel._create_new_module(config, adapter_name, target, **kwargs) + elif isinstance(config, lora.LoraConfig): + new_module = lora.LoraModel._create_new_module(config, adapter_name, target, **kwargs) + elif isinstance(config, loha.LoHaConfig): + new_module = loha.LoHaModel._create_new_module(config, adapter_name, target, **kwargs) + elif isinstance(config, lokr.LoKrConfig): + new_module = lokr.LoKrModel._create_new_module(config, adapter_name, target, **kwargs) + elif isinstance(config, oft.OFTConfig): + new_module = oft.OFTModel._create_new_module(config, adapter_name, target, **kwargs) + else: + raise ValueError(f"Unknown config type {type(config)}, should be one of {COMPATIBLE_TUNER_TYPES}.") + return new_module + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.model, name) + + def _set_adapter_layers(self, enabled=True): + for module in self.model.modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self): + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self): + for active_adapter in self.active_adapters: + val = getattr(self.peft_config[active_adapter], "bias", "none") + if val != "none": + msg = ( + f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " + "output as the the base model would without adaption." + ) + warnings.warn(msg) + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name: Union[str, list[str]]) -> None: + for module in self.model.modules(): + if isinstance(module, Layers): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return peft_config + + def _unload_and_optionally_merge( + self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[list[str]] = None, + ): + if merge: + if getattr(self.model, "quantization_method", None) == "gptq": + raise ValueError("Cannot merge layers when the model is gptq quantized") + + def merge_recursively(module): + # helper function to recursively merge the base_layer of the target + path = [] + layer = module + while hasattr(layer, "base_layer"): + path.append(layer) + layer = layer.base_layer + for layer_before, layer_after in zip(path[:-1], path[1:]): + layer_after.merge(safe_merge=safe_merge, adapter_names=adapter_names) + layer_before.base_layer = layer_after.base_layer + module.merge(safe_merge=safe_merge, adapter_names=adapter_names) + + key_list = [key for key, _ in self.model.named_modules() if not any(prefix in key for prefix in PREFIXES)] + desc = "Unloading " + ("and merging " if merge else "") + "model" + + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + + if hasattr(target, "base_layer"): + if merge: + merge_recursively(target) + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + new_module = target.modules_to_save[target.active_adapter] + if hasattr(new_module, "base_layer"): + # check if the module is itself a tuner layer + if merge: + new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names) + new_module = new_module.get_base_layer() + setattr(parent, target_name, new_module) + + return self.model + + def add_weighted_adapter(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError(f"Weighted adapters are not supported for {self.__class__.__name__} (yet).") + + def delete_adapter(self, adapter_name: Union[str, list[str]]) -> None: + """ + Deletes an existing adapter. + + Args: + adapter_name (Union[str, list[str]]): Name of the adapter(s) to delete. + """ + if isinstance(adapter_name, str): + adapter_names = [adapter_name] + else: + adapter_names = adapter_name + + mismatched = set(adapter_names) - set(self.peft_config.keys()) + if mismatched: + raise ValueError( + f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}" + ) + + for adapter_name in adapter_names: + del self.peft_config[adapter_name] + + key_list = [key for key, _ in self.model.named_modules() if not any(prefix in key for prefix in PREFIXES)] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, BaseTunerLayer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapters[:] + + self.active_adapter = new_adapter or [] + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> nn.Module: + r""" + This method merges the layers into the base model. This is needed if someone wants to use the base model as a + standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self) -> nn.Module: + """ + Gets back the base model by removing all the lora modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) + + def generate(self, *args: Any, **kwargs: Any): + return self.model.generate(*args, **kwargs) diff --git a/peft/tuners/multitask_prompt_tuning/__init__.py b/peft/tuners/multitask_prompt_tuning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..214f7722486485bea4ede3b5c1a433aac447dd2b --- /dev/null +++ b/peft/tuners/multitask_prompt_tuning/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import MultitaskPromptTuningConfig, MultitaskPromptTuningInit +from .model import MultitaskPromptEmbedding + + +__all__ = ["MultitaskPromptTuningConfig", "MultitaskPromptTuningInit", "MultitaskPromptEmbedding"] diff --git a/peft/tuners/multitask_prompt_tuning/__pycache__/__init__.cpython-310.pyc b/peft/tuners/multitask_prompt_tuning/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c70362a684b92008a222000a7b82f3d250b5932 Binary files /dev/null and b/peft/tuners/multitask_prompt_tuning/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/multitask_prompt_tuning/__pycache__/config.cpython-310.pyc b/peft/tuners/multitask_prompt_tuning/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe3dd0bc44bc0f6d8cd38cea2faec86a9965700b Binary files /dev/null and b/peft/tuners/multitask_prompt_tuning/__pycache__/config.cpython-310.pyc differ diff --git a/peft/tuners/multitask_prompt_tuning/__pycache__/model.cpython-310.pyc b/peft/tuners/multitask_prompt_tuning/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0031da8fb1055866a6be8e70886a8cb1df880e5f Binary files /dev/null and b/peft/tuners/multitask_prompt_tuning/__pycache__/model.cpython-310.pyc differ diff --git a/peft/tuners/multitask_prompt_tuning/config.py b/peft/tuners/multitask_prompt_tuning/config.py new file mode 100644 index 0000000000000000000000000000000000000000..67a3c323a299063900d42a6e464672898b13be7c --- /dev/null +++ b/peft/tuners/multitask_prompt_tuning/config.py @@ -0,0 +1,61 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +from dataclasses import dataclass, field +from typing import Optional, Union + +from peft.tuners.prompt_tuning import PromptTuningConfig +from peft.utils import PeftType + + +class MultitaskPromptTuningInit(str, enum.Enum): + # initialize prompt with text + TEXT = "TEXT" + # initialize prompt with random matrix + RANDOM = "RANDOM" + # average the prefix and column matrices obtained during source training + AVERAGE_SOURCE_TASKS = "AVERAGE_SOURCE_TASKS" + # pick prefix and column matrices for a particular task obtained during source training + EXACT_SOURCE_TASK = "EXACT_SOURCE_TASK" + # only use the prompt embeddings trained during source training + ONLY_SOURCE_SHARED = "ONLY_SOURCE_SHARED" + + +@dataclass +class MultitaskPromptTuningConfig(PromptTuningConfig): + prompt_tuning_init: Union[MultitaskPromptTuningInit, str] = field( + default=MultitaskPromptTuningInit.RANDOM, + metadata={ + "help": ( + "How to initialize the prompt tuning parameters. Can be one of TEXT, RANDOM, AVERAGE_SOURCE_TASKS, " + "EXACT_SOURCE_TASK, ONLY_SOURCE_SHARED." + ), + }, + ) + prompt_tuning_init_state_dict_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The path of source state dict. This is required when training the downstream target prompt from " + "the pretrained source prompt" + ), + }, + ) + prompt_tuning_init_task: Optional[int] = field(default=0, metadata={"help": "source task id for initialization"}) + num_ranks: Optional[int] = field(default=1, metadata={"help": "ranks"}) + num_tasks: Optional[int] = field(default=1, metadata={"help": "number of tasks"}) + + def __post_init__(self): + self.peft_type = PeftType.MULTITASK_PROMPT_TUNING diff --git a/peft/tuners/multitask_prompt_tuning/model.py b/peft/tuners/multitask_prompt_tuning/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a74ec47db44800a6300cafb9732b5725d340aae5 --- /dev/null +++ b/peft/tuners/multitask_prompt_tuning/model.py @@ -0,0 +1,119 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from peft.tuners.prompt_tuning import PromptEmbedding +from peft.utils import TaskType + +from .config import MultitaskPromptTuningConfig, MultitaskPromptTuningInit + + +# This code is adapted for the paper: https://arxiv.org/abs/2303.02861 and +# constitutes the work done at MIT-IBM Watson Research Lab. + + +class MultitaskPromptEmbedding(PromptEmbedding): + def __init__(self, config: MultitaskPromptTuningConfig, word_embeddings): + super().__init__(config, word_embeddings) + + self.num_tasks = config.num_tasks + self.num_ranks = config.num_ranks + self.num_virtual_tokens = config.num_virtual_tokens + + self.num_transformer_submodules = config.num_transformer_submodules + if self.num_transformer_submodules is None: + self.num_transformer_submodules = 2 if config.task_type == TaskType.SEQ_2_SEQ_LM else 1 + + self.token_dim = config.token_dim + + total_virtual_tokens = self.num_virtual_tokens * self.num_transformer_submodules + + self.prefix_task_cols = torch.nn.Parameter( + torch.normal( + mean=0, + std=0.02, + size=(self.num_tasks, total_virtual_tokens, self.num_ranks), + ) + ) + self.prefix_task_rows = torch.nn.Parameter( + torch.normal( + mean=0, + std=0.02, + size=(self.num_tasks, self.num_ranks, self.token_dim), + ) + ) + + if config.prompt_tuning_init in [ + MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS, + MultitaskPromptTuningInit.EXACT_SOURCE_TASK, + MultitaskPromptTuningInit.ONLY_SOURCE_SHARED, + ]: + if config.prompt_tuning_init_state_dict_path is None: + raise ValueError( + f"prompt_tuning_init_state_dict_path needs to be specified with {config.prompt_tuning_init} " + "init method" + ) + + if config.prompt_tuning_init_state_dict_path.endswith(".safetensors"): + from safetensors.torch import load_file + + state_dict: dict = load_file(config.prompt_tuning_init_state_dict_path) + else: + state_dict: dict = torch.load( + config.prompt_tuning_init_state_dict_path, + map_location=word_embeddings.weight.device, + ) + + if config.prompt_tuning_init in [ + MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS, + MultitaskPromptTuningInit.EXACT_SOURCE_TASK, + ]: + prefix_task_cols_: torch.Tensor = state_dict["prefix_task_cols"] + prefix_task_rows_: torch.Tensor = state_dict["prefix_task_rows"] + + if config.prompt_tuning_init == MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS: + prefix_task_cols_ = prefix_task_cols_.mean(0, keepdim=True) + prefix_task_rows_ = prefix_task_rows_.mean(0, keepdim=True) + elif config.prompt_tuning_init == MultitaskPromptTuningInit.EXACT_SOURCE_TASK: + prefix_task_cols_ = prefix_task_cols_[config.prompt_tuning_init_task, ...].unsqueeze(0) + prefix_task_rows_ = prefix_task_rows_[config.prompt_tuning_init_task, ...].unsqueeze(0) + + state_dict = { + "embedding.weight": state_dict["prompt_embeddings"], + "prefix_task_cols": prefix_task_cols_, + "prefix_task_rows": prefix_task_rows_, + } + + self.load_state_dict(state_dict, strict=True) + elif config.prompt_tuning_init == MultitaskPromptTuningInit.ONLY_SOURCE_SHARED: + state_dict = { + "embedding.weight": state_dict["prompt_embeddings"], + } + + self.load_state_dict(state_dict, strict=False) + + def forward(self, indices, task_ids): + if task_ids is None: + raise ValueError("task_ids cannot be None") + + prompt_embeddings = self.embedding(indices) + + task_cols = torch.index_select(self.prefix_task_cols, 0, task_ids) + task_rows = torch.index_select(self.prefix_task_rows, 0, task_ids) + task_prompts = torch.matmul(task_cols, task_rows) + + prompt_embeddings *= task_prompts + + return prompt_embeddings diff --git a/peft/tuners/oft/__init__.py b/peft/tuners/oft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52ac7131e24bd5cf39bf97ab6336ed1f1d46e152 --- /dev/null +++ b/peft/tuners/oft/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import OFTConfig +from .layer import Conv2d, Linear, OFTLayer +from .model import OFTModel + + +__all__ = ["OFTConfig", "OFTModel", "Conv2d", "Linear", "OFTLayer"] diff --git a/peft/tuners/oft/__pycache__/__init__.cpython-310.pyc b/peft/tuners/oft/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd40ffc99e05acbdce560604801f65856d3e8a55 Binary files /dev/null and b/peft/tuners/oft/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/oft/__pycache__/config.cpython-310.pyc b/peft/tuners/oft/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..067776feff3a353cdd0c9531c68b674918405fb2 Binary files /dev/null and b/peft/tuners/oft/__pycache__/config.cpython-310.pyc differ diff --git a/peft/tuners/oft/__pycache__/layer.cpython-310.pyc b/peft/tuners/oft/__pycache__/layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6394092daea9cb9c1f27e8a5d3137d86ae1a70c Binary files /dev/null and b/peft/tuners/oft/__pycache__/layer.cpython-310.pyc differ diff --git a/peft/tuners/oft/__pycache__/model.cpython-310.pyc b/peft/tuners/oft/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e0bccb0d4a10e525210462ee1be022816ab63f3 Binary files /dev/null and b/peft/tuners/oft/__pycache__/model.cpython-310.pyc differ diff --git a/peft/tuners/oft/config.py b/peft/tuners/oft/config.py new file mode 100644 index 0000000000000000000000000000000000000000..ba3b9a4401abd6a17840bc6944baaa9f0085fb39 --- /dev/null +++ b/peft/tuners/oft/config.py @@ -0,0 +1,119 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import List, Optional, Union + +from peft.tuners.lycoris_utils import LycorisConfig +from peft.utils import PeftType + + +@dataclass +class OFTConfig(LycorisConfig): + """ + This is the configuration class to store the configuration of a [`OFTModel`]. + + Args: + r (`int`): OFT rank. + module_dropout (`int`): The dropout probability for disabling OFT modules during training. + target_modules (`Optional[Union[List[str], str]]`): + The names of the modules to apply the adapter to. If this is specified, only the modules with the specified + names will be replaced. When passing a string, a regex match will be performed. When passing a list of + strings, either an exact match will be performed or it is checked if the name of the module ends with any + of the passed strings. If this is specified as 'all-linear', then all linear modules are chosen, excluding + the output layer. If this is not specified, modules will be chosen according to the model architecture. If + the architecture is not known, an error will be raised -- in this case, you should specify the target + modules manually. + init_weights (`bool`): + Whether to perform initialization of OFT weights. + layers_to_transform (`Union[List[int], int]`): + The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices + that are specified in this list. If a single integer is passed, it will apply the transformations on the + layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None`. + rank_pattern (`dict`): + The mapping from layer names or regexp expression to ranks which are different from the default rank + specified by `r`. + modules_to_save (`List[str]`): + List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint. + coft (`bool`): + Whether to use the constrained variant of OFT or not, off by default. + eps (`float`): + The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True. + block_share (`bool`): + Whether to share the OFT parameters between blocks or not. This is `False` by default. + """ + + r: int = field(default=8, metadata={"help": "OFT rank"}) + module_dropout: float = field( + default=0.0, metadata={"help": "The dropout probability for disabling OFT modules during training"} + ) + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to replace with OFT." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " + "This can also be a wildcard 'all-linear' which matches all linear/Conv1D layers except the output layer." + }, + ) + init_weights: bool = field( + default=True, + metadata={ + "help": ( + "Whether to initialize the weights of the OFT layers with their default initialization. Don't change " + "this setting, except if you know exactly what you're doing." + ), + }, + ) + layers_to_transform: Optional[Union[List[int], int]] = field( + default=None, + metadata={ + "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index." + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." + }, + ) + modules_to_save: Optional[List[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from OFT layers to be set as trainable and saved in the final checkpoint. " + "For example, in Sequence Classification or Token Classification tasks, " + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." + }, + ) + coft: bool = field( + default=False, + metadata={"help": "Whether to use the constrained variant of OFT or not."}, + ) + eps: float = field( + default=6e-5, + metadata={ + "help": "The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True." + }, + ) + block_share: bool = field( + default=False, + metadata={"help": "Whether to share the OFT parameters between blocks or not."}, + ) + + def __post_init__(self): + self.peft_type = PeftType.OFT + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) diff --git a/peft/tuners/oft/layer.py b/peft/tuners/oft/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..f4427304b5a739116f0dbca5582603a932518980 --- /dev/null +++ b/peft/tuners/oft/layer.py @@ -0,0 +1,388 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import Any, List, Optional, Set, Tuple + +import torch +import torch.nn as nn + +from peft.tuners.lycoris_utils import LycorisLayer, check_adapters_to_merge + + +class OFTLayer(nn.Module, LycorisLayer): + # All names of layers that may contain adapter weights + adapter_layer_names = ("oft_r",) + # other_param_names is defined on parent class + + def __init__(self, base_layer: nn.Module): + super().__init__() + LycorisLayer.__init__(self, base_layer) + + # OFT info + self.oft_r = nn.ParameterDict({}) + self.coft = {} + self.eps = {} + self.block_share = {} + + @property + def _available_adapters(self) -> Set[str]: + return {*self.oft_r} + + def create_adapter_parameters(self, adapter_name: str, r: int, shape: Tuple[int, ...], block_share: bool): + if block_share: + self.oft_r[adapter_name] = nn.Parameter(torch.empty(1, math.ceil(shape[0] / r), math.ceil(shape[0] / r))) + else: + self.oft_r[adapter_name] = nn.Parameter(torch.empty(r, math.ceil(shape[0] / r), math.ceil(shape[0] / r))) + + def reset_adapter_parameters(self, adapter_name: str): + nn.init.zeros_(self.oft_r[adapter_name]) + + def reset_adapter_parameters_random(self, adapter_name: str): + nn.init.kaiming_uniform_(self.oft_r[adapter_name], a=math.sqrt(5)) + + def update_layer( + self, + adapter_name: str, + r: int, + module_dropout: float, + init_weights: bool, + coft: bool = False, + eps: float = 6e-5, + block_share: bool = False, + **kwargs, + ) -> None: + """Internal function to create oft adapter + + Args: + adapter_name (`str`): Name for the adapter to add. + r (`int`): Rank for the added adapter. + module_dropout (`float`): The dropout probability for disabling adapter during training. + init_weights (`bool`): Whether to initialize weights. + coft (`bool`): Whether to use the constrained variant of OFT or not. + eps (`float`): + The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True. + block_share (`bool`): Whether to share the OFT parameters between blocks or not. + """ + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + self.module_dropout[adapter_name] = module_dropout + self.coft[adapter_name] = coft + self.block_share[adapter_name] = block_share + + # Determine shape of OFT weights + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + shape = tuple(base_layer.weight.shape) + elif isinstance(base_layer, nn.Conv2d): + shape = ( + base_layer.out_channels, + base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1], + ) + else: + raise TypeError(f"OFT is not implemented for base layers of type {type(base_layer).__name__}") + + self.eps[adapter_name] = eps * math.ceil(shape[0] / r) * math.ceil(shape[0] / r) + + # Create weights with provided shape + self.create_adapter_parameters(adapter_name, r, shape, block_share) + + # Initialize weights + if init_weights: + self.reset_adapter_parameters(adapter_name) + else: + self.reset_adapter_parameters_random(adapter_name) + + # Move new weights to device + weight = getattr(self.get_base_layer(), "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + if weight.dtype.is_floating_point or weight.dtype.is_complex: + self.to(weight.device, dtype=weight.dtype) + else: + self.to(weight.device) + self.set_adapter(self.active_adapters) + + def unscale_layer(self, scale=None) -> None: + # scale is not used + pass + + def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If `None`, all active adapters will be merged. + Defaults to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self._available_adapters: + base_layer = self.get_base_layer() + + orig_weights = base_layer.weight.data + if isinstance(base_layer, nn.Linear): + orig_weights = torch.transpose(orig_weights, 0, 1) + elif isinstance(base_layer, nn.Conv2d): + orig_weights = orig_weights.view( + [ + base_layer.out_channels, + base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1], + ] + ) + orig_weights = torch.transpose(orig_weights, 0, 1) + delta_weight = self.get_delta_weight(active_adapter) + if orig_weights.shape[1] != delta_weight.shape[1]: + # when in channels is not divisible by r + delta_weight = delta_weight[: orig_weights.shape[1], : orig_weights.shape[1]] + new_weights = torch.mm(orig_weights, delta_weight) + if isinstance(base_layer, nn.Linear): + new_weights = torch.transpose(new_weights, 0, 1) + elif isinstance(base_layer, nn.Conv2d): + new_weights = torch.transpose(new_weights, 0, 1) + new_weights = new_weights.view( + [ + base_layer.out_channels, + base_layer.in_channels, + base_layer.kernel_size[0], + base_layer.kernel_size[1], + ] + ) + + if safe_merge and not torch.isfinite(new_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = new_weights + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self._available_adapters: + base_layer = self.get_base_layer() + new_weights = base_layer.weight.data + if isinstance(base_layer, nn.Linear): + new_weights = torch.transpose(new_weights, 0, 1) + elif isinstance(base_layer, nn.Conv2d): + new_weights = new_weights.view( + [ + base_layer.out_channels, + base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1], + ] + ) + new_weights = torch.transpose(new_weights, 0, 1) + delta_weight = self.get_delta_weight(active_adapter) + if new_weights.shape[1] != delta_weight.shape[1]: + # when in channels is not divisible by r + delta_weight = delta_weight[: new_weights.shape[1], : new_weights.shape[1]] + delta_inv = torch.inverse(delta_weight) + orig_weights = torch.mm(new_weights, delta_inv) + + if isinstance(base_layer, nn.Linear): + orig_weights = torch.transpose(orig_weights, 0, 1) + elif isinstance(base_layer, nn.Conv2d): + orig_weights = torch.transpose(orig_weights, 0, 1) + orig_weights = orig_weights.reshape( + [ + base_layer.out_channels, + base_layer.in_channels, + base_layer.kernel_size[0], + base_layer.kernel_size[1], + ] + ) + base_layer.weight.data = orig_weights + + def get_delta_weight(self, adapter_name: str) -> torch.Tensor: + rank = self.r[adapter_name] + coft = self.coft[adapter_name] + eps = self.eps[adapter_name] + opt_r = self.oft_r[adapter_name] + + if coft: + with torch.no_grad(): + opt_r.copy_(self._project_batch(opt_r, eps=eps)) + + orth_rotate = self._cayley_batch(opt_r) + weight = self._block_diagonal(orth_rotate, rank) + + return weight + + # Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L144 + def _cayley_batch(self, data: torch.Tensor) -> torch.Tensor: + b, r, c = data.shape + # Ensure the input matrix is skew-symmetric + skew = 0.5 * (data - data.transpose(1, 2)) + I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c) # noqa: E741 + + # Perform the Cayley parametrization + Q = torch.bmm(I - skew, torch.inverse(I + skew)) + + return Q + + # Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L155 + def _block_diagonal(self, oft_r: torch.Tensor, rank: int) -> torch.Tensor: + if oft_r.shape[0] == 1: + # block share + blocks = [oft_r[0, ...] for i in range(rank)] + else: + blocks = [oft_r[i, ...] for i in range(rank)] + + # Use torch.block_diag to create the block diagonal matrix + A = torch.block_diag(*blocks) + + return A + + # Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L52 + def _project_batch(self, oft_r, eps=1e-5): + # scaling factor for each of the smaller block matrix + eps = eps * 1 / torch.sqrt(torch.tensor(oft_r.shape[0])) + I = ( # noqa: E741 + torch.zeros((oft_r.size(1), oft_r.size(1)), device=oft_r.device, dtype=oft_r.dtype) + .unsqueeze(0) + .expand_as(oft_r) + ) + diff = oft_r - I + norm_diff = torch.norm(oft_r - I, dim=(1, 2), keepdim=True) + mask = (norm_diff <= eps).bool() + out = torch.where(mask, oft_r, I + eps * (diff / norm_diff)) + return out + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + if len(result.shape) == 4: + result = result.permute(0, 2, 3, 1) + + base_layer = self.get_base_layer() + base_bias = base_layer.bias + if base_bias is not None: + # Bias should be added after OFT forward + result = result - base_bias.data + + # Execute all the adapters + for active_adapter in self.active_adapters: + if active_adapter not in self._available_adapters: + continue + + module_dropout = self.module_dropout[active_adapter] + + # Modify current execution weights + if (not self.training) or (self.training and torch.rand(1) > module_dropout): + result = self._get_delta_activations(active_adapter, result, *args, **kwargs) + + if base_bias is not None: + result = result + base_bias.data + if len(result.shape) == 4: + result = result.permute(0, 3, 1, 2) + + result = result.to(previous_dtype) + return result + + +class Linear(OFTLayer): + """OFT implemented in Linear layer""" + + def __init__( + self, + base_layer: nn.Module, + adapter_name: str = "default", + r: int = 0, + module_dropout: float = 0.0, + init_weights: bool = True, + **kwargs, + ): + super().__init__(base_layer) + + # Create adapter and set it active + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, module_dropout, init_weights, **kwargs) + + def _get_delta_activations( + self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any + ) -> torch.Tensor: + delta_weight = self.get_delta_weight(adapter_name) + + base_layer = self.get_base_layer() + base_weight = base_layer.weight.data + delta_weight = delta_weight[: base_weight.shape[0], : base_weight.shape[0]] + + # don't add bias here, because the bias will be added after OFT forward + return torch.matmul(input, delta_weight) + + def __repr__(self) -> str: + rep = super().__repr__() + return "oft." + rep + + +class Conv2d(OFTLayer): + """OFT implemented in Conv2d layer""" + + def __init__( + self, + base_layer: nn.Module, + adapter_name: str = "default", + r: int = 0, + module_dropout: float = 0.0, + init_weights: bool = True, + **kwargs, + ): + super().__init__(base_layer) + + # Create adapter and set it active + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, module_dropout, init_weights, **kwargs) + + def _get_delta_activations( + self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any + ) -> torch.Tensor: + delta_weight = self.get_delta_weight(adapter_name) + + base_layer = self.get_base_layer() + base_weight = base_layer.weight.data + delta_weight = delta_weight[: base_weight.shape[0], : base_weight.shape[0]] + + # don't add bias here, because the bias will be added after OFT forward + return torch.matmul(input, delta_weight) + + def __repr__(self) -> str: + rep = super().__repr__() + return "oft." + rep diff --git a/peft/tuners/oft/model.py b/peft/tuners/oft/model.py new file mode 100644 index 0000000000000000000000000000000000000000..fd96325c6f0a6d7fd87b77e033d5bf49a9050752 --- /dev/null +++ b/peft/tuners/oft/model.py @@ -0,0 +1,106 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Dict, Type, Union + +import torch +from torch import nn + +from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner + +from .layer import Conv2d, Linear, OFTLayer + + +class OFTModel(LycorisTuner): + """ + Creates Orthogonal Finetuning model from a pretrained model. The method is described in + https://arxiv.org/abs/2306.07280 + + Args: + model (`torch.nn.Module`): The model to which the adapter tuner layers will be attached. + config ([`OFTConfig`]): The configuration of the OFT model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + `torch.nn.Module`: The OFT model. + + Example: + ```py + >>> from diffusers import StableDiffusionPipeline + >>> from peft import OFTModel, OFTConfig + + >>> config_te = OFTConfig( + ... r=8, + ... target_modules=["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"], + ... module_dropout=0.0, + ... init_weights=True, + ... ) + >>> config_unet = OFTConfig( + ... r=8, + ... target_modules=[ + ... "proj_in", + ... "proj_out", + ... "to_k", + ... "to_q", + ... "to_v", + ... "to_out.0", + ... "ff.net.0.proj", + ... "ff.net.2", + ... ], + ... module_dropout=0.0, + ... init_weights=True, + ... ) + + >>> model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> model.text_encoder = OFTModel(model.text_encoder, config_te, "default") + >>> model.unet = OFTModel(model.unet, config_unet, "default") + ``` + + **Attributes**: + - **model** ([`~torch.nn.Module`]) -- The model to be adapted. + - **peft_config** ([`OFTConfig`]): The configuration of the OFT model. + """ + + prefix: str = "oft_" + layers_mapping: Dict[Type[torch.nn.Module], Type[OFTLayer]] = { + torch.nn.Conv2d: Conv2d, + torch.nn.Linear: Linear, + } + + def _create_and_replace( + self, + config: LycorisConfig, + adapter_name: str, + target: Union[OFTLayer, nn.Module], + target_name: str, + parent: nn.Module, + current_key: str, + ) -> None: + """ + A private method to create and replace the target module with the adapter module. + """ + + # Regexp matching - Find key which matches current target_name in patterns provided + pattern_keys = list(config.rank_pattern.keys()) + target_name_key = next(filter(lambda key: re.match(rf"(.*\.)?{key}$", current_key), pattern_keys), target_name) + + kwargs = config.to_dict() + kwargs["r"] = config.rank_pattern.get(target_name_key, config.r) + + if isinstance(target, OFTLayer): + target.update_layer(adapter_name, **kwargs) + else: + new_module = self._create_new_module(config, adapter_name, target, **kwargs) + self._replace_module(parent, target_name, new_module, target) diff --git a/peft/tuners/p_tuning/__init__.py b/peft/tuners/p_tuning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7dd3a6ba3e4442354302c5bfe3da75f1d6f69d02 --- /dev/null +++ b/peft/tuners/p_tuning/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import PromptEncoderConfig, PromptEncoderReparameterizationType +from .model import PromptEncoder + + +__all__ = ["PromptEncoder", "PromptEncoderConfig", "PromptEncoderReparameterizationType"] diff --git a/peft/tuners/p_tuning/__pycache__/__init__.cpython-310.pyc b/peft/tuners/p_tuning/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1c324a0bf1c8673da68e2a63a8509b0df16c476 Binary files /dev/null and b/peft/tuners/p_tuning/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/p_tuning/__pycache__/config.cpython-310.pyc b/peft/tuners/p_tuning/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a95dc0c3c6ea52c33a2eca8e1be1c28f9ad9bf3a Binary files /dev/null and b/peft/tuners/p_tuning/__pycache__/config.cpython-310.pyc differ diff --git a/peft/tuners/p_tuning/__pycache__/model.cpython-310.pyc b/peft/tuners/p_tuning/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..334098e141cec92cf1e431fb9e4f2329ed766218 Binary files /dev/null and b/peft/tuners/p_tuning/__pycache__/model.cpython-310.pyc differ diff --git a/peft/tuners/p_tuning/config.py b/peft/tuners/p_tuning/config.py new file mode 100644 index 0000000000000000000000000000000000000000..75deffb4299df4178e80d74dde47b0470ea06c25 --- /dev/null +++ b/peft/tuners/p_tuning/config.py @@ -0,0 +1,59 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +from dataclasses import dataclass, field +from typing import Union + +from peft.config import PromptLearningConfig +from peft.utils import PeftType + + +class PromptEncoderReparameterizationType(str, enum.Enum): + MLP = "MLP" + LSTM = "LSTM" + + +@dataclass +class PromptEncoderConfig(PromptLearningConfig): + """ + This is the configuration class to store the configuration of a [`PromptEncoder`]. + + Args: + encoder_reparameterization_type (Union[[`PromptEncoderReparameterizationType`], `str`]): + The type of reparameterization to use. + encoder_hidden_size (`int`): The hidden size of the prompt encoder. + encoder_num_layers (`int`): The number of layers of the prompt encoder. + encoder_dropout (`float`): The dropout probability of the prompt encoder. + """ + + encoder_reparameterization_type: Union[str, PromptEncoderReparameterizationType] = field( + default=PromptEncoderReparameterizationType.MLP, + metadata={"help": "How to reparameterize the prompt encoder"}, + ) + encoder_hidden_size: int = field( + default=None, + metadata={"help": "The hidden size of the prompt encoder"}, + ) + encoder_num_layers: int = field( + default=2, + metadata={"help": "The number of layers of the prompt encoder"}, + ) + encoder_dropout: float = field( + default=0.0, + metadata={"help": "The dropout of the prompt encoder"}, + ) + + def __post_init__(self): + self.peft_type = PeftType.P_TUNING diff --git a/peft/tuners/p_tuning/model.py b/peft/tuners/p_tuning/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ade2b1128158376c134441687803b85d444cfb96 --- /dev/null +++ b/peft/tuners/p_tuning/model.py @@ -0,0 +1,130 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Based on https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/modules/common/prompt_encoder.py +# with some refactor +import warnings + +import torch + +from .config import PromptEncoderConfig, PromptEncoderReparameterizationType + + +class PromptEncoder(torch.nn.Module): + """ + The prompt encoder network that is used to generate the virtual token embeddings for p-tuning. + + Args: + config ([`PromptEncoderConfig`]): The configuration of the prompt encoder. + + Example: + + ```py + >>> from peft import PromptEncoder, PromptEncoderConfig + + >>> config = PromptEncoderConfig( + ... peft_type="P_TUNING", + ... task_type="SEQ_2_SEQ_LM", + ... num_virtual_tokens=20, + ... token_dim=768, + ... num_transformer_submodules=1, + ... num_attention_heads=12, + ... num_layers=12, + ... encoder_reparameterization_type="MLP", + ... encoder_hidden_size=768, + ... ) + + >>> prompt_encoder = PromptEncoder(config) + ``` + + **Attributes**: + - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prompt encoder. + - **mlp_head** (`torch.nn.Sequential`) -- The MLP head of the prompt encoder if `inference_mode=False`. + - **lstm_head** (`torch.nn.LSTM`) -- The LSTM head of the prompt encoder if `inference_mode=False` and + `encoder_reparameterization_type="LSTM"`. + - **token_dim** (`int`) -- The hidden embedding dimension of the base transformer model. + - **input_size** (`int`) -- The input size of the prompt encoder. + - **output_size** (`int`) -- The output size of the prompt encoder. + - **hidden_size** (`int`) -- The hidden size of the prompt encoder. + - **total_virtual_tokens** (`int`): The total number of virtual tokens of the + prompt encoder. + - **encoder_type** (Union[[`PromptEncoderReparameterizationType`], `str`]): The encoder type of the prompt + encoder. + + + Input shape: (`batch_size`, `total_virtual_tokens`) + + Output shape: (`batch_size`, `total_virtual_tokens`, `token_dim`) + """ + + def __init__(self, config): + super().__init__() + self.token_dim = config.token_dim + self.input_size = self.token_dim + self.output_size = self.token_dim + self.hidden_size = config.encoder_hidden_size + self.total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules + self.encoder_type = config.encoder_reparameterization_type + + # embedding + self.embedding = torch.nn.Embedding(self.total_virtual_tokens, self.token_dim) + if not config.inference_mode: + if self.encoder_type == PromptEncoderReparameterizationType.LSTM: + lstm_dropout = config.encoder_dropout + num_layers = config.encoder_num_layers + # LSTM + self.lstm_head = torch.nn.LSTM( + input_size=self.input_size, + hidden_size=self.hidden_size, + num_layers=num_layers, + dropout=lstm_dropout, + bidirectional=True, + batch_first=True, + ) + + self.mlp_head = torch.nn.Sequential( + torch.nn.Linear(self.hidden_size * 2, self.hidden_size * 2), + torch.nn.ReLU(), + torch.nn.Linear(self.hidden_size * 2, self.output_size), + ) + + elif self.encoder_type == PromptEncoderReparameterizationType.MLP: + encoder_num_layers_default = PromptEncoderConfig.encoder_num_layers + if config.encoder_num_layers != encoder_num_layers_default: + warnings.warn( + f"for {self.encoder_type.value}, the argument `encoder_num_layers` is ignored. " + f"Exactly {encoder_num_layers_default} MLP layers are used." + ) + layers = [ + torch.nn.Linear(self.input_size, self.hidden_size), + torch.nn.ReLU(), + torch.nn.Linear(self.hidden_size, self.hidden_size), + torch.nn.ReLU(), + torch.nn.Linear(self.hidden_size, self.output_size), + ] + self.mlp_head = torch.nn.Sequential(*layers) + + else: + raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.") + + def forward(self, indices): + input_embeds = self.embedding(indices) + if self.encoder_type == PromptEncoderReparameterizationType.LSTM: + output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]) + elif self.encoder_type == PromptEncoderReparameterizationType.MLP: + output_embeds = self.mlp_head(input_embeds) + else: + raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.") + + return output_embeds diff --git a/peft/tuners/poly/__init__.py b/peft/tuners/poly/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b0f368695edbd7fb7bb3c68d9e918bd16752b873 --- /dev/null +++ b/peft/tuners/poly/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import PolyConfig +from .layer import Linear, PolyLayer +from .model import PolyModel + + +__all__ = ["Linear", "PolyConfig", "PolyLayer", "PolyModel"] diff --git a/peft/tuners/poly/__pycache__/__init__.cpython-310.pyc b/peft/tuners/poly/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5cae1fca5c0a71e744b407516795e939543abbd Binary files /dev/null and b/peft/tuners/poly/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/poly/__pycache__/config.cpython-310.pyc b/peft/tuners/poly/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bbd45c0a12c26bebee972bd81825a019a63f517 Binary files /dev/null and b/peft/tuners/poly/__pycache__/config.cpython-310.pyc differ diff --git a/peft/tuners/poly/__pycache__/layer.cpython-310.pyc b/peft/tuners/poly/__pycache__/layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..540089adc2af803e75a0770eba8aadb507aa2298 Binary files /dev/null and b/peft/tuners/poly/__pycache__/layer.cpython-310.pyc differ diff --git a/peft/tuners/poly/__pycache__/model.cpython-310.pyc b/peft/tuners/poly/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c04f570b472a334507220c2124eb2ff705505f5a Binary files /dev/null and b/peft/tuners/poly/__pycache__/model.cpython-310.pyc differ diff --git a/peft/tuners/poly/__pycache__/router.cpython-310.pyc b/peft/tuners/poly/__pycache__/router.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c45e551f21b88093299f44e189b2a82e1cbaa71 Binary files /dev/null and b/peft/tuners/poly/__pycache__/router.cpython-310.pyc differ diff --git a/peft/tuners/poly/config.py b/peft/tuners/poly/config.py new file mode 100644 index 0000000000000000000000000000000000000000..3abbc93b022dd53b5fd5c373b029dba9084a0b9b --- /dev/null +++ b/peft/tuners/poly/config.py @@ -0,0 +1,89 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import List, Literal, Optional, Union + +from peft.config import PeftConfig +from peft.utils import PeftType + + +@dataclass +class PolyConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`PolyModel`]. + - [Polytropon (Poly)](https://arxiv.org/abs/2202.13914) + - [Multi-Head Routing (MHR)](https://arxiv.org/abs/2211.03831) + + Args: + r (`int`): Attention dimension of each Lora in Poly. + target_modules (`Union[List[str],str]`): The names of the modules to apply Poly to. + modules_to_save (`List[str]`): List of modules apart from Poly layers to be set as trainable + and saved in the final checkpoint. + init_weights (bool): Whether to perform initialization of Poly weights. + poly_type (`Literal["poly"]`): The variant of the Poly module to use. Currently, only "poly" + is supported. + n_tasks (`int`): The number of tasks in a multitasking scenario. + n_skills (`int`): The number of skills (LoRA) in each Poly layer. + n_splits (`int`): The number of splits within each LoRA of a Poly layer. A value greater + than 1 indicates the use of Multi-Head Routing (MHR). + """ + + r: int = field(default=8, metadata={"help": "Lora attention dimension"}) + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to replace with Poly." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " + }, + ) + modules_to_save: Optional[List[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from Poly layers to be set as trainable and saved in the final checkpoint. " + "For example, in Sequence Classification or Token Classification tasks, " + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." + }, + ) + init_weights: bool = field( + default=True, + metadata={ + "help": ( + "Whether to initialize the weights of the Poly layers with their default initialization. Don't change " + "this setting, except if you know exactly what you're doing." + ), + }, + ) + poly_type: Literal["poly"] = field( + default="poly", + metadata={"help": 'Type of Poly modules to be used. Currently only "poly" is supported.'}, + ) + n_tasks: int = field( + default=1, + metadata={"help": "Number of tasks in multitasking scenario."}, + ) + n_skills: int = field( + default=4, + metadata={"help": "Number of skills (LoRA) in each Poly layer."}, + ) + n_splits: int = field( + default=1, + metadata={"help": "Number of splits within each LoRA of a Poly layer."}, + ) + + def __post_init__(self): + self.peft_type = PeftType.POLY + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) diff --git a/peft/tuners/poly/layer.py b/peft/tuners/poly/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..debb40beee29b1cfdf2072a293d4c61042280227 --- /dev/null +++ b/peft/tuners/poly/layer.py @@ -0,0 +1,171 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any + +import torch +import torch.nn as nn + +from peft.tuners.tuners_utils import BaseTunerLayer + +from .config import PolyConfig +from .router import get_router + + +class PolyLayer(BaseTunerLayer): + # All names of layers that may contain (trainable) adapter weights + adapter_layer_names = ("poly_lora_A", "poly_lora_B", "poly_router") + # All names of other parameters that may contain adapter-related parameters + other_param_names = ("r", "n_tasks", "n_skills", "n_splits") + + def __init__(self, base_layer: nn.Module, **kwargs): + self.base_layer = base_layer + self.r = {} + self.n_tasks = {} + self.n_skills = {} + self.n_splits = {} + self.poly_type = {} + self.poly_router = nn.ModuleDict() + self.poly_lora_A = nn.ParameterDict() + self.poly_lora_B = nn.ParameterDict() + self.kwargs = kwargs + + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + in_features, out_features = base_layer.in_features, base_layer.out_features + else: + raise ValueError(f"Unsupported layer type {type(base_layer)}") + + self.in_features = in_features + self.out_features = out_features + + def update_layer(self, adapter_name, poly_config): + if poly_config.r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {poly_config.r}") + + self.r[adapter_name] = poly_config.r + self.n_tasks[adapter_name] = poly_config.n_tasks + self.n_skills[adapter_name] = poly_config.n_skills + self.n_splits[adapter_name] = poly_config.n_splits + self.poly_type[adapter_name] = poly_config.poly_type + + self.poly_lora_A[adapter_name] = nn.Parameter( + torch.empty( + poly_config.n_splits, + poly_config.n_skills, + self.in_features // poly_config.n_splits, + poly_config.r, + ) + ) + self.poly_lora_B[adapter_name] = nn.Parameter( + torch.empty( + poly_config.n_splits, + poly_config.n_skills, + poly_config.r, + self.out_features // poly_config.n_splits, + ) + ) + self.poly_router[adapter_name] = get_router(poly_config) + + self.reset_poly_parameters(adapter_name, init_weights=poly_config.init_weights) + + weight = getattr(self.get_base_layer(), "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + if weight.dtype.is_floating_point or weight.dtype.is_complex: + self.to(weight.device, dtype=weight.dtype) + else: + self.to(weight.device) + self.set_adapter(self.active_adapters) + + def reset_poly_parameters(self, adapter_name, init_weights): + if adapter_name in self.poly_lora_A.keys(): + # initialize A the same way as the default for nn.Linear + # https://github.com/microsoft/mttl/blob/ce4ca51dbca73be656feb9b3e5233633e3c5dec7/mttl/models/poly.py#L269 + n_splits, n_skills, d, r = self.poly_lora_A[adapter_name].shape + for skill in range(n_skills): + for split in range(n_splits): + param = torch.empty((r, d)) + torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5)) + self.poly_lora_A[adapter_name].data[split, skill, :, :] = param.T + + if init_weights: + # initialize B to zero + torch.nn.init.zeros_(self.poly_lora_B[adapter_name]) + else: + # initialize B the same way as the default for nn.Linear + n_splits, n_skills, r, d = self.poly_lora_B[adapter_name].shape + for skill in range(n_skills): + for split in range(n_splits): + param = torch.empty((d, r)) + torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5)) + self.poly_lora_B[adapter_name].data[split, skill, :, :] = param.T + + # initialized router + self.poly_router[adapter_name].reset() + + +class Linear(nn.Module, PolyLayer): + # Lora implemented in a dense layer + def __init__( + self, + base_layer, + adapter_name: str, + poly_config: PolyConfig, + **kwargs, + ) -> None: + super().__init__() + PolyLayer.__init__(self, base_layer, **kwargs) + + self._active_adapter = adapter_name + self.update_layer(adapter_name, poly_config) + + def forward(self, x: torch.Tensor, *args: Any, task_ids: torch.Tensor = None, **kwargs: Any) -> torch.Tensor: + previous_dtype = x.dtype + if self.disable_adapters: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.poly_lora_A.keys(): + continue + + r = self.r[active_adapter] + poly_router = self.poly_router[active_adapter] + poly_lora_A = self.poly_lora_A[active_adapter] + poly_lora_B = self.poly_lora_B[active_adapter] + + # Combine the output of LoRAs + # https://github.com/microsoft/mttl/blob/ce4ca51dbca73be656feb9b3e5233633e3c5dec7/mttl/models/poly.py#L293 + mixing_weights = poly_router(task_ids=task_ids, input_ids=x) + bs, n_splits, n_skills = mixing_weights.size() + + # A is n_splits, n_skills, D // n_splits, rank + # we want bs, n_splits, D // n_splits, rank + A = torch.einsum("bqs,qsdr->bqdr", (mixing_weights, poly_lora_A)) + B = torch.einsum("bqs,qsrd->bqrd", (mixing_weights, poly_lora_B)) + + A = A.reshape(bs, self.in_features, r) + B = B.transpose(1, 2).reshape(bs, r, self.out_features) + + x = x.to(A.dtype) + result += x.bmm(A).bmm(B) / r + + result = result.to(previous_dtype) + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "poly." + rep diff --git a/peft/tuners/poly/model.py b/peft/tuners/poly/model.py new file mode 100644 index 0000000000000000000000000000000000000000..df637b12af3e9a1ea8c18e5c96631e102cbfdd00 --- /dev/null +++ b/peft/tuners/poly/model.py @@ -0,0 +1,187 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager +from dataclasses import asdict +from enum import Enum +from typing import Any + +import torch +from torch import nn + +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists +from peft.utils import ( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, +) + +from .config import PolyConfig +from .layer import Linear, PolyLayer + + +class PolyModel(BaseTuner): + prefix: str = "poly_" + + def __init__(self, model, config, adapter_name) -> None: + super().__init__(model, config, adapter_name) + + @staticmethod + def _check_target_module_exists(poly_config, key): + return check_target_module_exists(poly_config, key) + + def _create_and_replace( + self, + poly_config: PolyConfig, + adapter_name: str, + target: nn.Module, + target_name: str, + parent: nn.Module, + **optional_kwargs: Any, + ): + if isinstance(target, PolyLayer): + target.update_layer(adapter_name, poly_config) + else: + new_module = self._create_new_module( + poly_config, + adapter_name, + target, + ) + if adapter_name not in self.active_adapters: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + + def _replace_module(self, parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + if (self.prefix in name) or ("ranknum" in name): + weight = child.qweight if hasattr(child, "qweight") else child.weight + module.to(weight.device) + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.named_parameters(): + if self.prefix not in n: + p.requires_grad = False + + @staticmethod + def _create_new_module(poly_config, adapter_name, target, **kwargs): + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, torch.nn.Linear): + return Linear(target, adapter_name, poly_config, **kwargs) + else: + raise ValueError( + f"Target module {target} is not supported. Currently, only the following modules are supported: " + "`torch.nn.Linear`." + ) + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.model, name) + + def get_peft_config_as_dict(self, inference: bool = False): + config_dict = {} + for key, value in self.peft_config.items(): + config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} + if inference: + config["inference_mode"] = True + config_dict[key] = config + return config + + def _set_adapter_layers(self, enabled=True): + for module in self.model.modules(): + if isinstance(module, (PolyLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self): + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self): + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name): + for module in self.model.modules(): + if isinstance(module, PolyLayer): + module.set_adapter(adapter_name) + + def _prepare_adapter_config(self, peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return peft_config + + def _register_pre_hooks(self, task_ids): + """Helper method to register pre hooks.""" + if task_ids is None: + return [] + + def pre_hook(_, args, kwargs): + kwargs["task_ids"] = task_ids + return args, kwargs + + handles = [] + + for module in self.model.modules(): + if isinstance(module, Linear): + handle = module.register_forward_pre_hook(pre_hook, with_kwargs=True) + handles.append(handle) + + return handles + + @contextmanager + def _manage_pre_hooks(self, task_ids): + """Context manager to handle the lifecycle of pre hooks.""" + handles = self._register_pre_hooks(task_ids) + try: + yield + finally: + for handle in handles: + handle.remove() + + def forward(self, *args, task_ids=None, **kwargs): + with self._manage_pre_hooks(task_ids): + return self.model(*args, **kwargs) + + def generate(self, *args, task_ids=None, **kwargs): + with self._manage_pre_hooks(task_ids): + return self.model.generate(*args, **kwargs) diff --git a/peft/tuners/poly/router.py b/peft/tuners/poly/router.py new file mode 100644 index 0000000000000000000000000000000000000000..0249398a9fc36d53bc0b4f022a8410514688a9f1 --- /dev/null +++ b/peft/tuners/poly/router.py @@ -0,0 +1,83 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +import torch +from torch import nn +from torch.distributions.relaxed_bernoulli import RelaxedBernoulli + +from .config import PolyConfig + + +EPS = 1e-12 + + +def get_router(poly_config: PolyConfig) -> nn.Module: + if poly_config.poly_type == "poly": + return PolyRouter(poly_config) + else: + raise ValueError( + f"Unsupported poly_type: {poly_config.poly_type}. " + "Currently, only the following types are supported: " + "`poly`." + ) + + +class Router(nn.Module, ABC): + @abstractmethod + def reset(self): + ... + + @abstractmethod + def forward(self, task_ids: torch.Tensor, input_ids: torch.Tensor): + ... + + +class PolyRouter(Router): + # It's a simplified implementation of + # https://github.com/microsoft/mttl/blob/ce4ca51dbca73be656feb9b3e5233633e3c5dec7/mttl/models/poly.py#L138 + def __init__(self, poly_config: PolyConfig): + super().__init__() + + self.poly_type = poly_config.poly_type + self.n_tasks = poly_config.n_tasks + self.n_skills = poly_config.n_skills + self.n_splits = poly_config.n_splits + + self.module_logits = nn.Parameter(torch.empty((self.n_tasks, self.n_splits * self.n_skills))) + + def reset(self): + torch.nn.init.uniform_(self.module_logits, -1e-3, 1e-3) + + def forward(self, task_ids: torch.Tensor, input_ids: torch.Tensor): + if task_ids is None: + raise ValueError("task_ids should not be None.") + if task_ids.max().item() >= self.n_tasks: + raise ValueError(f"Only {self.n_tasks} tasks available. Found task id = {task_ids.max().item()}") + + # move task id to input's device + task_ids = task_ids.to(self.module_logits.device) + + module_logits = self.module_logits[task_ids] + module_logits = module_logits.view(-1, self.n_splits, self.n_skills) + + if self.training: + module_logits = RelaxedBernoulli(temperature=1.0, logits=module_logits).rsample() + else: + module_logits = torch.sigmoid(module_logits) + + module_weights = module_logits / (module_logits.sum(dim=-1, keepdim=True) + EPS) + + return module_weights diff --git a/peft/tuners/prefix_tuning/__init__.py b/peft/tuners/prefix_tuning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..28f4bedbb43bcf2b22146d60e0e1f2fe7b19d9eb --- /dev/null +++ b/peft/tuners/prefix_tuning/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import PrefixTuningConfig +from .model import PrefixEncoder + + +__all__ = ["PrefixTuningConfig", "PrefixEncoder"] diff --git a/peft/tuners/prefix_tuning/__pycache__/__init__.cpython-310.pyc b/peft/tuners/prefix_tuning/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ce5ca54fd323aa906e2ef6843a3c0099b5efc94 Binary files /dev/null and b/peft/tuners/prefix_tuning/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/prefix_tuning/__pycache__/config.cpython-310.pyc b/peft/tuners/prefix_tuning/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7065dab677330f2f55a47c7915c1b0960c824d90 Binary files /dev/null and b/peft/tuners/prefix_tuning/__pycache__/config.cpython-310.pyc differ diff --git a/peft/tuners/prefix_tuning/__pycache__/model.cpython-310.pyc b/peft/tuners/prefix_tuning/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76eb57d16b6ba10c7f816708adee9308832baea4 Binary files /dev/null and b/peft/tuners/prefix_tuning/__pycache__/model.cpython-310.pyc differ diff --git a/peft/tuners/prefix_tuning/config.py b/peft/tuners/prefix_tuning/config.py new file mode 100644 index 0000000000000000000000000000000000000000..39203ff7beb571f067331798051e085a49273211 --- /dev/null +++ b/peft/tuners/prefix_tuning/config.py @@ -0,0 +1,41 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from peft.config import PromptLearningConfig +from peft.utils import PeftType + + +@dataclass +class PrefixTuningConfig(PromptLearningConfig): + """ + This is the configuration class to store the configuration of a [`PrefixEncoder`]. + + Args: + encoder_hidden_size (`int`): The hidden size of the prompt encoder. + prefix_projection (`bool`): Whether to project the prefix embeddings. + """ + + encoder_hidden_size: int = field( + default=None, + metadata={"help": "The hidden size of the encoder"}, + ) + prefix_projection: bool = field( + default=False, + metadata={"help": "Whether to project the prefix tokens"}, + ) + + def __post_init__(self): + self.peft_type = PeftType.PREFIX_TUNING diff --git a/peft/tuners/prefix_tuning/model.py b/peft/tuners/prefix_tuning/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ffd51892a3cc074406791f6bc7d1b088d25148e3 --- /dev/null +++ b/peft/tuners/prefix_tuning/model.py @@ -0,0 +1,80 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py +# with some refactor +import torch + + +class PrefixEncoder(torch.nn.Module): + r""" + The `torch.nn` model to encode the prefix. + + Args: + config ([`PrefixTuningConfig`]): The configuration of the prefix encoder. + + Example: + + ```py + >>> from peft import PrefixEncoder, PrefixTuningConfig + + >>> config = PrefixTuningConfig( + ... peft_type="PREFIX_TUNING", + ... task_type="SEQ_2_SEQ_LM", + ... num_virtual_tokens=20, + ... token_dim=768, + ... num_transformer_submodules=1, + ... num_attention_heads=12, + ... num_layers=12, + ... encoder_hidden_size=768, + ... ) + >>> prefix_encoder = PrefixEncoder(config) + ``` + + **Attributes**: + - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prefix encoder. + - **transform** (`torch.nn.Sequential`) -- The two-layer MLP to transform the prefix embeddings if + `prefix_projection` is `True`. + - **prefix_projection** (`bool`) -- Whether to project the prefix embeddings. + + Input shape: (`batch_size`, `num_virtual_tokens`) + + Output shape: (`batch_size`, `num_virtual_tokens`, `2*layers*hidden`) + """ + + def __init__(self, config): + super().__init__() + self.prefix_projection = config.prefix_projection + token_dim = config.token_dim + num_layers = config.num_layers + encoder_hidden_size = config.encoder_hidden_size + num_virtual_tokens = config.num_virtual_tokens + if self.prefix_projection and not config.inference_mode: + # Use a two-layer MLP to encode the prefix + self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim) + self.transform = torch.nn.Sequential( + torch.nn.Linear(token_dim, encoder_hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim), + ) + else: + self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.transform(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values diff --git a/peft/tuners/prompt_tuning/__init__.py b/peft/tuners/prompt_tuning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71795b61d819573ff41770e6d49c750e6c51b0ae --- /dev/null +++ b/peft/tuners/prompt_tuning/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import PromptTuningConfig, PromptTuningInit +from .model import PromptEmbedding + + +__all__ = ["PromptTuningConfig", "PromptEmbedding", "PromptTuningInit"] diff --git a/peft/tuners/prompt_tuning/__pycache__/__init__.cpython-310.pyc b/peft/tuners/prompt_tuning/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdd772ea0a8b961a34d6b2f7af1d187b537a5739 Binary files /dev/null and b/peft/tuners/prompt_tuning/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/prompt_tuning/__pycache__/config.cpython-310.pyc b/peft/tuners/prompt_tuning/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc31497e99ce70f7d18b2e18ec17dd62d009e95e Binary files /dev/null and b/peft/tuners/prompt_tuning/__pycache__/config.cpython-310.pyc differ diff --git a/peft/tuners/prompt_tuning/__pycache__/model.cpython-310.pyc b/peft/tuners/prompt_tuning/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0096fe7a5c97927a1b124419bdf920fca8aa9840 Binary files /dev/null and b/peft/tuners/prompt_tuning/__pycache__/model.cpython-310.pyc differ diff --git a/peft/tuners/prompt_tuning/config.py b/peft/tuners/prompt_tuning/config.py new file mode 100644 index 0000000000000000000000000000000000000000..d9987e112abc389f623a6ae4f7df90bd70c8439a --- /dev/null +++ b/peft/tuners/prompt_tuning/config.py @@ -0,0 +1,86 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +from dataclasses import dataclass, field +from typing import Optional, Union + +from peft.config import PromptLearningConfig +from peft.utils import PeftType + + +class PromptTuningInit(str, enum.Enum): + TEXT = "TEXT" + RANDOM = "RANDOM" + + +@dataclass +class PromptTuningConfig(PromptLearningConfig): + """ + This is the configuration class to store the configuration of a [`PromptEmbedding`]. + + Args: + prompt_tuning_init (Union[[`PromptTuningInit`], `str`]): The initialization of the prompt embedding. + prompt_tuning_init_text (`str`, *optional*): + The text to initialize the prompt embedding. Only used if `prompt_tuning_init` is `TEXT`. + tokenizer_name_or_path (`str`, *optional*): + The name or path of the tokenizer. Only used if `prompt_tuning_init` is `TEXT`. + tokenizer_kwargs (`dict`, *optional*): + The keyword arguments to pass to `AutoTokenizer.from_pretrained`. Only used if `prompt_tuning_init` is + `TEXT`. + """ + + prompt_tuning_init: Union[PromptTuningInit, str] = field( + default=PromptTuningInit.RANDOM, + metadata={"help": "How to initialize the prompt tuning parameters"}, + ) + prompt_tuning_init_text: Optional[str] = field( + default=None, + metadata={ + "help": "The text to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`" + }, + ) + tokenizer_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The tokenizer to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`" + }, + ) + + tokenizer_kwargs: Optional[dict] = field( + default=None, + metadata={ + "help": ( + "The keyword arguments to pass to `AutoTokenizer.from_pretrained`. Only used if prompt_tuning_init is " + "`TEXT`" + ), + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.PROMPT_TUNING + if (self.prompt_tuning_init == PromptTuningInit.TEXT) and not self.tokenizer_name_or_path: + raise ValueError( + f"When prompt_tuning_init='{PromptTuningInit.TEXT.value}', " + f"tokenizer_name_or_path can't be {self.tokenizer_name_or_path}." + ) + if (self.prompt_tuning_init == PromptTuningInit.TEXT) and self.prompt_tuning_init_text is None: + raise ValueError( + f"When prompt_tuning_init='{PromptTuningInit.TEXT.value}', " + f"prompt_tuning_init_text can't be {self.prompt_tuning_init_text}." + ) + if self.tokenizer_kwargs and (self.prompt_tuning_init != PromptTuningInit.TEXT): + raise ValueError( + f"tokenizer_kwargs only valid when using prompt_tuning_init='{PromptTuningInit.TEXT.value}'." + ) diff --git a/peft/tuners/prompt_tuning/model.py b/peft/tuners/prompt_tuning/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ce9b6bc409e0b26f846ea159464bdb953f8b0c4e --- /dev/null +++ b/peft/tuners/prompt_tuning/model.py @@ -0,0 +1,91 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch + +from peft.utils.integrations import gather_params_ctx + +from .config import PromptTuningInit + + +class PromptEmbedding(torch.nn.Module): + """ + The model to encode virtual tokens into prompt embeddings. + + Args: + config ([`PromptTuningConfig`]): The configuration of the prompt embedding. + word_embeddings (`torch.nn.Module`): The word embeddings of the base transformer model. + + **Attributes**: + - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prompt embedding. + + Example: + + ```py + >>> from peft import PromptEmbedding, PromptTuningConfig + + >>> config = PromptTuningConfig( + ... peft_type="PROMPT_TUNING", + ... task_type="SEQ_2_SEQ_LM", + ... num_virtual_tokens=20, + ... token_dim=768, + ... num_transformer_submodules=1, + ... num_attention_heads=12, + ... num_layers=12, + ... prompt_tuning_init="TEXT", + ... prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral", + ... tokenizer_name_or_path="t5-base", + ... ) + + >>> # t5_model.shared is the word embeddings of the base model + >>> prompt_embedding = PromptEmbedding(config, t5_model.shared) + ``` + + Input Shape: (`batch_size`, `total_virtual_tokens`) + + Output Shape: (`batch_size`, `total_virtual_tokens`, `token_dim`) + """ + + def __init__(self, config, word_embeddings): + super().__init__() + + total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules + self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim) + if config.prompt_tuning_init == PromptTuningInit.TEXT and not config.inference_mode: + from transformers import AutoTokenizer + + tokenizer_kwargs = config.tokenizer_kwargs or {} + tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path, **tokenizer_kwargs) + init_text = config.prompt_tuning_init_text + init_token_ids = tokenizer(init_text)["input_ids"] + # Trim or iterate until num_text_tokens matches total_virtual_tokens + num_text_tokens = len(init_token_ids) + if num_text_tokens > total_virtual_tokens: + init_token_ids = init_token_ids[:total_virtual_tokens] + elif num_text_tokens < total_virtual_tokens: + num_reps = math.ceil(total_virtual_tokens / num_text_tokens) + init_token_ids = init_token_ids * num_reps + init_token_ids = init_token_ids[:total_virtual_tokens] + init_token_ids = torch.LongTensor(init_token_ids).to(word_embeddings.weight.device) + with gather_params_ctx(word_embeddings.parameters()): + word_embedding_weights = word_embeddings(init_token_ids).detach().clone() + word_embedding_weights = word_embedding_weights.to(torch.float32) + self.embedding.weight = torch.nn.Parameter(word_embedding_weights) + + def forward(self, indices): + # Just get embeddings + prompt_embeddings = self.embedding(indices) + return prompt_embeddings diff --git a/peft/tuners/tuners_utils.py b/peft/tuners/tuners_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a6ba53dcfc0132e886dfe657a027d82bb302b4 --- /dev/null +++ b/peft/tuners/tuners_utils.py @@ -0,0 +1,830 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import copy +import logging +import os +import re +import warnings +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import Any, Optional, Union + +import torch +from accelerate.hooks import AlignDevicesHook +from accelerate.utils import named_module_tensors, offload_state_dict +from torch import nn +from transformers import PreTrainedModel +from transformers.pytorch_utils import Conv1D + +from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND + +from ..config import PeftConfig +from ..utils import ModulesToSaveWrapper, _get_submodules + + +logger = logging.getLogger(__name__) + + +@contextmanager +def onload_layer(layer): + r""" + A utility for modifying a module containing one or more tuners and a base layer, any of which are offloaded to the + CPU or disk. Moves a module's sub-modules to the execution device before some action is performed, after that the + base layer state dictionary is re-assigned (if that layer was offloaded to the disk) and finally the parameters are + offloaded. + + If the module has no offloaded sub-modules, this function does nothing. + + Args: + layer ('torch.nn.Module'): + layer with tuners to be merged + """ + + offloaded_modules = [] + for name, module in layer.named_modules(): + if name in ["", "base_layer"]: + continue + if hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload: + module._hf_hook.pre_forward(module) + offloaded_modules.append(module) + + base_layer_offload = False + if hasattr(layer, "base_layer") and ( + hasattr(layer.base_layer, "_hf_hook") + and isinstance(layer.base_layer._hf_hook, AlignDevicesHook) + and layer.base_layer._hf_hook.offload + ): + # check if the base layer is disk-offloaded (must contain a 'dataset' and an offload index) + if torch.device("meta") in layer.base_layer._hf_hook.original_devices.values() and hasattr( + layer.base_layer._hf_hook.weights_map, "dataset" + ): + # find the disk-offload index (maps modules to safetensors) from the `dataset` (OffloadedWeightsLoader object) + index = layer.base_layer._hf_hook.weights_map.dataset.index + module_name = list(dict(layer.base_layer._hf_hook.weights_map.dataset).keys())[0] # any module will do + file_name = index[module_name]["safetensors_file"] + base_name_arr = [] + # get effective dir name + for i in os.path.split(file_name): + if "--" in i: + base_name_arr.append(i) + break + base_name_arr.append(i) + base_name = os.path.join(*base_name_arr) + safetensors_filename = base_name + "-merged" + layer.base_layer._hf_hook.pre_forward(layer.base_layer) + base_layer_offload = True + + yield + + for module in offloaded_modules: + module._hf_hook.post_forward(module, torch.tensor([])) + + if base_layer_offload: + # re-make weights map (must be on cpu to send params to the disk via memmap if disk offload) + layer.base_layer._hf_hook.weights_map = { + name: param.to("cpu") for name, param in named_module_tensors(layer.base_layer) + } + # offload weights map to disk if original device is the disk + if torch.device("meta") in layer.base_layer._hf_hook.original_devices.values() and hasattr( + layer.base_layer._hf_hook.weights_map, "dataset" + ): + # rewrite directory with merged weights + offload_state_dict(safetensors_filename, layer.base_layer._hf_hook.weights_map) + layer.base_layer._hf_hook.post_forward(layer.base_layer, torch.tensor([])) + + +class BaseTuner(nn.Module, ABC): + r""" + A base tuner model that provides the common methods and attributes for all tuners that are injectable into a + torch.nn.Module + + For adding a new Tuner class, one needs to overwrite the following methods: + + - **_prepare_adapter_config**: + A private method to eventually prepare the adapter config, for example in case the field `target_modules` is + missing. + - **_create_and_replace**: + A private method to create and replace the target module with the adapter module. + - **_check_target_module_exists**: + A private helper method to check if the passed module's key name matches any of the target modules in the + adapter_config. + + The easiest is to check what is done in the `peft.tuners.lora.LoraModel` class. + + Attributes: + model (`torch.nn.Module`): + The model to which the adapter tuner layers will be attached. + forward (`Callable`): + The forward method of the model. + peft_config (`Union[`PeftConfig`, dict[str, PeftConfig]]`): + The adapter configuration object, it should be a dictionary of `str` to `PeftConfig` objects. One can also + pass a PeftConfig object and a new adapter will be created with the default name `adapter` or create a new + dictionary with a key `adapter_name` and a value of that peft config. + config (`dict[str, Any]`): + The model configuration object, it should be a dictionary of `str` to `Any` objects. + targeted_module_names (`list[str]`): + The list of module names that were actually adapted. Can be useful to inspect if you want to quickly + double-check that the `config.target_modules` where specified correctly. + """ + + def __init__(self, model, peft_config: Union[PeftConfig, dict[str, PeftConfig]], adapter_name: str) -> None: + super().__init__() + + self.model = model + self.targeted_module_names: list[str] = [] + + # For advanced developers, if you want to attach multiple adapters to your + # model, just add a `peft_config` dict attribute to your model. + if not hasattr(self, "peft_config"): + self.peft_config = {adapter_name: peft_config} if isinstance(peft_config, PeftConfig) else peft_config + else: + logger.info( + "Already found a `peft_config` attribute in the model. This will lead to having multiple adapters" + " in the model. Make sure to know what you are doing!" + ) + if isinstance(peft_config, PeftConfig): + self.peft_config[adapter_name] = peft_config + else: + # user is adding a dict of PeftConfigs + self.peft_config.update(peft_config) + + self.active_adapter: str | list[str] = adapter_name + self._pre_injection_hook(self.model, self.peft_config[adapter_name], adapter_name) + self.inject_adapter(self.model, adapter_name) + + # Copy the peft_config in the injected model. + self.model.peft_config = self.peft_config + + @property + def active_adapters(self) -> list[str]: + if isinstance(self.active_adapter, str): + return [self.active_adapter] + # is already a list of str + return self.active_adapter + + def forward(self, *args: Any, **kwargs: Any): + return self.model.forward(*args, **kwargs) + + def _pre_injection_hook(self, model: nn.Module, config: PeftConfig, adapter_name: str) -> None: + r""" + A hook to be called before the adapter is injected into the model. This method can be overridden by child + classes to perform any pre-injection operations. + + Args: + model (`nn.Module`): + The model to be adapted. + config (`PeftConfig`): + The adapter config. + adapter_name (`str`): + The adapter name. + """ + pass + + @abstractmethod + def _prepare_adapter_config(self, peft_config: PeftConfig, model_config: dict) -> PeftConfig: + r""" + A private method to eventually prepare the adapter config. For transformers based models, if + `peft_config.target_modules` is None, we can automatically infer the target modules from the + `TRANSFORMERS_MODELS_TO_XXX_TARGET_MODULES_MAPPING`. This method can be further refactored in the future to + automatically infer it for all tuner models. + + Check out `peft.tuner.lora.LoraModel._prepare_adapter_config` for an example. + + Args: + peft_config (`PeftConfig`): + The adapter config. + model_config (`dict`): + The transformers model config, that config should contain the `model_type` key. + """ + ... + + def _prepare_model(self, peft_config: PeftConfig, model: nn.Module): + r""" + A private method to modify the model structure before adapter is applied. + + See `peft.tuner.lora.LoraModel._prepare_model` for an example. + + Args: + peft_config (`PeftConfig`): + The prepared adapter config. + model (`nn.Module`): + The model that is going to be adapted. + """ + pass + + @abstractmethod + def _check_target_module_exists(peft_config: PeftConfig, key: str) -> bool: + r""" + A helper private method to check if the passed module's key name matches any of the target modules in the + `peft_config.target_modules` list. If it does, return `True`, else return `False`. + + Args: + peft_config (`PeftConfig`): + The adapter config. + key (`str`): + The module's key name. + """ + ... + + @abstractmethod + def _create_and_replace( + self, + peft_config: PeftConfig, + adapter_name: str, + target: nn.Module, + target_name: str, + parent: nn.Module, + current_key: str, + ) -> None: + r""" + Inplace replacement of the target module with the adapter layer. This method needs to be overridden by all the + tuner classes. + + Check `peft.tuners.lora.LoraModel._create_and_replace` for an example. + + Args: + peft_config (`PeftConfig`): + The adapter config. + adapter_name (`str`): + The adapter name. + target (`nn.Module`): + The target module. + target_name (`str`): + The target module's name. + parent (`nn.Module`): + The parent module. + current_key (`str`): + The key of the current target being adapted. + """ + ... + + @abstractmethod + def _mark_only_adapters_as_trainable(self, model: nn.Module): + r""" + A helper method to mark only the adapter layers as trainable (i.e. module.requires_grad = False) This needs to + be overridden for all tuner classes to match the correct key names. + + Check `peft.tuners.lora.LoraModel._mark_only_adapters_as_trainable` for an example. + """ + ... + + @abstractmethod + def disable_adapter_layers(self) -> None: + """ + Disable all adapters in-place. + """ + ... + + @abstractmethod + def enable_adapter_layers(self) -> None: + """ + Enable all adapters in-place + """ + ... + + def _check_new_adapter_config(self, config: PeftConfig) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + pass + + def _check_merge_allowed(self): + """Helper method to check whether the adapter can be merged. + + Raise a ValueError if it is not possible to merge the adapter with the given configuration. + """ + pass + + def inject_adapter(self, model: nn.Module, adapter_name: str): + r""" + Creates adapter layers and replaces the target modules with the adapter layers. This method is called under the + hood by `peft.mapping.get_peft_model` if a non-prompt tuning adapter class is passed. + + The corresponding PEFT config is directly retrieved from the `peft_config` attribute of the BaseTuner class. + + Args: + model (`nn.Module`): + The model to be tuned. + adapter_name (`str`): + The adapter name. + """ + peft_config = self.peft_config[adapter_name] + # Note: If possible, all checks should be performed *at the start of this method*. + # This way, we can raise early if something goes wrong, without leaving the model + # in a bad (half-initialized) state. + self._check_new_adapter_config(peft_config) + + _check_for_modules_to_save = getattr(peft_config, "modules_to_save", None) is not None + _has_modules_to_save = False + + model_config = getattr(model, "config", {"model_type": "custom"}) + if hasattr(model_config, "to_dict"): + model_config = model_config.to_dict() + + peft_config = self._prepare_adapter_config(peft_config, model_config) + + self._prepare_model(peft_config, model) + is_target_modules_in_base_model = False + key_list = [key for key, _ in model.named_modules()] + + # update peft_config.target_modules if required + peft_config = _maybe_include_all_linear_layers(peft_config, model) + + for key in key_list: + # Check for modules_to_save in case + if _check_for_modules_to_save and any( + key.endswith(f"{module_to_save}") for module_to_save in peft_config.modules_to_save + ): + # Optionally set the modules to save + parent, target, target_name = _get_submodules(model, key) + + if not isinstance(target, ModulesToSaveWrapper): + new_module = ModulesToSaveWrapper(target, adapter_name) + setattr(parent, target_name, new_module) + else: + target.update(adapter_name) + + _has_modules_to_save = True + continue + + if not self._check_target_module_exists(peft_config, key): + continue + + self.targeted_module_names.append(key) + is_target_modules_in_base_model = True + parent, target, target_name = _get_submodules(model, key) + self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key) + + if not is_target_modules_in_base_model: + raise ValueError( + f"Target modules {peft_config.target_modules} not found in the base model. " + f"Please check the target modules and try again." + ) + + # It's important to set the adapter here (again), because otherwise it can happen that if a 2nd adapter is + # added, and it targets different layer(s) than the first adapter (which is active), then those different + # layers will be activated, which we don't want. + self.set_adapter(self.active_adapters) + self._mark_only_adapters_as_trainable(model) + + if self.peft_config[adapter_name].inference_mode: + for n, p in model.named_parameters(): + if adapter_name in n: + p.requires_grad = False + + if _has_modules_to_save: + if not hasattr(model, "modules_to_save"): + model.modules_to_save = set(peft_config.modules_to_save) + else: + model.modules_to_save.update(set(peft_config.modules_to_save)) + + def merge_adapter(self, adapter_names: Optional[list[str]] = None) -> None: + """ + This method merges the adapter layers into the base model. + + Merging adapters can lead to a speed up of the forward pass. A copy of the adapter weights is still kept in + memory, which is required to unmerge the adapters. In order to merge the adapter weights without keeping them + in memory, please call `merge_and_unload`. + + Args: + safe_merge (`bool`, *optional*): + If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If `None`, all active adapters will be merged. + Defaults to `None`. + """ + self._check_merge_allowed() + for module in self.model.modules(): + if isinstance(module, BaseTunerLayer): + with onload_layer(module): + module.merge(adapter_names=adapter_names) + + def unmerge_adapter(self): + """ + This method unmerges all merged adapter layers from the base model. + """ + for module in self.model.modules(): + if isinstance(module, BaseTunerLayer): + with onload_layer(module): + module.unmerge() + + def _unloading_checks(self, adapter_names: Optional[list[str]]): + adapters_to_consider = adapter_names or self.active_adapters + is_modules_to_save_available = any( + self.peft_config[adapter].modules_to_save for adapter in adapters_to_consider + ) + if is_modules_to_save_available and len(adapters_to_consider) > 1: + raise ValueError("Cannot unload multiple adapters that specify `modules_to_save`.") + + +class BaseTunerLayer(ABC): + r""" + A tuner layer mixin that provides the common methods and attributes for all tuners. + + Args: + is_pluggable (`bool`, *optional*): + Whether the adapter layer can be plugged to any pytorch module + active_adapters (Union[List[`str`], `str`], *optional*): + The name of the active adapter. + """ + + active_adapter = None + + # All names of layers that may contain adapter (trainable) weights + adapter_layer_names: tuple[str, ...] = () + # All names of other parameters that may contain adapter-related parameters + other_param_names: tuple[str, ...] = () + + # indicates whether all adapters should be disabled + _disable_adapters: bool = False + + # the currently active adapter(s) + _active_adapter: str | list[str] = "default" + + # List all merged adapters + merged_adapters: list[str] = [] + + def get_base_layer(self) -> nn.Module: + """ + (Recursively) get the base_layer. + + This is necessary for the case that the tuner layer wraps another tuner layer. + + """ + base_layer = self + while hasattr(base_layer, "base_layer"): + base_layer = base_layer.base_layer + return base_layer + + @property + def weight(self) -> torch.Tensor: + # This is required for some transformers code, e.g. for T5, weight is accessed as: + # self.wo.weight + # where "wo" is the adapter layer. + # https://github.com/huggingface/transformers/blob/78f6ed6c70b29c1560780e3869a7ad4c6b3d2710/src/transformers + # /models/t5/modeling_t5.py#L292 + base_layer = self.get_base_layer() + if hasattr(base_layer, "qweight"): + # QuantLinear + weight = base_layer.qweight + else: + # Other layers + weight = base_layer.weight + return weight + + @property + def bias(self) -> torch.Tensor: + base_layer = self.get_base_layer() + return base_layer.bias + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + raise NotImplementedError + + def unmerge(self) -> None: + raise NotImplementedError + + @property + def merged(self) -> bool: + return bool(self.merged_adapters) + + @property + def disable_adapters(self) -> bool: + # use a property to ensure that disable_adapters is not set directly, instead use the enable_adapters method + return self._disable_adapters + + @property + def active_adapter(self) -> str | list[str]: + # use a property to ensure that active_adapter is not set directly, instead use the set_adapter method + return self._active_adapter + + def _get_available_adapters(self) -> set[str]: + """Return all adapter names that can be found on this module.""" + adapters = set() + for layer_name in self.adapter_layer_names: + module = getattr(self, layer_name) + if not isinstance(module, (nn.ModuleDict, nn.ParameterDict)): + continue + adapters.update(set(module.keys())) + return adapters + + @property + def active_adapters(self): + if isinstance(self.active_adapter, str): + return [self.active_adapter] + # is already a list of str + return self.active_adapter + + def enable_adapters(self, enabled: bool) -> None: + """Toggle the enabling and disabling of adapters + + Takes care of setting the requires_grad flag for the adapter weights. + + Args: + enabled (bool): True to enable adapters, False to disable adapters + """ + if enabled: + self.set_adapter(self.active_adapters) + self._disable_adapters = False + else: + # disable grads on all adapter layers + for layer_name in self.adapter_layer_names: + layer = getattr(self, layer_name) + layer.requires_grad_(False) + self._disable_adapters = True + + def set_adapter(self, adapter_names: str | list[str]) -> None: + """Set the active adapter(s). + + Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is + not desired, use the following code. + + ```py + >>> for name, param in model_peft.named_parameters(): + ... if ...: # some check on name (ex. if 'lora' in name) + ... param.requires_grad = False + ``` + + Args: + adapter_name (`str` or `List[str]`): Name of the adapter(s) to be activated. + """ + if isinstance(adapter_names, str): + adapter_names = [adapter_names] + + # Deactivate grads on the inactive adapter and activate grads on the active adapter + for layer_name in self.adapter_layer_names: + module_dict = getattr(self, layer_name) + for key, layer in module_dict.items(): + if key in adapter_names: + # Note: It is possible that not a single layer is called with requires_grad_(True) here. This may + # happen if a completely different adapter layer is being activated. + layer.requires_grad_(True) + else: + layer.requires_grad_(False) + + self._active_adapter = adapter_names + + def _all_available_adapter_names(self) -> list[str]: + """Return a sorted list of all available adapter names""" + adapter_names = set() + for name in self.adapter_layer_names + self.other_param_names: + # we check each possible attribute and if it's a dict or ModuleDict, we assume that the keys are the adapter + # names + attr = getattr(self, name) + if hasattr(attr, "keys"): + adapter_names.update(attr.keys()) + return sorted(adapter_names) + + def delete_adapter(self, adapter_name: str) -> None: + """ + Delete an adapter from the layer + + This should be called on all adapter layers, or else we will get an inconsistent state. + + This method will also set a new active adapter if the deleted adapter was an active adapter. It is important + that the new adapter is chosen in a deterministic way, so that the same adapter is chosen on all layers. + + Args: + adapter_name (`str`): The name of the adapter to delete + + """ + for attr in self.adapter_layer_names + self.other_param_names: + if adapter_name in getattr(self, attr): + del getattr(self, attr)[adapter_name] + + if adapter_name in self.active_adapters: + # choose a new active adapter + active_adapters = self.active_adapters[:] + active_adapters.remove(adapter_name) + if active_adapters: + self.set_adapter(active_adapters) + else: + # no active adapters left, set a new default adapter + # here we get the list of all adapters existing adapter names and choose the first one + remaining_adapters = self._all_available_adapter_names() + if not remaining_adapters: + self.set_adapter([]) + else: + new_active_adapter = remaining_adapters[0] + warnings.warn( + f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to " + f"{new_active_adapter}." + ) + self.set_adapter(remaining_adapters[0]) + + +def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None: + """A helper method to check if the passed module's key name matches any of the target modules in the adapter_config. + + Args: + config (`LoraConfig` | `LycorisConfig`): A config to match target modules from + key (`str`): A key to search any matches in config + + Returns: + `bool` | `re.Match[str]` | `None`: True of match object if key matches any target modules from config, False or + None if no match found + """ + if isinstance(config.target_modules, str): + target_module_found = re.fullmatch(config.target_modules, key) + elif key in config.target_modules: + # this module is specified directly in target_modules + target_module_found = True + else: + target_module_found = any(key.endswith(f".{target_key}") for target_key in config.target_modules) + + layer_indexes = getattr(config, "layers_to_transform", None) + layers_pattern = getattr(config, "layers_pattern", None) + + is_using_layer_indexes = layer_indexes is not None and ( + len(layer_indexes) != 0 if isinstance(layer_indexes, list) else True + ) + if is_using_layer_indexes and target_module_found: + layer_index = None + # TODO: It's still unclear how empty layers_pattern (None, [], or "") should behave + # For now, empty layers_pattern means any layer pattern is ok + if layers_pattern is None or len(layers_pattern) == 0: + layer_index = re.match(r".*\.[^.]*\.(\d+)\.", key) + else: + layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern + for pattern in layers_pattern: + layer_index = re.match(rf".*\.{pattern}\.(\d+)\.", key) + if layer_index is not None: + break + + if layer_index is None: + target_module_found = False + else: + layer_index = int(layer_index.group(1)) + if isinstance(layer_indexes, int): + target_module_found = layer_index == layer_indexes + else: + target_module_found = layer_index in layer_indexes + + return target_module_found + + +def inspect_matched_modules(tuner: BaseTuner, adapter_name: str = "default") -> dict: + """ + A helper function to inspect the set of matched and unmatched modules for a PEFT model and the given adapter. + """ + config = tuner.peft_config[adapter_name] + key_list = [key for key, _ in tuner.model.named_modules()] + module_dict = {"matched": [], "unmatched": []} + for key in key_list: + if tuner._check_target_module_exists(config, key): + module_dict["matched"].append(key) + else: + module_dict["unmatched"].append(key) + return module_dict + + +def _maybe_include_all_linear_layers(peft_config: PeftConfig, model: nn.Module) -> PeftConfig: + """ + Helper function to update `target_modules` to all linear/Conv1D layers if provided as 'all-linear'. Adapted from + the QLoRA repository: https://github.com/artidoro/qlora/blob/main/qlora.py + """ + + # if `target_modules` is a string, convert to lower case and check if it matches "all-linear" + if not ( + isinstance(peft_config.target_modules, str) + and peft_config.target_modules.lower() == INCLUDE_LINEAR_LAYERS_SHORTHAND + ): + return peft_config + + if not isinstance(model, PreTrainedModel): + raise ValueError( + f"Only instances of PreTrainedModel support `target_modules={INCLUDE_LINEAR_LAYERS_SHORTHAND!r}`" + ) + + linear_classes = (torch.nn.Linear, Conv1D) + + linear_module_names = set() + for name, module in model.named_modules(): + # match with all linear classes. + if isinstance(module, linear_classes): + names = name.rsplit(".", 1)[-1] # get the base name + linear_module_names.add(names) + + # ignore the last classification head for text generation models + output_emb = model.get_output_embeddings() + if output_emb is not None: + last_module_name = [name for name, module in model.named_modules() if module is output_emb][0] + linear_module_names -= {last_module_name} + peft_config.target_modules = linear_module_names + return peft_config + + +def check_adapters_to_merge(module: BaseTunerLayer, adapter_names: Optional[list[str]] = None) -> list[str]: + """ + Helper function to check which adapters should be merged. + + Only return those adapters that are not already merged. Give a warning if some or all of the adapters are already + merged. + + """ + if adapter_names is None: + adapter_names = module.active_adapters + if isinstance(adapter_names, str): + raise ValueError(f"adapter_names should be a list of strings, got {adapter_names!r}.") + + if module.merged: + merged_adapters = set(module.merged_adapters) + adapter_names = [name for name in adapter_names if name not in merged_adapters] + + if adapter_names: + warnings.warn( + f"Already following adapters were merged {','.join(module.merged_adapters)}. " + f"You are now additionally merging {','.join(adapter_names)}." + ) + else: + warnings.warn("All adapters are already merged, nothing to do.") + + return adapter_names + + +def clone_module(module: nn.Module, share_weights=False): + """Clone a module in a pytorch model. + + Clones a module of a model, optionally sharing all the parameters between the original and the clone. Simplifies + reusing a module when manipulating the architecture of a model. + """ + clone = copy.deepcopy(module) + + def _share_weights(src: nn.Module, dst: nn.Module): + for name, param in src.named_parameters(recurse=False): + dst.register_parameter(name, param) + + if share_weights: + for name, submodule in module.named_modules(): + _share_weights(submodule, clone.get_submodule(name)) + + return clone + + +def replicate_layers(model: nn.Module, layer_map: list[tuple[int, int]]): + """Replicate layers in a transfomer model with weight sharing. + + This function looks for a module list attribute at model[(.model)*].layers and replicates the layers in the module + list according to the layer map. For example the map `[[0, 4], [2, 5]]` will take the set of layers `[0, 1, 2, 3, + 4]` and replace them with a module list containing `[0, 1, 2, 3, 2, 3, 4]`. + """ + while hasattr(model, "model"): + model = model.model + # Some variants of the bert model nest the main model under the bert attribute. + if hasattr(model, "bert"): + model = model.bert + + model_type = None + layers: nn.ModuleList = None + if hasattr(model, "layers"): + model_type = "llama" + layers = model.layers + elif hasattr(model, "encoder") and hasattr(model.encoder, "layer"): + model_type = "bert" + layers = model.encoder.layer + elif hasattr(model, "h"): + model_type = "falcon" + layers = model.h + if not model_type or not isinstance(layers, nn.ModuleList): + raise ValueError( + "Could not locate the layers attribute in the model. " + "Expected Llama, Bert or Falcon compatible architectures." + ) + + new_layers = [] + for start, end in layer_map: + for i in range(start, end): + current_idx = len(new_layers) + new_layers.append(clone_module(layers[i], share_weights=True)) + # This is a hack needed to work around the layer_idx introduced in HF transformers. + for submodule in new_layers[-1].modules(): + if hasattr(submodule, "layer_idx"): + submodule.layer_idx = current_idx + layers = nn.ModuleList(new_layers) + if model_type == "llama": + model.layers = layers + elif model_type == "bert": + model.encoder.layer = layers + elif model_type == "falcon": + model.h = layers + else: + raise ValueError("Unexpected model type, need to handle post-processing of layers.") + if hasattr(model.config, "num_hidden_layers"): # Common to Llama, Bert, Falcon. + model.config.num_hidden_layers = len(new_layers) diff --git a/peft/tuners/vera/__init__.py b/peft/tuners/vera/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf358818342066788f129c78cfaa98317c5e68cd --- /dev/null +++ b/peft/tuners/vera/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import VeraConfig +from .layer import Linear, VeraLayer +from .model import VeraModel + + +__all__ = ["VeraConfig", "VeraLayer", "Linear", "VeraModel"] diff --git a/peft/tuners/vera/__pycache__/__init__.cpython-310.pyc b/peft/tuners/vera/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e98bb1fc671c6aee8f8399772dc0a778d63c8d03 Binary files /dev/null and b/peft/tuners/vera/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/tuners/vera/__pycache__/buffer_dict.cpython-310.pyc b/peft/tuners/vera/__pycache__/buffer_dict.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30ed9f2b7aad142f0f80a3dcc7cd0500ed83f4c8 Binary files /dev/null and b/peft/tuners/vera/__pycache__/buffer_dict.cpython-310.pyc differ diff --git a/peft/tuners/vera/__pycache__/config.cpython-310.pyc b/peft/tuners/vera/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46a6cb46cab34fa5ba742595ce48f4921348b2bb Binary files /dev/null and b/peft/tuners/vera/__pycache__/config.cpython-310.pyc differ diff --git a/peft/tuners/vera/__pycache__/layer.cpython-310.pyc b/peft/tuners/vera/__pycache__/layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e80108e0d094dc75bdf37ba20b272c931d6272f Binary files /dev/null and b/peft/tuners/vera/__pycache__/layer.cpython-310.pyc differ diff --git a/peft/tuners/vera/__pycache__/model.cpython-310.pyc b/peft/tuners/vera/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb93e9e45be0353a8291755bfe4e5ee395f26eac Binary files /dev/null and b/peft/tuners/vera/__pycache__/model.cpython-310.pyc differ diff --git a/peft/tuners/vera/buffer_dict.py b/peft/tuners/vera/buffer_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..80b2307d1d9a8c65bfc0fee1a0bc7560cf4166db --- /dev/null +++ b/peft/tuners/vera/buffer_dict.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Adapted from https://botorch.org/api/_modules/botorch/utils/torch.html + +# TODO: To be removed once (if) https://github.com/pytorch/pytorch/pull/37385 lands + +from __future__ import annotations + +import collections +from collections import OrderedDict + +import torch +from torch.nn import Module + + +class BufferDict(Module): + r""" + Holds buffers in a dictionary. + + BufferDict can be indexed like a regular Python dictionary, but buffers it contains are properly registered, and + will be visible by all Module methods. `torch.nn.BufferDict` is an **ordered** dictionary that respects + + * the order of insertion, and + * in `torch.nn.BufferDict.update`, the order of the merged `OrderedDict` or another `torch.nn.BufferDict` (the + argument to `torch.nn.BufferDict.update`). + + Note that `torch.nn.BufferDict.update` with other unordered mapping types (e.g., Python's plain `dict`) does not + preserve the order of the merged mapping. + + Args: + buffers (iterable, optional): + a mapping (dictionary) of (string : `torch.Tensor`) or an iterable of key-value pairs of type (string, + `torch.Tensor`) + + ```python + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.buffers = nn.BufferDict({"left": torch.randn(5, 10), "right": torch.randn(5, 10)}) + + def forward(self, x, choice): + x = self.buffers[choice].mm(x) + return x + ``` + """ + + def __init__(self, buffers=None, persistent: bool = False): + r""" + Args: + buffers (`dict`): + A mapping (dictionary) from string to `torch.Tensor`, or an iterable of key-value pairs of type + (string, `torch.Tensor`). + """ + super().__init__() + if buffers is not None: + self.update(buffers) + + self.persistent = persistent + + def __getitem__(self, key): + return self._buffers[key] + + def __setitem__(self, key, buffer): + self.register_buffer(key, buffer, persistent=self.persistent) + + def __delitem__(self, key): + del self._buffers[key] + + def __len__(self): + return len(self._buffers) + + def __iter__(self): + return iter(self._buffers.keys()) + + def __contains__(self, key): + return key in self._buffers + + def clear(self): + """Remove all items from the BufferDict.""" + self._buffers.clear() + + def pop(self, key): + r"""Remove key from the BufferDict and return its buffer. + + Args: + key (`str`): + Key to pop from the BufferDict + """ + v = self[key] + del self[key] + return v + + def keys(self): + r"""Return an iterable of the BufferDict keys.""" + return self._buffers.keys() + + def items(self): + r"""Return an iterable of the BufferDict key/value pairs.""" + return self._buffers.items() + + def values(self): + r"""Return an iterable of the BufferDict values.""" + return self._buffers.values() + + def update(self, buffers): + r""" + Update the `torch.nn.BufferDict` with the key-value pairs from a mapping or an iterable, overwriting existing + keys. + + Note: + If `buffers` is an `OrderedDict`, a `torch.nn.BufferDict`, or an iterable of key-value pairs, the order of + new elements in it is preserved. + + Args: + buffers (iterable): + a mapping (dictionary) from string to `torch.Tensor`, or an iterable of key-value pairs of type + (string, `torch.Tensor`). + """ + if not isinstance(buffers, collections.abc.Iterable): + raise TypeError( + "BuffersDict.update should be called with an " + "iterable of key/value pairs, but got " + type(buffers).__name__ + ) + + if isinstance(buffers, collections.abc.Mapping): + if isinstance(buffers, (OrderedDict, BufferDict)): + for key, buffer in buffers.items(): + self[key] = buffer + else: + for key, buffer in sorted(buffers.items()): + self[key] = buffer + else: + for j, p in enumerate(buffers): + if not isinstance(p, collections.abc.Iterable): + raise TypeError( + "BufferDict update sequence element " + "#" + str(j) + " should be Iterable; is" + type(p).__name__ + ) + if not len(p) == 2: + raise ValueError( + "BufferDict update sequence element " + "#" + str(j) + " has length " + str(len(p)) + "; 2 is required" + ) + self[p[0]] = p[1] + + def extra_repr(self): + child_lines = [] + for k, p in self._buffers.items(): + size_str = "x".join(str(size) for size in p.size()) + device_str = "" if not p.is_cuda else f" (GPU {p.get_device()})" + parastr = f"Buffer containing: [{torch.typename(p)} of size {size_str}{device_str}]" + child_lines.append(" (" + k + "): " + parastr) + tmpstr = "\n".join(child_lines) + return tmpstr + + def __call__(self, input): + raise RuntimeError("BufferDict should not be called.") diff --git a/peft/tuners/vera/config.py b/peft/tuners/vera/config.py new file mode 100644 index 0000000000000000000000000000000000000000..1601c8c0e66715ef0404cc4ac8e338197a1e10c2 --- /dev/null +++ b/peft/tuners/vera/config.py @@ -0,0 +1,157 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass, field +from typing import List, Optional, Union + +from peft.config import PeftConfig +from peft.utils import PeftType + + +@dataclass +class VeraConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`VeraModel`]. + + Paper: https://arxiv.org/abs/2310.11454. + + Args: + r (`int`, *optional*, defaults to `256`): + VeRA parameter dimension ("rank"). Choose higher values than LoRA ranks here, since VeRA uses far fewer + parameters than LoRA (see Table 1). + target_modules (`Union[List[str], str]`): + The names of the modules to apply Vera to. Only linear layers are supported. + projection_prng_key (`int`): + Vera PRNG init key. Used for initialising vera_A and vera_B for new models or when loading a checkpoint + that did not include these projections. Defaults to `0`. + save_projection (`bool`): + Whether to save the vera_A / vera_B projections in the state dict alongside per layer lambda_b / lambda_d + weights. This will increase the size of the checkpoint, but guarantee that we can reload the checkpoint on + all system configurations. Defaults to `True`. + vera_dropout (`float`): + The dropout probability for Vera layers. + d_initial (`float`, *optional*, defaults to `0.1`): + Initial init value for `vera_lambda_d` vector used when initializing the VeRA parameters. Small values + (<=0.1) are recommended (see Table 6c in the paper). + fan_in_fan_out (`bool`): + Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses + `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`. + bias (`str`): + Bias type for Vera. Can be 'none', 'all' or 'vera_only'. If 'all' or 'vera_only', the corresponding biases + will be updated during training. Be aware that this means that, even when disabling the adapters, the model + will not produce the same output as the base model would have without adaptation. + modules_to_save (`List[str]`): + List of modules apart from Vera layers to be set as trainable and saved in the final checkpoint. + init_weights (`bool`): + Whether to initialize the weights of the Vera layers with their default initialization. Don't change this + setting, except if you know exactly what you're doing. + layers_to_transform (`Union[List[int],int]`): + The layer indexes to transform, if this argument is specified, it will apply the Vera transformations on + the layer indexes that are specified in this list. If a single integer is passed, it will apply the Vera + transformations on the layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer + pattern is not in the common layers pattern. + """ + + r: int = field(default=256, metadata={"help": "Vera attention dimension"}) + + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": ( + "List of module names or regex expression of the module names to replace with Vera." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. " + "Only linear layers are supported." + ) + }, + ) + projection_prng_key: int = field( + default=0, + metadata={ + "help": ( + "Vera PRNG init key. Used for initialising vera_A and vera_B for new models or when loading a " + "checkpoint that did not include these projections." + ) + }, + ) + save_projection: bool = field( + default=True, + metadata={ + "help": ( + "Whether to save the vera_A / vera_B projections in the state dict alongside per layer lambda_b / " + "lambda_d weights. This will increase the size of the checkpoint, but guarantee that we can reload " + "the checkpoint on all system configurations." + ) + }, + ) + vera_dropout: float = field(default=0.0, metadata={"help": "Vera dropout"}) + d_initial: float = field(default=0.1, metadata={"help": "Initial init value for d vector."}) + fan_in_fan_out: bool = field( + default=False, + metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, + ) + bias: str = field(default="none", metadata={"help": "Bias type for Vera. Can be 'none', 'all' or 'vera_only'"}) + modules_to_save: Optional[List[str]] = field( + default=None, + metadata={ + "help": ( + "List of modules apart from Vera layers to be set as trainable and saved in the final checkpoint. For" + " example, in Sequence Classification or Token Classification tasks, the final layer" + " `classifier/score` are randomly initialized and as such need to be trainable and saved." + ) + }, + ) + init_weights: bool = field( + default=True, + metadata={ + "help": ( + "Whether to initialize the weights of the Vera layers with their default initialization. Don't change " + "this setting, except if you know exactly what you're doing." + ), + }, + ) + layers_to_transform: Optional[Union[List[int], int]] = field( + default=None, + metadata={ + "help": ( + "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers" + " indexes that are specified inside this list. If a single integer is passed, PEFT will transform only" + " the layer at this index." + ) + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer" + " pattern is not in the common layers pattern." + ) + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.VERA + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) + + if not self.save_projection: + warnings.warn( + "Specified to not save vera_A and vera_B within the state dictionary, instead they will be restored " + "using the PRNG key store in `config.projection_prng_key`. Consider setting `config.save_projection` " + "to `True` to guarantee restoring the checkpoint correctly on all system configurations." + ) diff --git a/peft/tuners/vera/layer.py b/peft/tuners/vera/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..934e2fdb5e49b8a01c59629e1683acfc884c045d --- /dev/null +++ b/peft/tuners/vera/layer.py @@ -0,0 +1,267 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.pytorch_utils import Conv1D + +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.other import transpose + +from .buffer_dict import BufferDict + + +class VeraLayer(BaseTunerLayer): + # List all names of layers that may contain adapter weights + adapter_layer_names = ("vera_lambda_b", "vera_lambda_d") + other_param_names = ("vera_A", "vera_B") + + def __init__(self, base_layer: nn.Module, **kwargs): + self.base_layer = base_layer + self.r = {} + self.vera_dropout = nn.ModuleDict({}) + + # For storing vector scale + self.vera_lambda_b = nn.ParameterDict({}) + self.vera_lambda_d = nn.ParameterDict({}) + + # Stores a reference to the vera_A/B BufferDict. + # Set to `None` otherwise to avoid computation with random weights + self.vera_A: Optional[BufferDict] = None + self.vera_B: Optional[BufferDict] = None + + # Mark the weight as unmerged + self._disable_adapters = False + self.merged_adapters = [] + + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + in_features, out_features = base_layer.in_features, base_layer.out_features + elif isinstance(base_layer, Conv1D): + in_features, out_features = ( + base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape + ) + + self.in_features = in_features + self.out_features = out_features + self.kwargs = kwargs + + @property + def merged(self) -> bool: + return bool(self.merged_adapters) + + def update_layer( + self, + adapter_name, + vera_A: BufferDict, + vera_B: BufferDict, + r, + vera_dropout, + init_weights, + d_initial: float = 0.1, + ): + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + self.r[adapter_name] = r + if vera_dropout > 0.0: + vera_dropout_layer = nn.Dropout(p=vera_dropout) + else: + vera_dropout_layer = nn.Identity() + + self.vera_dropout.update(nn.ModuleDict({adapter_name: vera_dropout_layer})) + # Actual trainable parameters + self.vera_lambda_b[adapter_name] = nn.Parameter(torch.ones(self.out_features), requires_grad=True) + self.vera_lambda_d[adapter_name] = nn.Parameter(torch.randn(r), requires_grad=True) + + # non trainable references to vera_A/B buffers + self.vera_A = vera_A + self.vera_B = vera_B + if adapter_name not in vera_A: + # This means that this is not the first VeRA adapter. We have to add an entry in the dict for this adapter. + if len(self.vera_A) < 1: + raise ValueError( + "The `vera_A` and `vera_B` buffers are empty. This should not happen. Please report this issue." + ) + # we can take any of the existing adapter's parameters, as they should all be identical + vera_A_param = list(self.vera_A.values())[0] + vera_B_param = list(self.vera_B.values())[0] + self.vera_A[adapter_name] = vera_A_param + self.vera_B[adapter_name] = vera_B_param + + if init_weights: + self.reset_vera_parameters(adapter_name, d_initial=d_initial) + + weight = getattr(self.get_base_layer(), "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + if weight.dtype.is_floating_point or weight.dtype.is_complex: + self.to(weight.device, dtype=weight.dtype) + else: + self.to(weight.device) + + self.set_adapter(self.active_adapters) + + def reset_vera_parameters(self, adapter_name, d_initial: float = 0.1): + if adapter_name in self.vera_lambda_d.keys(): + with torch.no_grad(): + nn.init.zeros_(self.vera_lambda_d[adapter_name]).fill_(d_initial) + nn.init.zeros_(self.vera_lambda_b[adapter_name]) + + +class Linear(nn.Linear, VeraLayer): + # Vera implemented in a dense layer + def __init__( + self, + base_layer, + vera_A: BufferDict, + vera_B: BufferDict, + adapter_name: str, + r: int = 0, + vera_dropout: float = 0.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + is_target_conv_1d_layer: bool = False, + init_weights: bool = True, + d_initial: float = 0.1, + **kwargs, + ) -> None: + # this gets the init from nn.Linear's super perspective, i.e. nn.Module.__init__, which should always be called + super(nn.Linear, self).__init__() + VeraLayer.__init__(self, base_layer, **kwargs) + self.fan_in_fan_out = fan_in_fan_out + + self._active_adapter = adapter_name + self.update_layer(adapter_name, vera_A, vera_B, r, vera_dropout, init_weights, d_initial=d_initial) + self.is_target_conv_1d_layer = is_target_conv_1d_layer + + def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.vera_lambda_d.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weights = base_layer.weight.data.clone() + + orig_weights += self.get_delta_weight(active_adapter) + + if not torch.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = orig_weights + else: + base_layer.weight.data += self.get_delta_weight(active_adapter) + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.vera_lambda_d.keys(): + self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) + + def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + vera_A = self.vera_A[adapter] + vera_B = self.vera_B[adapter] + + device = vera_B.device + dtype = vera_B.dtype + + # In case users wants to merge the adapter weights that are in + # float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16. + cast_to_fp32 = device.type == "cpu" and dtype == torch.float16 + + lambda_d = self.vera_lambda_d[adapter] + lambda_b = self.vera_lambda_b[adapter] + + if cast_to_fp32: + vera_A = vera_A.float() + vera_B = vera_B.float() + lambda_d = lambda_d.float() + lambda_b = lambda_b.float() + + lambda_b = lambda_b.unsqueeze(-1) + lambda_d = lambda_d.unsqueeze(-1) + output_tensor = transpose((lambda_b * vera_B) @ (lambda_d * vera_A), self.fan_in_fan_out) + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + # TODO: why? + self.vera_lambda_d[adapter].data = lambda_d.to(dtype) + self.vera_lambda_b[adapter].data = lambda_b.to(dtype) + + return output_tensor + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.vera_lambda_d.keys(): + continue + + lambda_d = self.vera_lambda_d[active_adapter] + lambda_b = self.vera_lambda_b[active_adapter] + + vera_A = self.vera_A[active_adapter] + vera_B = self.vera_B[active_adapter] + + dropout = self.vera_dropout[active_adapter] + x = x.to(lambda_d.dtype) + result = result + lambda_b * F.linear(lambda_d * F.linear(dropout(x), vera_A), vera_B) + + result = result.to(previous_dtype) + return result diff --git a/peft/tuners/vera/model.py b/peft/tuners/vera/model.py new file mode 100644 index 0000000000000000000000000000000000000000..2ecd1c9ab8d65242fa4fc0216518adf96ab5469c --- /dev/null +++ b/peft/tuners/vera/model.py @@ -0,0 +1,474 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +import warnings +from dataclasses import asdict +from enum import Enum +from typing import Optional, Union + +import torch +import torch.nn as nn +from torch.nn.init import _calculate_correct_fan +from tqdm import tqdm +from transformers.pytorch_utils import Conv1D + +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists +from peft.utils import ( + TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, + _get_submodules, +) + +from ..tuners_utils import _maybe_include_all_linear_layers +from .buffer_dict import BufferDict +from .config import VeraConfig +from .layer import Linear, VeraLayer + + +def _kaiming_init( + tensor_or_shape: Union[torch.Tensor, tuple[int, ...]], + generator: torch.Generator, +) -> torch.Tensor: + """ + Kaiming Uniform Initialisation adapted to accept a `torch.Generator` object for PRNG. + + Args: + tensor_or_shape (`Union[torch.Tensor, tuple[int, ...]]`): + Tensor to initialise, or shape of new tensor to create and then initialise. + generator: (`torch.Generator`): + Generator object that manages the state of the PRNG algorithm in use. + + Returns: + `torch.Tensor`: The initialised tensor. + """ + if isinstance(tensor_or_shape, tuple): + tensor = torch.empty(tensor_or_shape) + else: + tensor = tensor_or_shape + fan = _calculate_correct_fan(tensor, "fan_in") + gain = math.sqrt(2) + std = gain / math.sqrt(fan) + bound = math.sqrt(3.0) * std + + with torch.no_grad(): + return tensor.uniform_(-bound, bound, generator=generator) + + +class VeraModel(BaseTuner): + """ + Creates Vector-based Random Matrix Adaptation (Vera) model from a pretrained transformers model. + + Args: + model ([`~transformers.PreTrainedModel`]): The model to be adapted. + config ([`VeraConfig`]): The configuration of the Vera model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + `torch.nn.Module`: The Vera model. + + Example: + + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import VeraConfig, get_peft_model + + >>> base_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") + >>> config = VeraConfig(r=128) + >>> model = get_peft_model(base_model, config) + ``` + + **Attributes**: + - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. + - **peft_config** ([`VeraConfig`]): The configuration of the Vera model. + """ + + prefix: str = "vera_lambda" + + def __init__(self, model, config, adapter_name) -> None: + super().__init__(model, config, adapter_name) + + def _find_first_dim(self, config) -> tuple[int, int]: + """ + Finds the first linear layer that has been wrapped with Vera, and extract the input and output dimension. + + This will be used for determining the size of the shared vera_A and vera_B matrices. + + This will throw an error if there are multiple layers of the same type with different shapes. + """ + model_config = getattr(self.model, "config", {"model_type": "custom"}) + if hasattr(model_config, "to_dict"): + model_config = model_config.to_dict() + + peft_config = self._prepare_adapter_config(config, model_config) + peft_config = _maybe_include_all_linear_layers(peft_config, self.model) + + first_shape = None + for key, module in self.model.named_modules(): + if not self._check_target_module_exists(peft_config, key): + continue + + if isinstance(module, (nn.Linear, Conv1D)): + module_shape = tuple(module.weight.shape) + if isinstance(module, Conv1D): + module_shape = module_shape[::-1] + else: + continue + + if first_shape is None: + first_shape = module_shape + continue + + if module_shape != first_shape: + raise ValueError( + "Multiple target layers with different dimensions were specified. VeRA only supports a " + f"single dimension size. Expected shape {first_shape}, got {module_shape}." + ) + + if first_shape is None: + msg = "No layers types compatible with VeRA were found. Please check `peft_config.target_modules`." + raise ValueError(msg) + + return first_shape + + def _init_vera_A_vera_B(self, config: VeraConfig, adapter_name: str) -> None: + first_linear_out_dim, first_linear_in_dim = self._find_first_dim(config) + + # use of persistent to exclude vera_A and vera_B from the state dict if we choose not to save them. + self.vera_A = BufferDict({}, persistent=config.save_projection) + self.vera_B = BufferDict({}, persistent=config.save_projection) + + # deterministic init of vera_A and vera_B if we know the key + generator = torch.Generator(device="cpu").manual_seed(config.projection_prng_key) + vera_A = _kaiming_init((config.r, first_linear_in_dim), generator=generator) + vera_B = _kaiming_init((first_linear_out_dim, config.r), generator=generator) + self.vera_A[adapter_name] = vera_A + self.vera_B[adapter_name] = vera_B + + def _pre_injection_hook(self, model: nn.Module, config: VeraConfig, adapter_name: str) -> None: + self._init_vera_A_vera_B(config, adapter_name) + + def _check_new_adapter_config(self, config: VeraConfig) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + # the below todo is copied from LoRA + # TODO: there should be a check if any of the existing adapters actually has bias != "none", or else the check + # does not fully correspond to the error message. + if (len(self.peft_config) > 1) and (config.bias != "none"): + raise ValueError( + f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, " + "set bias to 'none' for all adapters." + ) + + for existing_config in self.peft_config.values(): + if existing_config is config: + # skip the current config + continue + + if existing_config.projection_prng_key != config.projection_prng_key: + raise ValueError( + f"Vera PRNG initialisation key must be the same for all adapters. Got {config.projection_prng_key=} but " + f"previous config had {existing_config.projection_prng_key}." + ) + + save_project_unique_values = sorted({config.save_projection for config in self.peft_config.values()}) + if len(save_project_unique_values) > 1: + raise ValueError( + "VeRA projection weights must be saved for all adapters or none, but got multiple different values: " + f"{save_project_unique_values}" + ) + + @staticmethod + def _check_target_module_exists(vera_config, key): + return check_target_module_exists(vera_config, key) + + def _create_and_replace( + self, + vera_config, + adapter_name, + target, + target_name, + parent, + current_key, + **optional_kwargs, + ): + if current_key is None: + raise ValueError("Current Key shouldn't be `None`") + + r = vera_config.r + bias = hasattr(target, "bias") and target.bias is not None + kwargs = { + "r": r, + "vera_dropout": vera_config.vera_dropout, + "fan_in_fan_out": vera_config.fan_in_fan_out, + "init_weights": vera_config.init_weights, + } + kwargs["bias"] = bias + # TODO: add quantization support + + if isinstance(target, Linear): + target.update_layer( + adapter_name, + self.vera_A, + self.vera_B, + r, + vera_config.vera_dropout, + vera_config.init_weights, + d_initial=vera_config.d_initial, + ) + else: + new_module = self._create_new_module(vera_config, self.vera_A, self.vera_B, adapter_name, target, **kwargs) + if adapter_name not in self.active_adapter: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + + @staticmethod + def _replace_module(parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + if "vera_" in name: + module.to(child.weight.device) + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.named_parameters(): + if self.prefix not in n: + p.requires_grad = False + + for active_adapter in self.active_adapters: + bias = self.peft_config[active_adapter].bias + if bias == "none": + continue + + if bias == "all": + for n, p in model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "vera_only": + for m in model.modules(): + if isinstance(m, VeraLayer) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError(f"Requested bias: {bias}, is not implemented.") + + @staticmethod + def _create_new_module(vera_config, vera_A, vera_B, adapter_name, target, **kwargs): + bias = kwargs.pop("bias", False) + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, torch.nn.Linear): + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." + ) + kwargs["fan_in_fan_out"] = vera_config.fan_in_fan_out = False + elif isinstance(target_base_layer, Conv1D): + kwargs["is_target_conv_1d_layer"] = True + if not kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to False but the target module is `Conv1D`. " + "Setting fan_in_fan_out to True." + ) + kwargs["fan_in_fan_out"] = vera_config.fan_in_fan_out = True + else: + raise ValueError( + f"Target module {target} is not supported. Currently, only the following modules are supported: " + "`torch.nn.Linear`, `transformers.pytorch_utils.Conv1D`." + ) + new_module = Linear( + target, + vera_A, + vera_B, + adapter_name, + bias=bias, + d_initial=vera_config.d_initial, + **kwargs, + ) + + return new_module + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.model, name) + + def get_peft_config_as_dict(self, inference: bool = False): + config_dict = {} + for key, value in self.peft_config.items(): + config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} + if inference: + config["inference_mode"] = True + config_dict[key] = config + return config + + def _set_adapter_layers(self, enabled=True): + for module in self.model.modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self): + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self): + for active_adapter in self.active_adapters: + val = self.peft_config[active_adapter].bias + if val != "none": + msg = ( + f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " + "output as the the base model would without adaption." + ) + warnings.warn(msg) + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name): + for module in self.model.modules(): + if isinstance(module, VeraLayer): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return peft_config + + def _unload_and_optionally_merge( + self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[list[str]] = None, + ): + # we cannot use self.prefix as we want to include non-trainable vera parameters + key_list = [key for key, _ in self.model.named_modules() if "vera" not in key] + desc = "Unloading " + ("and merging " if merge else "") + "model" + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + + if hasattr(target, "base_layer"): + if merge: + target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + setattr(parent, target_name, target.modules_to_save[target.active_adapter]) + + return self.model + + def delete_adapter(self, adapter_name: str): + """ + Deletes an existing adapter. + + Args: + adapter_name (str): Name of the adapter to be deleted. + """ + if adapter_name not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter_name} does not exist") + del self.peft_config[adapter_name] + + # we cannot use self.prefix as we want to include non-trainable vera parameters + key_list = [key for key, _ in self.model.named_modules() if "vera" not in key] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, VeraLayer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapter[:] + + self.active_adapter = new_adapter or [] + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ): + r""" + This method merges the Vera layers into the base model. This is needed if someone wants to use the base model + as a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + + Example: + + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import PeftModel + + >>> base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b") + >>> peft_model_id = "smangrul/falcon-40B-int4-peft-lora-sfttrainer-sample" + >>> model = PeftModel.from_pretrained(base_model, peft_model_id) + >>> merged_model = model.merge_and_unload() + ``` + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self): + """ + Gets back the base model by removing all the Vera modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) diff --git a/peft/utils/__init__.py b/peft/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f3c56cc8e1e9595553b38d1d43927da8bfe11e4 --- /dev/null +++ b/peft/utils/__init__.py @@ -0,0 +1,53 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all + +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType +from .loftq_utils import replace_lora_weights_loftq +from .peft_types import PeftType, TaskType +from .other import ( + TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, + CONFIG_NAME, + WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + INCLUDE_LINEAR_LAYERS_SHORTHAND, + _set_trainable, + bloom_model_postprocess_past_key_value, + prepare_model_for_kbit_training, + shift_tokens_right, + transpose, + _get_batch_size, + _get_submodules, + _set_adapter, + _freeze_adapter, + ModulesToSaveWrapper, + _prepare_prompt_learning_config, + _is_valid_match, + infer_device, + get_auto_gptq_quant_linear, + get_quantization_config, + id_tensor_storage, + cast_mixed_precision_params, +) +from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict, load_peft_weights diff --git a/peft/utils/__pycache__/__init__.cpython-310.pyc b/peft/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d6be355c8b5f487c5b51e1b52907443c9d8ed3f Binary files /dev/null and b/peft/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/peft/utils/__pycache__/constants.cpython-310.pyc b/peft/utils/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c88b9505331e5fdc88bd0f94f3301c5ce07e4d8 Binary files /dev/null and b/peft/utils/__pycache__/constants.cpython-310.pyc differ diff --git a/peft/utils/__pycache__/integrations.cpython-310.pyc b/peft/utils/__pycache__/integrations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22f34689aa7dc1dc47b17de541b76a2c08c42cc3 Binary files /dev/null and b/peft/utils/__pycache__/integrations.cpython-310.pyc differ diff --git a/peft/utils/__pycache__/loftq_utils.cpython-310.pyc b/peft/utils/__pycache__/loftq_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40add20f2681b0668f5bbcc8415e8d2046f8820a Binary files /dev/null and b/peft/utils/__pycache__/loftq_utils.cpython-310.pyc differ diff --git a/peft/utils/__pycache__/merge_utils.cpython-310.pyc b/peft/utils/__pycache__/merge_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8be9fd55f3de55880a37922719d27d8f50fe02f Binary files /dev/null and b/peft/utils/__pycache__/merge_utils.cpython-310.pyc differ diff --git a/peft/utils/__pycache__/other.cpython-310.pyc b/peft/utils/__pycache__/other.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7665a0977730ebff9edcb1d3bf5ed2f82aaf243c Binary files /dev/null and b/peft/utils/__pycache__/other.cpython-310.pyc differ diff --git a/peft/utils/__pycache__/peft_types.cpython-310.pyc b/peft/utils/__pycache__/peft_types.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..466ea91521aa922489526a91729886436cee5818 Binary files /dev/null and b/peft/utils/__pycache__/peft_types.cpython-310.pyc differ diff --git a/peft/utils/__pycache__/save_and_load.cpython-310.pyc b/peft/utils/__pycache__/save_and_load.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7baa51dc5f10e4629c2bce8ecbd49eaf7b6b140 Binary files /dev/null and b/peft/utils/__pycache__/save_and_load.cpython-310.pyc differ diff --git a/peft/utils/constants.py b/peft/utils/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..98df496275606dc33417c342ca39124c2131be37 --- /dev/null +++ b/peft/utils/constants.py @@ -0,0 +1,215 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +# needed for prefix-tuning of bloom model +def bloom_model_postprocess_past_key_value(past_key_values): + past_key_values = torch.cat(past_key_values) + total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape + keys = past_key_values[: total_layers // 2] + keys = keys.transpose(2, 3).reshape( + total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens + ) + values = past_key_values[total_layers // 2 :] + values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim) + + return tuple(zip(keys, values)) + + +# needed for prefix-tuning of StarCoder models +def starcoder_model_postprocess_past_key_value(past_key_values): + result = [] + for k in past_key_values: + k = k[:, :, 0] + k = k.permute([1, 2, 0, 3]) + k = k.reshape(*k.shape[:-2], -1) + result.append(k) + return tuple(result) + + +TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = { + "bloom": bloom_model_postprocess_past_key_value, + "gpt_bigcode": starcoder_model_postprocess_past_key_value, +} + +TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING = { + "llama": ["input_layernorm", "post_attention_layernorm", "norm"], + "bloom": ["input_layernorm", "post_attention_layernorm", "ln_f"], + "llava": [ + "multi_modal_projector", + "input_layernorm", + "post_attention_layernorm", + "norm", + "embed_tokens", + "lm_head", + ], + "t5": ["layer_norm", "final_layer_norm"], + "mt5": ["layer_norm", "final_layer_norm"], + "bart": ["self_attn_layer_norm", "encoder_attn_layer_norm", "final_layer_norm"], + "gpt2": ["ln_1", "ln_2", "ln_f"], + "blip-2": ["layernorm", "LayerNorm", "final_layer_norm", "self_attn_layer_norm"], + "gptj": ["ln_1", "ln_f"], + "falcon": ["input_layernorm", "post_attention_layernorm", "ln_f"], + "mistral": ["input_layernorm", "post_attention_layernorm", "norm"], + "phi": ["input_layernorm", "final_layernorm"], + "gemma": ["input_layernorm", "post_attention_layernorm", "norm"], +} + +TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = { + "t5": ["q", "v"], + "mt5": ["q", "v"], + "bart": ["q_proj", "v_proj"], + "gpt2": ["c_attn"], + "bloom": ["query_key_value"], + "blip-2": ["q", "v", "q_proj", "v_proj"], + "opt": ["q_proj", "v_proj"], + "gptj": ["q_proj", "v_proj"], + "gpt_neox": ["query_key_value"], + "gpt_neo": ["q_proj", "v_proj"], + "bert": ["query", "value"], + "roberta": ["query", "value"], + "xlm-roberta": ["query", "value"], + "electra": ["query", "value"], + "deberta-v2": ["query_proj", "value_proj"], + "deberta": ["in_proj"], + "layoutlm": ["query", "value"], + "llama": ["q_proj", "v_proj"], + "chatglm": ["query_key_value"], + "gpt_bigcode": ["c_attn"], + "mpt": ["Wqkv"], + "RefinedWebModel": ["query_key_value"], + "RefinedWeb": ["query_key_value"], + "falcon": ["query_key_value"], + "btlm": ["c_proj", "c_attn"], + "codegen": ["qkv_proj"], + "mistral": ["q_proj", "v_proj"], + "mixtral": ["q_proj", "v_proj"], + "stablelm": ["q_proj", "v_proj"], + "phi": ["q_proj", "v_proj", "fc1", "fc2"], + "gemma": ["q_proj", "v_proj"], +} + +TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING = { + "t5": ["k", "v", "wo"], + "mt5": ["k", "v", "wi_1"], + "gpt2": ["c_attn", "mlp.c_proj"], + "bloom": ["query_key_value", "mlp.dense_4h_to_h"], + "roberta": ["key", "value", "output.dense"], + "opt": ["q_proj", "k_proj", "fc2"], + "gptj": ["q_proj", "v_proj", "fc_out"], + "gpt_neox": ["query_key_value", "dense_4h_to_h"], + "gpt_neo": ["q_proj", "v_proj", "c_proj"], + "bart": ["q_proj", "v_proj", "fc2"], + "gpt_bigcode": ["c_attn", "mlp.c_proj"], + "llama": ["k_proj", "v_proj", "down_proj"], + "mistral": ["k_proj", "v_proj", "down_proj"], + "mixtral": ["k_proj", "v_proj", "w2"], + "bert": ["key", "value", "output.dense"], + "deberta-v2": ["key_proj", "value_proj", "output.dense"], + "deberta": ["in_proj", "output.dense"], + "RefinedWebModel": ["query_key_value", "dense_4h_to_h"], + "RefinedWeb": ["query_key_value", "dense_4h_to_h"], + "falcon": ["query_key_value", "dense_4h_to_h"], + "phi": ["q_proj", "v_proj", "fc2"], + "gemma": ["q_proj", "v_proj", "down_proj"], +} + +TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING = { + "t5": ["wo"], + "mt5": [], + "gpt2": ["mlp.c_proj"], + "bloom": ["mlp.dense_4h_to_h"], + "roberta": ["output.dense"], + "opt": ["fc2"], + "gptj": ["fc_out"], + "gpt_neox": ["dense_4h_to_h"], + "gpt_neo": ["c_proj"], + "bart": ["fc2"], + "gpt_bigcode": ["mlp.c_proj"], + "llama": ["down_proj"], + "mistral": ["down_proj"], + "mixtral": ["w2"], + "bert": ["output.dense"], + "deberta-v2": ["output.dense"], + "deberta": ["output.dense"], + "RefinedWeb": ["dense_4h_to_h"], + "RefinedWebModel": ["dense_4h_to_h"], + "falcon": ["dense_4h_to_h"], + "phi": ["fc2"], + "gemma": ["down_proj"], +} + +TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING = { + "t5": ["q", "k", "v", "o", "wi", "wo"], + "mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"], + "bart": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], + "gpt2": ["c_attn"], + "bloom": ["query_key_value"], + "opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], + "gptj": ["q_proj", "v_proj"], + "gpt_neox": ["query_key_value"], + "gpt_neo": ["q_proj", "v_proj"], + "llama": ["q_proj", "v_proj"], + "bert": ["query", "value"], + "roberta": ["query", "key", "value", "dense"], + # "xlm-roberta": ["query", "value"], + # "electra": ["query", "value"], + "deberta-v2": ["query_proj", "key_proj", "value_proj", "dense"], + "gpt_bigcode": ["c_attn"], + "deberta": ["in_proj"], + # "layoutlm": ["query", "value"], +} + +TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING = { + "t5": ["q", "v"], + "mt5": ["q", "v"], + "bart": ["q_proj", "v_proj"], + "gpt2": ["c_attn"], + "bloom": ["query_key_value"], + "blip-2": ["q", "v", "q_proj", "v_proj"], + "opt": ["q_proj", "v_proj"], + "gptj": ["q_proj", "v_proj"], + "gpt_neox": ["query_key_value"], + "gpt_neo": ["q_proj", "v_proj"], + "bert": ["query", "value"], + "roberta": ["query", "value"], + "xlm-roberta": ["query", "value"], + "electra": ["query", "value"], + "deberta-v2": ["query_proj", "value_proj"], + "deberta": ["in_proj"], + "layoutlm": ["query", "value"], + "llama": ["q_proj", "v_proj"], + "chatglm": ["query_key_value"], + "gpt_bigcode": ["c_attn"], + "mpt": ["Wqkv"], + "RefinedWebModel": ["query_key_value"], + "RefinedWeb": ["query_key_value"], + "falcon": ["query_key_value"], + # "btlm": ["c_proj", "c_attn"], # tested, does not work because of different shapes + "codegen": ["qkv_proj"], + # "mistral": ["q_proj", "v_proj"], # tested, does not work because of different shapes + # "mixtral": ["q_proj", "v_proj"], # tested, does not work because of different shapes + "stablelm": ["q_proj", "v_proj"], + # "phi": ["q_proj", "v_proj", "fc1", "fc2"], # tested, does not work because of different shapes + "phi": ["q_proj", "v_proj"], + # "gemma": ["q_proj", "v_proj"], # tested, does not work because of different shapes +} + +WEIGHTS_NAME = "adapter_model.bin" +SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" +CONFIG_NAME = "adapter_config.json" +EMBEDDING_LAYER_NAMES = ["embed_tokens", "lm_head"] +INCLUDE_LINEAR_LAYERS_SHORTHAND = "all-linear" +TOKENIZER_CONFIG_NAME = "tokenizer_config.json" diff --git a/peft/utils/integrations.py b/peft/utils/integrations.py new file mode 100644 index 0000000000000000000000000000000000000000..14b37e32e8863ff720427861bb2e434bb9cca22e --- /dev/null +++ b/peft/utils/integrations.py @@ -0,0 +1,103 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager + +import packaging.version +import torch +import transformers + + +@contextmanager +def gather_params_ctx(param, modifier_rank: int = 0, fwd_module: torch.nn.Module = None): + """Call DeepSpeed GatheredParameters context manager if DeepSpeed is enabled, otherwise do nothing.""" + if packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.33.0"): + from transformers.integrations import is_deepspeed_zero3_enabled + else: + from transformers.deepspeed import is_deepspeed_zero3_enabled + + if not is_deepspeed_zero3_enabled(): + yield + return + + import deepspeed + + with deepspeed.zero.GatheredParameters(param, modifier_rank=modifier_rank, fwd_module=fwd_module): + yield + return + + +def dequantize_module_weight(module: torch.nn.Module) -> torch.nn.Parameter: + """ + Helper function to dequantize a quantized weight. + + This function should be extended if more quantization schemes are added to the library. + + If the weight is not quantized, it will be returned as is. + """ + if hasattr(module, "W_q"): # For handling HQQ quantized weight + weight = module.dequantize() + return weight + + weight = module.weight + if not isinstance(weight, torch.nn.Parameter): + raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead") + + cls_name = weight.__class__.__name__ + if cls_name not in ("Params4bit", "Int8Params"): + return weight + + quant_state = getattr(module, "state", None) + device = weight.device + is_cpu = device.type == torch.device("cpu").type + weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb + if is_cpu: + # dequantize_bnb_weight for 8bit moves the device in-place, thus we need to move it back to CPU if necessary + module.weight = module.weight.to(device) + return weight + + +def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None): + """Helper function to dequantize 4bit or 8bit bnb weights. + + Since dequantization is not supported on CPU, the weight will be temporarily moved to CUDA if necessary. + """ + import bitsandbytes as bnb + + # BNB requires CUDA weights + device = weight.device + is_cpu = device.type == torch.device("cpu").type + if is_cpu: + weight = weight.to(torch.device("cuda")) + + cls_name = weight.__class__.__name__ + if cls_name == "Params4bit": + dequantized = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) + if is_cpu: + dequantized = dequantized.to(device) + return dequantized + + if state.SCB is None: + state.SCB = weight.SCB + + im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device) + im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im) + im, Sim = bnb.functional.transform(im, "col32") + if state.CxB is None: + state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB) + out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB) + dequantized = bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t() + if is_cpu: + dequantized = dequantized.to(device) + return dequantized diff --git a/peft/utils/loftq_utils.py b/peft/utils/loftq_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f8323485a70ecd016bfcfcbdff378a7749f8569e --- /dev/null +++ b/peft/utils/loftq_utils.py @@ -0,0 +1,410 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Reference code: https://github.com/yxli2123/LoftQ/blob/main/utils.py +# Reference paper: https://arxiv.org/abs/2310.08659 + +from __future__ import annotations + +import logging +import os +from typing import Callable, Optional, Union + +import torch +from huggingface_hub import snapshot_download +from huggingface_hub.utils import LocalEntryNotFoundError +from safetensors import SafetensorError, safe_open +from transformers.utils import cached_file +from transformers.utils.hub import get_checkpoint_shard_files + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available + + +class NFQuantizer: + def __init__(self, num_bits=2, device="cuda", method="normal", block_size=64, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_bits = num_bits + self.device = device + self.method = method + self.block_size = block_size + if self.method == "normal": + self.norm_lookup_table = self.create_normal_map(num_bits=self.num_bits) + self.norm_lookup_table = self.norm_lookup_table.to(device) + elif self.method == "uniform": + self.norm_lookup_table = self.create_uniform_map(num_bits=self.num_bits) + self.norm_lookup_table = self.norm_lookup_table.to(device) + else: + raise NotImplementedError("Other quantization methods not supported yet.") + + @staticmethod + def create_uniform_map(symmetric=False, num_bits=4): + if symmetric: + # print("symmetric uniform quantization") + negative = torch.linspace(-1, 0, 2 ** (num_bits - 1)) + positive = torch.linspace(0, 1, 2 ** (num_bits - 1)) + table = torch.cat([negative, positive[1:]]) + else: + # print("asymmetric uniform quantization") + table = torch.linspace(-1, 1, 2**num_bits) + return table + + @staticmethod + def create_normal_map(offset=0.9677083, symmetric=False, num_bits=2): + try: + from scipy.stats import norm + except ImportError: + raise ImportError("The required package 'scipy' is not installed. Please install it to continue.") + + variations = 2**num_bits + if symmetric: + v = norm.ppf(torch.linspace(1 - offset, offset, variations + 1)).tolist() + values = [] + for index in range(len(v) - 1): + values.append(0.5 * v[index] + 0.5 * v[index + 1]) + v = values + else: + # one more positive value, this is an asymmetric type + v1 = norm.ppf(torch.linspace(offset, 0.5, variations // 2 + 1)[:-1]).tolist() + v2 = [0] + v3 = (-norm.ppf(torch.linspace(offset, 0.5, variations // 2)[:-1])).tolist() + v = v1 + v2 + v3 + + values = torch.Tensor(v) + values = values.sort().values + values /= values.max() + return values + + def quantize_tensor(self, weight): + max_abs = torch.abs(weight).max() + weight_normed = weight / max_abs + + weight_normed_expanded = weight_normed.unsqueeze(-1) + + # Reshape L to have the same number of dimensions as X_expanded + L_reshaped = torch.tensor(self.norm_lookup_table).reshape(1, -1) + + # Calculate the absolute difference between X_expanded and L_reshaped + abs_diff = torch.abs(weight_normed_expanded - L_reshaped) + + # Find the index of the minimum absolute difference for each element + qweight = torch.argmin(abs_diff, dim=-1) + return qweight, max_abs + + def dequantize_tensor(self, qweight, max_abs): + qweight_flatten = qweight.flatten() + + weight_normed = self.norm_lookup_table[qweight_flatten] + weight = weight_normed * max_abs + + weight = weight.reshape(qweight.shape) + + return weight + + def quantize_block(self, weight): + if len(weight.shape) != 2: + raise ValueError(f"Only support 2D matrix, but your input has {len(weight.shape)} dimensions.") + if weight.shape[0] * weight.shape[1] % self.block_size != 0: + raise ValueError( + f"Weight with shape ({weight.shape[0]} x {weight.shape[1]}) " + f"is not dividable by block size {self.block_size}." + ) + + M, N = weight.shape + device = weight.device + + # Quantization + weight_flatten = weight.flatten() # (M*N, ) + weight_block = weight_flatten.reshape(-1, self.block_size) # (L, B), L = M * N / B + if self.method == "normal": + weight_max = weight_block.abs().max(dim=-1)[0] # (L, 1) + elif self.method == "uniform": + weight_max = weight_block.mean(dim=-1) + 2.5 * weight_block.std(dim=-1) + else: + raise NotImplementedError("Method not supported yet.") + weight_max = weight_max.unsqueeze(-1) + weight_divabs = weight_block / weight_max # (L, B) + weight_divabs = weight_divabs.unsqueeze(-1) # (L, B, 1) + L_reshaped = self.norm_lookup_table.reshape(1, -1) # (1, 2**K) + + abs_diff = torch.abs(weight_divabs - L_reshaped) # (L, B, 2**K) + qweight = torch.argmin(abs_diff, dim=-1) # (L, B) + + # Pack multiple k-bit into uint8 + qweight = qweight.reshape(-1, 8 // self.num_bits) + qweight_pack = torch.zeros((M * N // 8 * self.num_bits, 1), dtype=torch.uint8, device=device) + + # data format example: + # [1, 0, 3, 2] or [01, 00, 11, 10] -> [10110001], LIFO + for i in range(8 // self.num_bits): + qweight[:, i] = qweight[:, i] << i * self.num_bits + qweight_pack[:, 0] |= qweight[:, i] + + return qweight_pack, weight_max, weight.shape + + def dequantize_block(self, qweight, weight_max, weight_shape): + # unpack weight + device = qweight.device + weight = torch.zeros((qweight.shape[0], 8 // self.num_bits), dtype=torch.float32, device=device) + for i in range(8 // self.num_bits): + lookup_table_idx = qweight.to(torch.long) % 2**self.num_bits # get the most right 2 bits + lookup_table_idx = lookup_table_idx.to(torch.long) + weight[:, i] = self.norm_lookup_table[lookup_table_idx].squeeze() + qweight = qweight >> self.num_bits # right shift 2 bits of the original data + + weight_block = weight.reshape(-1, self.block_size) + weight = weight_block * weight_max + weight = weight.reshape(weight_shape) + + return weight + + +def _low_rank_decomposition(weight, reduced_rank=32): + """ + :param weight: The matrix to decompose, of shape (H, W) :param reduced_rank: the final rank :return: + """ + matrix_dimension = len(weight.size()) + if matrix_dimension != 2: + raise ValueError(f"Only support 2D matrix, but your input has {matrix_dimension} dimensions.") + + # Use SVD to decompose a matrix, default full_matrices is False to save parameters + U, S, Vh = torch.linalg.svd(weight, full_matrices=False) + + L = U @ (torch.sqrt(torch.diag(S)[:, 0:reduced_rank])) + R = torch.sqrt(torch.diag(S)[0:reduced_rank, :]) @ Vh + + return {"L": L, "R": R, "U": U, "S": S, "Vh": Vh, "reduced_rank": reduced_rank} + + +@torch.no_grad() +def loftq_init(weight: Union[torch.Tensor, torch.nn.Parameter], num_bits: int, reduced_rank: int, num_iter=1): + if is_bnb_available(): + import bitsandbytes as bnb + else: + raise ValueError("bitsandbytes is not available, please install it to use LoftQ.") + + if num_bits not in [2, 4, 8]: + raise ValueError("Only support 2, 4, 8 bits quantization") + if num_iter <= 0: + raise ValueError("Number of iterations must be greater than 0") + + out_feature, in_feature = weight.size() + device = weight.device + dtype = weight.dtype + + logging.info( + f"Weight: ({out_feature}, {in_feature}) | Rank: {reduced_rank} " + f"| Num Iter: {num_iter} | Num Bits: {num_bits}" + ) + if not is_bnb_4bit_available() or num_bits in [2, 8]: + quantizer = NFQuantizer(num_bits=num_bits, device=device, method="normal", block_size=64) + compute_device = device + else: + compute_device = "cuda" + + weight = weight.to(device=compute_device, dtype=torch.float32) + res = weight.clone() + for i in range(num_iter): + torch.cuda.empty_cache() + # Quantization + if num_bits == 4 and is_bnb_4bit_available(): + qweight = bnb.nn.Params4bit( + res.to("cpu"), requires_grad=False, compress_statistics=False, quant_type="nf4" + ).to(compute_device) + dequantized_weight = bnb.functional.dequantize_4bit(qweight.data, qweight.quant_state) + else: + quantized_weight, max_abs, shape = quantizer.quantize_block(res) + dequantized_weight = quantizer.dequantize_block(quantized_weight, max_abs, shape) + + res = weight - dequantized_weight + + # Decompose the residual by SVD + output = _low_rank_decomposition(res, reduced_rank=reduced_rank) + L, R, reduced_rank = output["L"], output["R"], output["reduced_rank"] + res = weight - torch.mm(L, R) + + lora_A, lora_B = R, L + + return dequantized_weight.to(device=device, dtype=dtype), lora_A, lora_B + + +@torch.no_grad() +def _loftq_init_new(qweight, weight, num_bits: int, reduced_rank: int): + import bitsandbytes as bnb + + if num_bits != 4: + raise ValueError("Only 4 bit quantization supported at the moment.") + if not is_bnb_4bit_available(): + raise ValueError("bitsandbytes 4bit quantization is not available.") + + compute_device = "cuda" + dequantized_weight = bnb.functional.dequantize_4bit(qweight.data, qweight.quant_state) + + weight = weight.to(device=compute_device, dtype=torch.float32) + residual = weight - dequantized_weight + torch.cuda.empty_cache() + # Decompose the residualidual by SVD + output = _low_rank_decomposition(residual, reduced_rank=reduced_rank) + L, R, reduced_rank = output["L"], output["R"], output["reduced_rank"] + return R, L + + +class _SafetensorLoader: + """ + Simple utility class that loads tensors with safetensors from a single file or sharded files. + + Takes care of file name normalization etc. + + """ + + def __init__(self, peft_model, model_path): + if model_path is None: + try: + model_path = snapshot_download(peft_model.base_model.config._name_or_path, local_files_only=True) + except AttributeError as exc: + raise ValueError( + "The provided model does not appear to be a transformers model. In this case, you must pass the " + "model_path to the safetensors file." + ) from exc + except LocalEntryNotFoundError as exc: + raise ValueError( + "The model.safetensors file must be present on disk, but it could not be found." + ) from exc + + suffix = "model.safetensors" + if not model_path.endswith(suffix): + model_path = os.path.join(model_path, suffix) + + self.model_path = model_path + self.base_model_prefix = getattr(peft_model.get_base_model(), "base_model_prefix", None) + self.prefix = "base_model.model." + self.is_sharded = False + self.weight_map = None + + if not os.path.exists(model_path): + # check if the file is sharded + par_dir = model_path.rpartition(os.path.sep)[0] + try: + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + par_dir, cached_file(par_dir, "model.safetensors.index.json") + ) + except OSError as exc: + raise FileNotFoundError( + f"Could not find file for {model_path}, ensure that there is a (sharded) safetensors file of the model." + ) from exc + + self.is_sharded = True + # maps from 'model-X-of-Y.safetensors' to full file path + file_map = {k.rpartition(os.path.sep)[-1]: k for k in resolved_archive_file} + self.weight_map = {k: file_map[v] for k, v in sharded_metadata["weight_map"].items()} + + def get_tensor(self, name): + if not self.is_sharded: + file_path = self.model_path + else: + file_path = self.weight_map[name] + + with safe_open(file_path, framework="pt", device="cpu") as f: + try: + tensor = f.get_tensor(name) + except SafetensorError as exc: + # no matching key found, we probably need to remove the base model prefix + if self.base_model_prefix: + # remove 1 extra character for "." + name = name[len(self.base_model_prefix) + 1 :] + tensor = f.get_tensor(name) + else: + raise exc + return tensor + + +@torch.no_grad() +def replace_lora_weights_loftq( + peft_model, + model_path: Optional[str] = None, + adapter_name: str = "default", + callback: Optional[Callable[[torch.nn.Module, str], bool]] = None, +): + """ + Replace the LoRA weights of a model quantized with bitsandbytes, using the LoftQ technique. + + The replacement is done on the fly by loading in the non-quantized weights from a locally stored safetensors model + file and initializing the LoRA weights such that the quantization error between the original and quantized weights + is minimized. + + As lazy loading is not possible with pickle, normal PyTorch checkpoint files cannot be supported. + + Depending on the model size, calling this function may take some time to finish. + + Args: + peft_model (`PeftModel`): + The model to replace the weights of. Must be a quantized PEFT model with LoRA layers. + model_path (`Optional[str]`): + The path to the model safetensors file. If the model is a Hugging Face model, this will be inferred from + the model's config. Otherwise, it must be provided. + adapter_name (`str`): + The name of the adapter to replace the weights of. The default adapter name is "default". + callback (`Optional[Callable[[PeftModel, str], bool]]`): + A callback function that will be called after each module is replaced. The callback function should take + the model and the name of the current module as input and return a boolean indicating whether the + replacement should be kept. If the callback returns False, the replacement will be rolled back. This can be + very useful to confirm that the LoftQ initialization actually decreases the quantization error of the + model. As an example, this callback could generate logits for given input and compare it with the logits + from the original, non-quanitzed model with the same input, and only return `True` if there is an + improvement. As this is a greedy optimization, it's possible that calling this function multiple times + yields incremental improvements. + """ + if not is_bnb_4bit_available(): + raise ValueError("bitsandbytes must be installed and the model must be quantized in 4bits.") + + from peft.tuners.lora import Linear4bit + + # model_path = _check_model_path_loftq(model_path, peft_model) + prefix = "base_model.model." + any_match = False + safetensor_loader = _SafetensorLoader(peft_model, model_path) + + # if too slow, consider adding tqdm as an option + for name, module in peft_model.named_modules(): + if not isinstance(module, Linear4bit): + continue + + if not name.startswith(prefix): + raise TypeError("The passed model does not appear to be a valid PeftModel") + + any_match = True + name = name[len(prefix) :] + tensor = safetensor_loader.get_tensor(name + ".weight") + + reduced_rank = module.r[adapter_name] + lora_A, lora_B = _loftq_init_new(module.weight, tensor, num_bits=4, reduced_rank=reduced_rank) + if not callback: + module.lora_A[adapter_name].weight.data = lora_A + module.lora_B[adapter_name].weight.data = lora_B + continue + + lora_A_before = module.lora_A[adapter_name].weight.data + lora_B_before = module.lora_B[adapter_name].weight.data + + module.lora_A[adapter_name].weight.data = lora_A + module.lora_B[adapter_name].weight.data = lora_B + should_replace = callback(peft_model, name) + if not should_replace: + # roll back + module.lora_A[adapter_name].weight.data = lora_A_before + module.lora_B[adapter_name].weight.data = lora_B_before + + del lora_A_before, lora_B_before + + if not any_match: + raise ValueError("No bnb LoRA module found on the model") diff --git a/peft/utils/merge_utils.py b/peft/utils/merge_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..900cb878e2d785e5e800caa58642f0532f4d574b --- /dev/null +++ b/peft/utils/merge_utils.py @@ -0,0 +1,268 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Literal + +import torch + + +def reshape_weight_task_tensors(task_tensors, weights): + """ + Reshapes `weights` to match the shape of `task_tensors` by unsqeezing in the remaining dimenions. + + Args: + task_tensors (`torch.Tensor`): The tensors that will be used to reshape `weights`. + weights (`torch.Tensor`): The tensor to be reshaped. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + new_shape = weights.shape + (1,) * (task_tensors.dim() - weights.dim()) + weights = weights.view(new_shape) + return weights + + +def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor: + """ + Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction + `density`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + + Returns: + `torch.Tensor`: The tensor with the pruned weights. + """ + mask = torch.zeros_like(tensor).reshape(-1) + k = int(density * tensor.numel()) + top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True) + mask[top_k[1]] = 1 + return tensor * mask.reshape(tensor.shape) + + +def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: + """ + Prune random values based on the specified fraction `density`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. + + Returns: + `torch.Tensor`: The pruned tensor. + """ + mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density)) + pruned_tensor = tensor * mask + if rescale: + torch.div(input=pruned_tensor, other=density) + return pruned_tensor + + +def prune( + tensor: torch.Tensor, density: float, method: Literal["magnitude", "random"], rescale: bool = False +) -> torch.Tensor: + """ + Prune the values of task tensors based on the `method`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + method (`str`):The method to use to prune. Should be one of ["magnitude", "random"]. + rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. + + Returns: + `torch.Tensor`: The pruned tensor. + """ + if density >= 1: + warnings.warn(f"The density {density} is greater than or equal to 1, no pruning will be performed.") + return tensor + elif density < 0: + raise ValueError(f"Density should be >= 0, got {density}") + if method == "magnitude": + return magnitude_based_pruning(tensor, density) + elif method == "random": + return random_pruning(tensor, density, rescale=rescale) + else: + raise ValueError(f"Unknown method {method}") + + +def calculate_majority_sign_mask( + tensor: torch.Tensor, method: Literal["total", "frequency"] = "total" +) -> torch.Tensor: + """ + Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0. + + Args: + tensor (`torch.Tensor`):The tensor to get the mask from. + method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"]. + + Returns: + `torch.Tensor`: The majority sign mask. + """ + + sign = tensor.sign() + if method == "total": + sign_magnitude = tensor.sum(dim=0) + elif method == "frequency": + sign_magnitude = sign.sum(dim=0) + else: + raise RuntimeError(f'Unimplemented mask method "{method}"') + majority_sign = torch.where(sign_magnitude >= 0, 1, -1) + return sign == majority_sign + + +def disjoint_merge(task_tensors: torch.Tensor, majority_sign_mask: torch.Tensor) -> torch.Tensor: + """ + Merge the task tensors using disjoint merge. + + Args: + task_tensors (`torch.Tensor`):The task tensors to merge. + majority_sign_mask (`torch.Tensor`):The mask of the majority sign across the task tensors. + + Returns: + `torch.Tensor`: The merged tensor. + """ + mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0) + num_params_preserved = majority_sign_mask.sum(dim=0) + return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0) + + +def task_arithmetic(task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: + """ + Merge the task tensors using `task arithmetic`. + + Args: + task_tensors(`List[torch.Tensor]`):The task tensors to merge. + weights (`torch.Tensor`):The weights of the task tensors. + + Returns: + `torch.Tensor`: The merged tensor. + """ + task_tensors = torch.stack(task_tensors, dim=0) + # weighted task tensors + weights = reshape_weight_task_tensors(task_tensors, weights) + weighted_task_tensors = task_tensors * weights + mixed_task_tensors = weighted_task_tensors.sum(dim=0) + return mixed_task_tensors + + +def magnitude_prune(task_tensors: List[torch.Tensor], weights: torch.Tensor, density: float) -> torch.Tensor: + """ + Merge the task tensors using `task arithmetic`. + + Args: + task_tensors(`List[torch.Tensor]`):The task tensors to merge. + weights (`torch.Tensor`):The weights of the task tensors. + density (`float`): The fraction of values to preserve. Should be in [0,1]. + + Returns: + `torch.Tensor`: The merged tensor. + """ + # sparsify + task_tensors = [prune(tensor, density, method="magnitude") for tensor in task_tensors] + task_tensors = torch.stack(task_tensors, dim=0) + # weighted task tensors + weights = reshape_weight_task_tensors(task_tensors, weights) + weighted_task_tensors = task_tensors * weights + mixed_task_tensors = weighted_task_tensors.sum(dim=0) + return mixed_task_tensors + + +def ties( + task_tensors: List[torch.Tensor], + weights: torch.Tensor, + density: float, + majority_sign_method: Literal["total", "frequency"] = "total", +) -> torch.Tensor: + """ + Merge the task tensors using `ties`. + + Args: + task_tensors(`List[torch.Tensor]`):The task tensors to merge. + weights (`torch.Tensor`):The weights of the task tensors. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + majority_sign_method (`str`): + The method to use to get the majority sign mask. Should be one of ["total", "frequency"]. + + Returns: + `torch.Tensor`: The merged tensor. + """ + # sparsify + task_tensors = [prune(tensor, density, method="magnitude") for tensor in task_tensors] + task_tensors = torch.stack(task_tensors, dim=0) + # Elect Sign + majority_sign_mask = calculate_majority_sign_mask(task_tensors, method=majority_sign_method) + # weighted task tensors + weights = reshape_weight_task_tensors(task_tensors, weights) + weighted_task_tensors = task_tensors * weights + # Disjoint Merge + mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask) + return mixed_task_tensors + + +def dare_linear(task_tensors: List[torch.Tensor], weights: torch.Tensor, density: float) -> torch.Tensor: + """ + Merge the task tensors using `dare linear`. + + Args: + task_tensors(`List[torch.Tensor]`):The task tensors to merge. + weights (`torch.Tensor`):The weights of the task tensors. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + + Returns: + `torch.Tensor`: The merged tensor. + """ + # sparsify + task_tensors = [prune(tensor, density, method="random", rescale=True) for tensor in task_tensors] + task_tensors = torch.stack(task_tensors, dim=0) + # weighted task tensors + weights = reshape_weight_task_tensors(task_tensors, weights) + weighted_task_tensors = task_tensors * weights + mixed_task_tensors = weighted_task_tensors.sum(dim=0) + return mixed_task_tensors + + +def dare_ties( + task_tensors: List[torch.Tensor], + weights: torch.Tensor, + density: float, + majority_sign_method: Literal["total", "frequency"] = "total", +) -> torch.Tensor: + """ + Merge the task tensors using `dare ties`. + + Args: + task_tensors(`List[torch.Tensor]`):The task tensors to merge. + weights (`torch.Tensor`):The weights of the task tensors. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + majority_sign_method (`str`): + The method to use to get the majority sign mask. Should be one of ["total", "frequency"]. + + Returns: + `torch.Tensor`: The merged tensor. + """ + # sparsify + task_tensors = [prune(tensor, density, method="random", rescale=True) for tensor in task_tensors] + task_tensors = torch.stack(task_tensors, dim=0) + # Elect Sign + majority_sign_mask = calculate_majority_sign_mask(task_tensors, method=majority_sign_method) + # weighted task tensors + weights = reshape_weight_task_tensors(task_tensors, weights) + weighted_task_tensors = task_tensors * weights + # Disjoint Merge + mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask) + return mixed_task_tensors diff --git a/peft/utils/other.py b/peft/utils/other.py new file mode 100644 index 0000000000000000000000000000000000000000..bd5a299e85fc3712a9f415adc0ba4df1aece9538 --- /dev/null +++ b/peft/utils/other.py @@ -0,0 +1,616 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import inspect +import os +import warnings +from contextlib import nullcontext +from typing import Optional, Tuple + +import accelerate +import torch +from accelerate.hooks import add_hook_to_module, remove_hook_from_module +from accelerate.utils import is_npu_available, is_xpu_available +from huggingface_hub import file_exists +from huggingface_hub.utils import EntryNotFoundError, HFValidationError +from packaging import version +from safetensors.torch import storage_ptr, storage_size + +from ..import_utils import is_auto_gptq_available, is_torch_tpu_available +from .constants import ( + CONFIG_NAME, + EMBEDDING_LAYER_NAMES, + INCLUDE_LINEAR_LAYERS_SHORTHAND, + SAFETENSORS_WEIGHTS_NAME, + TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, + TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, + WEIGHTS_NAME, + bloom_model_postprocess_past_key_value, + starcoder_model_postprocess_past_key_value, +) + + +mlu_available = False +if version.parse(accelerate.__version__) >= version.parse("0.29.0"): + from accelerate.utils import is_mlu_available + + mlu_available = is_mlu_available() + + +__all__ = [ + "CONFIG_NAME", + "EMBEDDING_LAYER_NAMES", + "SAFETENSORS_WEIGHTS_NAME", + "TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING", + "TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING", + "WEIGHTS_NAME", + "INCLUDE_LINEAR_LAYERS_SHORTHAND", + "bloom_model_postprocess_past_key_value", + "starcoder_model_postprocess_past_key_value", +] + + +# Get current device name based on available devices +def infer_device() -> str: + if torch.cuda.is_available(): + return "cuda" + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return "mps" + elif mlu_available: + return "mlu" + elif is_xpu_available(): + return "xpu" + elif is_npu_available(): + return "npu" + return "cpu" + + +def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None): + r""" + Note this method only works for `transformers` models. + + This method wraps the entire protocol for preparing a model before running a training. This includes: + 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm + head to fp32 + + Args: + model (`transformers.PreTrainedModel`): + The loaded model from `transformers` + use_gradient_checkpointing (`bool`, *optional*, defaults to `True`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`): + Keyword arguments to pass to the gradient checkpointing function, please refer to the documentation of + `torch.utils.checkpoint.checkpoint` for more details about the arguments that you can pass to that method. + Note this is only available in the latest transformers versions (> 4.34.1). + """ + loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) + is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq" + is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm" + is_eetq_quantized = getattr(model, "quantization_method", None) == "eetq" + is_hqq_quantized = getattr(model, "quantization_method", None) == "hqq" or getattr(model, "hqq_quantized", False) + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + + for name, param in model.named_parameters(): + # freeze base model's layers + param.requires_grad = False + + if not is_gptq_quantized and not is_aqlm_quantized and not is_eetq_quantized and not is_hqq_quantized: + # cast all non INT8 parameters to fp32 + for param in model.parameters(): + if ( + (param.dtype == torch.float16) or (param.dtype == torch.bfloat16) + ) and param.__class__.__name__ != "Params4bit": + param.data = param.data.to(torch.float32) + + if ( + loaded_in_kbit or is_gptq_quantized or is_aqlm_quantized or is_eetq_quantized or is_hqq_quantized + ) and use_gradient_checkpointing: + # When having `use_reentrant=False` + gradient_checkpointing, there is no need for this hack + if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]: + # For backward compatibility + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # To support older transformers versions, check if the model supports gradient_checkpointing_kwargs + _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( + inspect.signature(model.gradient_checkpointing_enable).parameters + ) + + if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0: + warnings.warn( + "gradient_checkpointing_kwargs is not supported in this version of transformers. The passed kwargs will be ignored." + " if you want to use that feature, please upgrade to the latest version of transformers.", + FutureWarning, + ) + + gc_enable_kwargs = ( + {} if not _supports_gc_kwargs else {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs} + ) + + # enable gradient checkpointing for memory efficiency + model.gradient_checkpointing_enable(**gc_enable_kwargs) + return model + + +# copied from transformers.models.bart.modeling_bart +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids + pad_token_id (`int`): The id of the `padding` token. + decoder_start_token_id (`int`): The id of the `start` token. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class ModulesToSaveWrapper(torch.nn.Module): + def __init__(self, module_to_save, adapter_name): + super().__init__() + self.original_module = module_to_save + self.modules_to_save = torch.nn.ModuleDict({}) + self._active_adapter = adapter_name + self._disable_adapters = False + self.update(adapter_name) + self.check_module() + + def check_module(self): + """Perform some sanity checks on the module to ensure that it works""" + # Try to anticipate some modules that users could try to target that would not work. + # Note: It's not possible to check hasattr(module, "forward"), since that returns True for ModuleDict and + # ModuleList, even though their forward methods cannot be called + forbidden_classes = (torch.nn.ModuleDict, torch.nn.ModuleList, torch.nn.ParameterDict, torch.nn.ParameterList) + if isinstance(self.original_module, forbidden_classes): + cls_name = self.original_module.__class__.__name__ + raise TypeError(f"modules_to_save cannot be applied to modules of type {cls_name}") + + @property + def disable_adapters(self) -> bool: + # use a property to ensure that disable_adapters is not set directly, instead use the enable_adapters method + return self._disable_adapters + + @property + def active_adapter(self) -> str: + # use a property to ensure that active_adapter is not set directly, instead use the set_adapter method + return self._active_adapter + + @property + def weight(self): + if self.active_adapter not in self.modules_to_save: + return self.original_module.weight + return self.modules_to_save[self.active_adapter].weight + + def update(self, adapter_name): + context_manager = nullcontext() + for _, param in self.original_module.named_parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + import deepspeed + + context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0) + break + with context_manager: + self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)})) + + if hasattr(self.modules_to_save[adapter_name], "_hf_hook"): + old_hook = self.modules_to_save[adapter_name]._hf_hook + new_hook = self._create_new_hook(old_hook) + remove_hook_from_module(self.modules_to_save[adapter_name]) + add_hook_to_module(self.modules_to_save[adapter_name], new_hook) + + self.original_module.requires_grad_(False) + if adapter_name == self.active_adapter: + self.modules_to_save[adapter_name].requires_grad_(True) + + def _create_new_hook(self, old_hook): + r""" + Creates a new hook based on the old hook. Use it only if you know what you are doing ! + """ + old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) + old_hook_attr = old_hook.__dict__ + filtered_old_hook_attr = {} + old_hook_init_signature = inspect.signature(old_hook_cls.__init__) + for k in old_hook_attr.keys(): + if k in old_hook_init_signature.parameters: + filtered_old_hook_attr[k] = old_hook_attr[k] + new_hook = old_hook_cls(**filtered_old_hook_attr) + return new_hook + + def forward(self, *args, **kwargs): + if self.disable_adapters or (self.active_adapter not in self.modules_to_save): + return self.original_module(*args, **kwargs) + return self.modules_to_save[self.active_adapter](*args, **kwargs) + + def enable_adapters(self, enabled: bool): + """Toggle the enabling and disabling of adapters + + Takes care of setting the requires_grad flag for the adapter weights. + + Args: + enabled (bool): True to enable adapters, False to disable adapters + """ + if self._disable_adapters is not enabled: + # already in the desired state, do nothing + return + + if enabled: + self.original_module.requires_grad_(False) + self.modules_to_save[self.active_adapter].requires_grad_(True) + self._disable_adapters = False + else: + self.original_module.requires_grad_(True) + self.modules_to_save.requires_grad_(False) + self._disable_adapters = True + + def set_adapter(self, adapter_name: str): + """Set the active adapter + + Additionally, this function will set the specified adapter to trainable (i.e., requires_grad=True). If this is + not desired, use the following code. + + ```py + >>> for name, param in model_peft.named_parameters(): + ... if ...: # some check on name (ex. if 'lora' in name) + ... param.requires_grad = False + ``` + + Args: + adapter_name (str): The name of the adapter to set as active + """ + if adapter_name not in self.modules_to_save: + raise ValueError(f"Adapter {adapter_name} not found in {self.modules_to_save.keys()}") + + self.modules_to_save[self.active_adapter].requires_grad_(False) + self.modules_to_save[adapter_name].requires_grad_(True) + self._active_adapter = adapter_name + + +def _get_submodules(model, key): + parent = model.get_submodule(".".join(key.split(".")[:-1])) + target_name = key.split(".")[-1] + target = model.get_submodule(key) + return parent, target, target_name + + +def _freeze_adapter(model, adapter_name): + for n, p in model.named_parameters(): + if adapter_name in n: + p.requires_grad = False + + +def _set_trainable(model, adapter_name): + key_list = [key for key, _ in model.named_modules()] + for key in key_list: + target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save) + if target_module_found: + parent, target, target_name = _get_submodules(model, key) + if isinstance(target, ModulesToSaveWrapper): + target.update(adapter_name) + target.set_adapter(target.active_adapter) + else: + new_module = ModulesToSaveWrapper(target, adapter_name) + new_module.set_adapter(adapter_name) + setattr(parent, target_name, new_module) + + +def _set_adapter(model, adapter_name): + def check_adapter_name(adapter_name): + if isinstance(adapter_name, str): + return adapter_name + + # adapter_name is a list of str + if len(adapter_name) > 1: + raise ValueError("Only one adapter can be set at a time for modules_to_save") + elif len(adapter_name) == 0: + raise ValueError("Please specify at least one adapter to set") + adapter_name = adapter_name[0] + return adapter_name + + for module in model.modules(): + if isinstance(module, ModulesToSaveWrapper): + # only check the adapter_name if we actually encounter a ModulesToSaveWrapper, otherwise we don't care + adapter_name = check_adapter_name(adapter_name) + + # if the adapter is found in this module, set it as the active adapter, else disable the adapters of this + # module + if adapter_name in module.modules_to_save: + module.set_adapter(adapter_name) + else: + module.enable_adapters(False) + + +def _prepare_prompt_learning_config(peft_config, model_config): + if peft_config.num_layers is None: + if "num_hidden_layers" in model_config: + num_layers = model_config["num_hidden_layers"] + elif "num_layers" in model_config: + num_layers = model_config["num_layers"] + elif "n_layer" in model_config: + num_layers = model_config["n_layer"] + else: + raise ValueError("Please specify `num_layers` in `peft_config`") + peft_config.num_layers = num_layers + + if peft_config.token_dim is None: + if "hidden_size" in model_config: + token_dim = model_config["hidden_size"] + elif "n_embd" in model_config: + token_dim = model_config["n_embd"] + elif "d_model" in model_config: + token_dim = model_config["d_model"] + else: + raise ValueError("Please specify `token_dim` in `peft_config`") + peft_config.token_dim = token_dim + + if peft_config.num_attention_heads is None: + if "num_attention_heads" in model_config: + num_attention_heads = model_config["num_attention_heads"] + elif "n_head" in model_config: + num_attention_heads = model_config["n_head"] + elif "num_heads" in model_config: + num_attention_heads = model_config["num_heads"] + elif "encoder_attention_heads" in model_config: + num_attention_heads = model_config["encoder_attention_heads"] + else: + raise ValueError("Please specify `num_attention_heads` in `peft_config`") + peft_config.num_attention_heads = num_attention_heads + + if getattr(peft_config, "encoder_hidden_size", None) is None: + setattr(peft_config, "encoder_hidden_size", peft_config.token_dim) + + return peft_config + + +def fsdp_auto_wrap_policy(model): + import functools + import os + + from accelerate import FullyShardedDataParallelPlugin + + if hasattr(FullyShardedDataParallelPlugin, "get_module_class_from_name"): + get_module_class_from_name = FullyShardedDataParallelPlugin.get_module_class_from_name + else: + from accelerate.utils.dataclasses import get_module_class_from_name + from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy + + from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder + + default_transformer_cls_names_to_wrap = ( + ",".join(model._no_split_modules) if getattr(model, "_no_split_modules", None) is not None else "" + ) + transformer_cls_names_to_wrap = os.environ.get( + "FSDP_TRANSFORMER_CLS_TO_WRAP", default_transformer_cls_names_to_wrap + ).split(",") + transformer_cls_to_wrap = {PrefixEncoder, PromptEncoder, PromptEmbedding} + for layer_class in transformer_cls_names_to_wrap: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise Exception("Could not find the transformer layer class to wrap in the model.") + else: + transformer_cls_to_wrap.add(transformer_cls) + + def lambda_policy_fn(module): + if ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ): + return True + return False + + lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) + transformer_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=transformer_cls_to_wrap, + ) + + auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) + return auto_wrap_policy + + +def transpose(weight, fan_in_fan_out): + if not fan_in_fan_out: + return weight + + if isinstance(weight, torch.nn.Parameter): + return torch.nn.Parameter(weight.T) + return weight.T + + +def _is_valid_match(key: str, target_key: str): + """ + Helper function to match module names target_key and key. Makes sure that either the key is exactly the target_key + or the target_key is a submodule of key + """ + if key.endswith(target_key): + if len(key) > len(target_key): + return key.endswith("." + target_key) # must be a sub module + return True + return False + + +def _get_batch_size(input_ids: Optional[torch.Tensor], inputs_embeds: Optional[torch.Tensor]) -> int: + """Get the batch size based on either input_ids or input_embeds + + Raises an ValueError if both are None. + + """ + if (input_ids is None) and (inputs_embeds is None): + raise ValueError("You have to provide either input_ids or inputs_embeds") + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + return batch_size + + +def get_quantization_config(model: torch.nn.Module, method: str): + """ + Get the quantization config of the related quantization method + """ + if ( + hasattr(model, "config") + and hasattr(model.config, "quantization_config") + and (getattr(model, "quantization_method", None) == method) + ): + return model.config.quantization_config + return None + + +def get_auto_gptq_quant_linear(gptq_quantization_config): + """ + Get the right AutoGPTQQuantLinear class based on the quantization config file + """ + if gptq_quantization_config is not None and is_auto_gptq_available(): + from auto_gptq.utils.import_utils import dynamically_import_QuantLinear + + desc_act = gptq_quantization_config.desc_act + group_size = gptq_quantization_config.group_size + bits = gptq_quantization_config.bits + if hasattr(gptq_quantization_config, "use_exllama"): + use_exllama = gptq_quantization_config.use_exllama + else: + use_exllama = not gptq_quantization_config.disable_exllama + if hasattr(gptq_quantization_config, "exllama_config"): + exllama_version = gptq_quantization_config.exllama_config["version"] + else: + exllama_version = 1 + AutoGPTQQuantLinear = dynamically_import_QuantLinear( + use_triton=False, + desc_act=desc_act, + group_size=group_size, + bits=bits, + disable_exllama=not (use_exllama and exllama_version == 1), + disable_exllamav2=not (use_exllama and exllama_version == 2), + ) + return AutoGPTQQuantLinear + return None + + +def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]: + """ + Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For + example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is + guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with + non-overlapping lifetimes may have the same id. + + This method is the exact same copy of + https://github.com/huggingface/transformers/blob/main/src/transformers/pytorch_utils.py#L282C1-L300C58 but we added + it here manually to avoid import issue with old versions of transformers. + """ + if tensor.device.type == "xla" and is_torch_tpu_available(): + # NOTE: xla tensors dont have storage + # use some other unique id to distinguish. + # this is a XLA tensor, it must be created using torch_xla's + # device. So the following import is safe: + import torch_xla + + unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor) + else: + unique_id = storage_ptr(tensor) + + return tensor.device, unique_id, storage_size(tensor) + + +def cast_mixed_precision_params(model, dtype): + """ + Cast all non-trainable parameters of the model to the given `dtype`. The `dtype` can be `torch.float16` or + `torch.bfloat16` as per the mixed-precision training you are performing. The trainable parameters are cast to full + precision. This is meant to reduce the GPU memory usage when using PEFT methods by using half-precision dtype for + non-trainable parameters. Having the trainable parameters in full-precision preserves training stability when using + automatic mixed-precision training. + + Args: + model (`torch.nn.Module`): + The model to cast the non-trainable parameters of. + dtype (`torch.dtype`): + The dtype to cast the non-trainable parameters to. The `dtype` can be `torch.float16` or + `torch.bfloat16` as per the mixed-precision training you are performing. + """ + for p in model.parameters(): + if not p.requires_grad: + p.data = p.to(dtype) + else: + p.data = p.to(torch.float32) + + +def str_to_bool(value: str) -> int: + """ + Converts a string representation of truth to `True` (1) or `False` (0). + + True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`; + """ + # same as function as in accelerate.utils, which replaces the deprecated distutils.util.strtobool + value = value.lower() + if value in ("y", "yes", "t", "true", "on", "1"): + return 1 + elif value in ("n", "no", "f", "false", "off", "0"): + return 0 + else: + raise ValueError(f"invalid truth value {value}") + + +def check_file_exists_on_hf_hub(repo_id: str, filename: str, **kwargs) -> Optional[bool]: + """Check if a file exists on HF Hub, if check was not successful returns None instead of erroring. + + Respect offline mode if set. + + """ + exists: Optional[bool] = None + if str_to_bool(os.environ.get("HF_HUB_OFFLINE", "0")): + # user set offline mode, cannot check + return exists + + try: + exists = file_exists(repo_id, filename, **kwargs) + except (HFValidationError, EntryNotFoundError): + # error, exists stays None + pass + except Exception as e: + warnings.warn( + f"Unable to fetch remote file due to the following error {e} - silently ignoring the lookup" + f" for the file {filename} in {repo_id}." + ) + + return exists diff --git a/peft/utils/peft_types.py b/peft/utils/peft_types.py new file mode 100644 index 0000000000000000000000000000000000000000..382521d61978e0dba7cff5dd51830939d41e1b9c --- /dev/null +++ b/peft/utils/peft_types.py @@ -0,0 +1,79 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all + +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum + + +class PeftType(str, enum.Enum): + """ + Enum class for the different types of adapters in PEFT. + + Supported PEFT types: + - PROMPT_TUNING + - MULTITASK_PROMPT_TUNING + - P_TUNING + - PREFIX_TUNING + - LORA + - ADALORA + - BOFT + - ADAPTION_PROMPT + - IA3 + - LOHA + - LOKR + - OFT + - POLY + - LN_TUNING + """ + + PROMPT_TUNING = "PROMPT_TUNING" + MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING" + P_TUNING = "P_TUNING" + PREFIX_TUNING = "PREFIX_TUNING" + LORA = "LORA" + ADALORA = "ADALORA" + BOFT = "BOFT" + ADAPTION_PROMPT = "ADAPTION_PROMPT" + IA3 = "IA3" + LOHA = "LOHA" + LOKR = "LOKR" + OFT = "OFT" + POLY = "POLY" + LN_TUNING = "LN_TUNING" + VERA = "VERA" + + +class TaskType(str, enum.Enum): + """ + Enum class for the different types of tasks supported by PEFT. + + Overview of the supported task types: + - SEQ_CLS: Text classification. + - SEQ_2_SEQ_LM: Sequence-to-sequence language modeling. + - CAUSAL_LM: Causal language modeling. + - TOKEN_CLS: Token classification. + - QUESTION_ANS: Question answering. + - FEATURE_EXTRACTION: Feature extraction. Provides the hidden states which can be used as embeddings or features + for downstream tasks. + """ + + SEQ_CLS = "SEQ_CLS" + SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM" + CAUSAL_LM = "CAUSAL_LM" + TOKEN_CLS = "TOKEN_CLS" + QUESTION_ANS = "QUESTION_ANS" + FEATURE_EXTRACTION = "FEATURE_EXTRACTION" diff --git a/peft/utils/save_and_load.py b/peft/utils/save_and_load.py new file mode 100644 index 0000000000000000000000000000000000000000..dc77d0f660e51bbe21d24cf424928373310c4b09 --- /dev/null +++ b/peft/utils/save_and_load.py @@ -0,0 +1,448 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import warnings +from typing import Optional + +import torch +from huggingface_hub import file_exists, hf_hub_download +from huggingface_hub.utils import EntryNotFoundError +from safetensors.torch import load_file as safe_load_file + +from .other import ( + EMBEDDING_LAYER_NAMES, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + check_file_exists_on_hf_hub, + infer_device, +) +from .peft_types import PeftType + + +def has_valid_embedding_base_layer(layer): + """Check if the layer has an embedding base layer""" + return hasattr(layer, "base_layer") and isinstance(layer.base_layer, (torch.nn.Linear, torch.nn.Embedding)) + + +def get_embedding_layer_name(model, layer, is_embedding_in_target_modules): + """Get the name of the embedding module for a given layer.""" + for name, module in model.named_modules(): + if (not is_embedding_in_target_modules and module == layer) or module == getattr(layer, "base_layer", None): + return name + return None + + +def get_peft_model_state_dict( + model, state_dict=None, adapter_name="default", unwrap_compiled=False, save_embedding_layers="auto" +): + """ + Get the state dict of the Peft model. + + Args: + model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP, + the model should be the underlying model/unwrapped model (i.e. model.module). + state_dict (`dict`, *optional*, defaults to `None`): + The state dict of the model. If not provided, the state dict of the passed model will be used. + adapter_name (`str`, *optional*, defaults to `"default"`): + The name of the adapter whose state dict should be returned. + unwrap_compiled (`bool`, *optional*, defaults to `False`): + Whether to unwrap the model if torch.compile was used. + save_embedding_layers (`Union[bool, str]`, , *optional*, defaults to `auto`): + If `True`, save the embedding layers in addition to adapter weights. If `auto`, checks the common embedding + layers `peft.utils.other.EMBEDDING_LAYER_NAMES` in config's `target_modules` when available. Based on it + sets the boolean flag. This only works for 🤗 transformers models. + """ + if unwrap_compiled: + model = getattr(model, "_orig_mod", model) + + config = model.peft_config[adapter_name] + if state_dict is None: + state_dict = model.state_dict() + if config.peft_type in (PeftType.LORA, PeftType.ADALORA): + # to_return = lora_state_dict(model, bias=model.peft_config.bias) + # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py` + # to be used directly with the state dict which is necessary when using DeepSpeed or FSDP + bias = config.bias + if bias == "none": + to_return = {k: state_dict[k] for k in state_dict if "lora_" in k} + elif bias == "all": + to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + for k in state_dict: + if "lora_" in k: + to_return[k] = state_dict[k] + bias_name = k.split("lora_")[0] + "bias" + if bias_name in state_dict: + to_return[bias_name] = state_dict[bias_name] + else: + raise NotImplementedError + to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))} + if config.peft_type == PeftType.ADALORA: + rank_pattern = config.rank_pattern + if rank_pattern is not None: + rank_pattern = {k.replace(f".{adapter_name}", ""): v for k, v in rank_pattern.items()} + config.rank_pattern = rank_pattern + to_return = model.resize_state_dict_by_rank_pattern(rank_pattern, to_return, adapter_name) + + elif config.peft_type == PeftType.BOFT: + bias = config.bias + if bias == "none": + to_return = {k: state_dict[k] for k in state_dict if "boft_" in k} + elif bias == "all": + to_return = {k: state_dict[k] for k in state_dict if "boft_" in k or "bias" in k} + elif bias == "boft_only": + to_return = {} + for k in state_dict: + if "boft_" in k: + to_return[k] = state_dict[k] + bias_name = k.split("boft_")[0] + "bias" + if bias_name in state_dict: + to_return[bias_name] = state_dict[bias_name] + else: + raise NotImplementedError + + elif config.peft_type == PeftType.LOHA: + to_return = {k: state_dict[k] for k in state_dict if "hada_" in k} + + elif config.peft_type == PeftType.LOKR: + to_return = {k: state_dict[k] for k in state_dict if "lokr_" in k} + + elif config.peft_type == PeftType.ADAPTION_PROMPT: + to_return = {k: state_dict[k] for k in state_dict if k.split(".")[-1].startswith("adaption_")} + + elif config.is_prompt_learning: + to_return = {} + if config.peft_type == PeftType.MULTITASK_PROMPT_TUNING: + to_return["prefix_task_cols"] = model.prompt_encoder[adapter_name].prefix_task_cols + to_return["prefix_task_rows"] = model.prompt_encoder[adapter_name].prefix_task_rows + prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight + else: + if config.inference_mode: + prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight + else: + prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name) + to_return["prompt_embeddings"] = prompt_embeddings + + elif config.peft_type == PeftType.IA3: + to_return = {k: state_dict[k] for k in state_dict if "ia3_" in k} + + elif config.peft_type == PeftType.OFT: + to_return = {k: state_dict[k] for k in state_dict if "oft_" in k} + + elif config.peft_type == PeftType.POLY: + to_return = {k: state_dict[k] for k in state_dict if "poly_" in k} + + elif config.peft_type == PeftType.LN_TUNING: + to_return = {k: state_dict[k] for k in state_dict if "ln_tuning_" in k} + + elif config.peft_type == PeftType.VERA: + to_return = {k: state_dict[k] for k in state_dict if "vera_lambda_" in k} + if config.save_projection: + # TODO: adding vera_A and vera_B to `self.get_base_layer` would + # make name to match here difficult to predict. + if f"base_model.vera_A.{adapter_name}" not in state_dict: + raise ValueError( + "Model was initialised to not save vera_A and vera_B but config now specifies to save projection!" + " Set `config.save_projection` to `False`." + ) + to_return["base_model.vera_A." + adapter_name] = state_dict["base_model.vera_A." + adapter_name] + to_return["base_model.vera_B." + adapter_name] = state_dict["base_model.vera_B." + adapter_name] + + else: + raise ValueError(f"Unknown PEFT type passed: {config.peft_type}") + + if getattr(model, "modules_to_save", None) is not None: + for key, value in state_dict.items(): + if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save): + to_return[key.replace("modules_to_save.", "")] = value + + # check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary + is_embedding_in_target_modules = False + if ( + save_embedding_layers == "auto" + and hasattr(config, "target_modules") + and any(k in config.target_modules for k in EMBEDDING_LAYER_NAMES) + ): + warnings.warn("Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`.") + save_embedding_layers = is_embedding_in_target_modules = True + elif save_embedding_layers == "auto": + vocab_size = getattr(getattr(model, "config", None), "vocab_size", None) + model_id = getattr(config, "base_model_name_or_path", None) + + # For some models e.g. diffusers the text config file is stored in a subfolder + # we need to make sure we can download that config. + has_remote_config = False + + # ensure that this check is not performed in HF offline mode, see #1452 + if model_id is not None: + exists = check_file_exists_on_hf_hub(model_id, "config.json") + if exists is None: + # check failed, could not determine if it exists or not + warnings.warn( + f"Could not find a config file in {model_id} - will assume that the vocabulary was not modified." + ) + has_remote_config = False + else: + has_remote_config = exists + + # check if the vocab size of the base model is different from the vocab size of the finetuned model + if ( + vocab_size + and model_id + and has_remote_config + and (vocab_size != model.config.__class__.from_pretrained(model_id).vocab_size) + ): + warnings.warn( + "Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning." + ) + save_embedding_layers = True + else: + save_embedding_layers = False + + if save_embedding_layers and hasattr(model, "get_input_embeddings"): + for layer in [model.get_input_embeddings(), model.get_output_embeddings()]: + if not is_embedding_in_target_modules or has_valid_embedding_base_layer(layer): + # support from version >= 0.6.2 + embedding_module_name = get_embedding_layer_name(model, layer, is_embedding_in_target_modules) + if embedding_module_name: + to_return.update({k: v for k, v in state_dict.items() if embedding_module_name in k}) + elif save_embedding_layers: + warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.") + + to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()} + return to_return + + +def _find_mismatched_keys( + model: torch.nn.Module, peft_model_state_dict: dict[str, torch.Tensor], ignore_mismatched_sizes: bool = False +) -> tuple[dict[str, torch.Tensor], list[tuple[str, tuple[int, ...], tuple[int, ...]]]]: + if not ignore_mismatched_sizes: + return peft_model_state_dict, [] + + mismatched = [] + state_dict = model.state_dict() + for key, tensor in peft_model_state_dict.items(): + if key not in state_dict: + continue + + # see https://github.com/huggingface/transformers/blob/09f9f566de83eef1f13ee83b5a1bbeebde5c80c1/src/transformers/modeling_utils.py#L3858-L3864 + if (state_dict[key].shape[-1] == 1) and (state_dict[key].numel() * 2 == tensor.numel()): + # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size + # differences. Without matching with module type or paramter type it seems like a practical way to detect + # valid 4bit weights. + continue + + if state_dict[key].shape != tensor.shape: + mismatched.append((key, tensor.shape, state_dict[key].shape)) + + for key, _, _ in mismatched: + del peft_model_state_dict[key] + + return peft_model_state_dict, mismatched + + +def set_peft_model_state_dict( + model, peft_model_state_dict, adapter_name="default", ignore_mismatched_sizes: bool = False +): + """ + Set the state dict of the Peft model. + + Args: + model ([`PeftModel`]): + The Peft model. + peft_model_state_dict (`dict`): + The state dict of the Peft model. + adapter_name (`str`, *optional*, defaults to `"default"`): + The name of the adapter whose state dict should be set. + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether to ignore mismatched in the state dict. + """ + config = model.peft_config[adapter_name] + state_dict = {} + if getattr(model, "modules_to_save", None) is not None: + for key, value in peft_model_state_dict.items(): + if any(module_name in key for module_name in model.modules_to_save): + for module_name in model.modules_to_save: + if module_name in key: + key = key.replace(module_name, f"{module_name}.modules_to_save.{adapter_name}") + break + state_dict[key] = value + else: + state_dict = peft_model_state_dict + + if config.peft_type in ( + PeftType.LORA, + PeftType.LOHA, + PeftType.LOKR, + PeftType.ADALORA, + PeftType.IA3, + PeftType.OFT, + PeftType.POLY, + PeftType.LN_TUNING, + PeftType.BOFT, + PeftType.VERA, + ): + peft_model_state_dict = {} + parameter_prefix = { + PeftType.IA3: "ia3_", + PeftType.LORA: "lora_", + PeftType.ADALORA: "lora_", + PeftType.LOHA: "hada_", + PeftType.LOKR: "lokr_", + PeftType.OFT: "oft_", + PeftType.POLY: "poly_", + PeftType.BOFT: "boft_", + PeftType.LN_TUNING: "ln_tuning_", + PeftType.VERA: "vera_lambda_", + }[config.peft_type] + for k, v in state_dict.items(): + if parameter_prefix in k: + suffix = k.split(parameter_prefix)[1] + if "." in suffix: + suffix_to_replace = ".".join(suffix.split(".")[1:]) + k = k.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}") + else: + k = f"{k}.{adapter_name}" + peft_model_state_dict[k] = v + else: + peft_model_state_dict[k] = v + + if config.peft_type == PeftType.ADALORA: + rank_pattern = config.rank_pattern + if rank_pattern is not None: + model.resize_modules_by_rank_pattern(rank_pattern, adapter_name) + elif config.peft_type == PeftType.VERA: + if config.save_projection and "base_model.vera_A" not in peft_model_state_dict: + raise ValueError( + "Specified to load vera_A and vera_B from state dictionary however they were not present!" + ) + elif not config.save_projection and "base_model.vera_A" in peft_model_state_dict: + warnings.warn( + "Specified to not load vera_A and vera_B from state dictionary however they are present in state" + " dictionary! Consider using them to ensure checkpoint loading is correct on all platforms using" + " `peft_config.save_projection = True`" + ) + elif not config.save_projection: # and no vera_A in state dictionary + warnings.warn( + "Specified to not load vera_A and vera_B from state dictionary. This means we will be relying on" + " PRNG initialisation to restore these projections using `config.projection_prng_key`, which may" + " not be accurate on all system configurations." + ) + elif config.is_prompt_learning or config.peft_type == PeftType.ADAPTION_PROMPT: + peft_model_state_dict = state_dict + else: + raise NotImplementedError + + peft_model_state_dict, mismatched_keys = _find_mismatched_keys( + model, peft_model_state_dict, ignore_mismatched_sizes=ignore_mismatched_sizes + ) + load_result = model.load_state_dict(peft_model_state_dict, strict=False) + if config.is_prompt_learning: + model.prompt_encoder[adapter_name].embedding.load_state_dict( + {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True + ) + + if config.peft_type == PeftType.MULTITASK_PROMPT_TUNING: + model.prompt_encoder[adapter_name].load_state_dict(peft_model_state_dict, strict=False) + + if mismatched_keys: + # see https://github.com/huggingface/transformers/blob/09f9f566de83eef1f13ee83b5a1bbeebde5c80c1/src/transformers/modeling_utils.py#L4039 + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + msg = ( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint " + f"and are being ignored because you passed `ignore_mismatched_sizes=True`: {mismatched_warning}." + ) + warnings.warn(msg) + return load_result + + +def load_peft_weights(model_id: str, device: Optional[str] = None, **hf_hub_download_kwargs) -> dict: + r""" + A helper method to load the PEFT weights from the HuggingFace Hub or locally + + Args: + model_id (`str`): + The local path to the adapter weights or the name of the adapter to load from the HuggingFace Hub. + device (`str`): + The device to load the weights onto. + hf_hub_download_kwargs (`dict`): + Additional arguments to pass to the `hf_hub_download` method when loading from the HuggingFace Hub. + """ + path = ( + os.path.join(model_id, hf_hub_download_kwargs["subfolder"]) + if hf_hub_download_kwargs.get("subfolder", None) is not None + else model_id + ) + + if device is None: + device = infer_device() + + if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)): + filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME) + use_safetensors = True + elif os.path.exists(os.path.join(path, WEIGHTS_NAME)): + filename = os.path.join(path, WEIGHTS_NAME) + use_safetensors = False + else: + token = hf_hub_download_kwargs.get("token", None) + if token is None: + token = hf_hub_download_kwargs.get("use_auth_token", None) + + hub_filename = ( + os.path.join(hf_hub_download_kwargs["subfolder"], SAFETENSORS_WEIGHTS_NAME) + if hf_hub_download_kwargs.get("subfolder", None) is not None + else SAFETENSORS_WEIGHTS_NAME + ) + has_remote_safetensors_file = file_exists( + repo_id=model_id, + filename=hub_filename, + revision=hf_hub_download_kwargs.get("revision", None), + repo_type=hf_hub_download_kwargs.get("repo_type", None), + token=token, + ) + use_safetensors = has_remote_safetensors_file + + if has_remote_safetensors_file: + # Priority 1: load safetensors weights + filename = hf_hub_download( + model_id, + SAFETENSORS_WEIGHTS_NAME, + **hf_hub_download_kwargs, + ) + else: + try: + filename = hf_hub_download(model_id, WEIGHTS_NAME, **hf_hub_download_kwargs) + except EntryNotFoundError: + raise ValueError( + f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. " + f"Please check that the file {WEIGHTS_NAME} or {SAFETENSORS_WEIGHTS_NAME} is present at {model_id}." + ) + + if use_safetensors: + if hasattr(torch.backends, "mps") and (device == torch.device("mps")): + adapters_weights = safe_load_file(filename, device="cpu") + else: + adapters_weights = safe_load_file(filename, device=device) + else: + adapters_weights = torch.load(filename, map_location=torch.device(device)) + + return adapters_weights diff --git a/pipelines/__pycache__/pipeline_imagecoductor.cpython-310.pyc b/pipelines/__pycache__/pipeline_imagecoductor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfa7a0da04273f420859de00b9956270fa57e1e6 Binary files /dev/null and b/pipelines/__pycache__/pipeline_imagecoductor.cpython-310.pyc differ diff --git a/pipelines/pipeline_imagecoductor.py b/pipelines/pipeline_imagecoductor.py new file mode 100644 index 0000000000000000000000000000000000000000..d1ba3ca3f65ba610b35454ab848195ab94b49e87 --- /dev/null +++ b/pipelines/pipeline_imagecoductor.py @@ -0,0 +1,559 @@ +# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py + +import inspect +import copy +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from diffusers.configuration_utils import FrozenDict +from diffusers.loaders import IPAdapterMixin +from diffusers.models import AutoencoderKL +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging +from einops import rearrange +from packaging import version +from tqdm import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from modules.unet import UNet3DConditionFlowModel +from modules.flow_controlnet import FlowControlNetModel +from modules.image_controlnet import ImageControlNetModel + + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ImageConductorPipelineOutput(BaseOutput): + videos: Union[torch.Tensor, np.ndarray] + + +class ImageConductorPipeline(DiffusionPipeline): + _optional_components = [] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionFlowModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + image_controlnet: Union[ImageControlNetModel, None] = None, + flow_controlnet: Union[FlowControlNetModel, None] = None, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + image_controlnet=image_controlnet, + flow_controlnet=flow_controlnet, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def enable_vae_slicing(self): + self.vae.enable_slicing() + + def disable_vae_slicing(self): + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + + @property + def _execution_device(self): + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def decode_latents(self, latents): + video_length = latents.shape[2] + latents = 1 / 0.18215 * latents + latents = rearrange(latents, "b c f h w -> (b f) c h w") + # video = self.vae.decode(latents).sample + video = [] + for frame_idx in tqdm(range(latents.shape[0])): + video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) + video = torch.cat(video) + video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) + video = (video / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + video = video.cpu().float().numpy() + return video + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = shape + # shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + video_length: Optional[int], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + obj_guidance_scale: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + + # support controlnet + controlnet_images: torch.FloatTensor = None, + controlnet_image_index: list = [0], + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + controlnet_flows: torch.FloatTensor = None, + + eval_mode: bool = False, + + control_mode: Optional[str] = "object", + + **kwargs, + ): + # Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # Define call parameters + # batch_size = 1 if isinstance(prompt, str) else len(prompt) + batch_size = 1 + if latents is not None: + batch_size = latents.shape[0] + if isinstance(prompt, list): + batch_size = len(prompt) + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # do_camera_free_guidance = obj_guidance_scale >= 1.0 ## todo: camera-free guidance + + + # Encode input prompt, [uncond_embeddings, cond_embeddings] for CFG + prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size + if negative_prompt is not None: + negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size + text_embeddings = self._encode_prompt( + prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + video_length, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + latents_dtype = latents.dtype + + obj_latents = copy.deepcopy(latents) + cam_latents = copy.deepcopy(latents) + + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + + # expand the latents if we are doing classifier free guidance + latent_obj_model_input = torch.cat([obj_latents] * 2) if do_classifier_free_guidance else obj_latents + latent_obj_model_input = self.scheduler.scale_model_input(latent_obj_model_input, t) + + latent_cam_model_input = torch.cat([cam_latents] * 2) if do_classifier_free_guidance else cam_latents + latent_cam_model_input = self.scheduler.scale_model_input(latent_cam_model_input, t) + + img_down_block_additional_residuals = img_mid_block_additional_residuals = None + flow_down_block_additional_residuals_obj = flow_mid_block_additional_residual_obj = None + + flow_down_block_additional_residuals_cam = flow_mid_block_additional_residual_cam = None + + + if control_mode == "object": + controlnet_noisy_latents = latent_obj_model_input + elif control_mode == "camera": + controlnet_noisy_latents = latent_cam_model_input + + + if (getattr(self, "image_controlnet", None) != None) and (controlnet_images != None): + assert controlnet_images.dim() == 5 + + controlnet_prompt_embeds = text_embeddings + + # concat mask and images condition + controlnet_images_ref = controlnet_images.to(latents.device) + + controlnet_images_shape = list(controlnet_images.shape) + controlnet_images_shape[2] = video_length + controlnet_images = torch.zeros(controlnet_images_shape).to(latents.device) + + controlnet_images_mask_shape = list(controlnet_images.shape) + controlnet_images_mask_shape[1] = 1 + controlnet_images_mask = torch.zeros(controlnet_images_mask_shape).to(latents.device) + + assert controlnet_images.shape[2] >= len(controlnet_image_index) + controlnet_images[:,:,controlnet_image_index] = controlnet_images_ref[:,:,:len(controlnet_image_index)] + controlnet_images_mask[:,:,controlnet_image_index] = 1 + + # fp16 accelerate training in deepspeed + if not eval_mode: + controlnet_noisy_latents = controlnet_noisy_latents.half() + controlnet_prompt_embeds = controlnet_prompt_embeds.half() + controlnet_images = controlnet_images.half() + controlnet_images_mask = controlnet_images_mask.half() + controlnet_flows = controlnet_flows.half() + text_embeddings = text_embeddings.half() + + img_down_block_additional_residuals, img_mid_block_additional_residuals = self.image_controlnet( + controlnet_noisy_latents, t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=controlnet_images, + conditioning_mask=controlnet_images_mask, + conditioning_scale=controlnet_conditioning_scale, + guess_mode=False, return_dict=False, + ) + + + if (getattr(self, "flow_controlnet", None) != None) and (controlnet_flows != None): + # Get the flow controlnet + # self.flow_controlnet.disable_adapters() + if control_mode == "camera": + self.flow_controlnet.enable_adapters() + self.flow_controlnet.set_adapter(["cmcm_weights"]) + flow_down_block_additional_residuals_cam, flow_mid_block_additional_residual_cam = self.flow_controlnet( + controlnet_noisy_latents, t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=controlnet_flows, # bcthw + conditioning_scale=controlnet_conditioning_scale, + guess_mode=False, return_dict=False, + ) + + ## camera transition latent outputs + noise_pred_cam = self.unet( + controlnet_noisy_latents, t, + encoder_hidden_states=text_embeddings, + down_block_additional_residuals = img_down_block_additional_residuals, + mid_block_additional_residual = img_mid_block_additional_residuals, + flow_down_block_additional_residuals = flow_down_block_additional_residuals_cam, + flow_mid_block_additional_residual = flow_mid_block_additional_residual_cam, + ).sample.to(dtype=latents_dtype) + + # perform CFG for camera + if do_classifier_free_guidance: + noise_pred_cam_uncond, noise_pred_cam_text = noise_pred_cam.chunk(2) + noise_pred_cam = noise_pred_cam_uncond + guidance_scale * (noise_pred_cam_text - noise_pred_cam_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + cam_latents = self.scheduler.step(noise_pred_cam, t, cam_latents, **extra_step_kwargs).prev_sample + + + if control_mode == "object": + self.flow_controlnet.enable_adapters() + self.flow_controlnet.set_adapter(["omcm_weights"]) + flow_down_block_additional_residuals_obj, flow_mid_block_additional_residual_obj = self.flow_controlnet( + controlnet_noisy_latents, t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=controlnet_flows, # bcthw + conditioning_scale=controlnet_conditioning_scale, + guess_mode=False, return_dict=False, + ) + + ## object movement latent outputs + noise_pred_obj = self.unet( + controlnet_noisy_latents, t, + encoder_hidden_states=text_embeddings, + down_block_additional_residuals = img_down_block_additional_residuals, + mid_block_additional_residual = img_mid_block_additional_residuals, + flow_down_block_additional_residuals = flow_down_block_additional_residuals_obj, + flow_mid_block_additional_residual = flow_mid_block_additional_residual_obj, + ).sample.to(dtype=latents_dtype) + + # perform CFG for object + if do_classifier_free_guidance: + noise_pred_obj_uncond, noise_pred_obj_text = noise_pred_obj.chunk(2) + noise_pred_obj = noise_pred_obj_uncond + guidance_scale * (noise_pred_obj_text - noise_pred_obj_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + obj_latents = self.scheduler.step(noise_pred_obj, t, obj_latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, cam_latents, obj_latents) + + + # Post-processning + if control_mode == "object": + video_obj = self.decode_latents(obj_latents) + videos = video_obj + elif control_mode == "camera": + video_cam = self.decode_latents(cam_latents) + videos = video_cam + + + # Convert to tensor + if output_type == "tensor": + videos = torch.from_numpy(videos) + + if not return_dict: + return videos + + return ImageConductorPipelineOutput(videos=videos) + + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..96f07763ae9984a05417e21fb90d7e9a6b87bd13 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,27 @@ +accelerate==0.32.1 +auto_gptq==0.7.1 +bitsandbytes==0.43.1 +deepspeed==0.14.4 +diffusers==0.29.2 +einops==0.8.0 +flow_vis==0.1 +gradio==3.50.2 +huggingface_hub==0.23.4 +imageio==2.34.2 +matplotlib==3.9.1 +numpy==2.0.0 +omegaconf==2.3.0 +packaging==24.1 +Pillow==10.4.0 +PyYAML==6.0.1 +Requests==2.32.3 +safetensors==0.4.3 +scipy==1.14.0 +torch==2.3.1 +torch_xla==2.3.0 +torchvision==0.18.1 +tqdm==4.65.0 +transformers==4.42.3 +imageio-ffmpeg +av2 +peft \ No newline at end of file diff --git a/utils/__pycache__/convert_from_ckpt.cpython-310.pyc b/utils/__pycache__/convert_from_ckpt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35988a572bebdcec2da2830e79889488975df453 Binary files /dev/null and b/utils/__pycache__/convert_from_ckpt.cpython-310.pyc differ diff --git a/utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-310.pyc b/utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2536faa654441e074aa6a9204301c62d838c1a62 Binary files /dev/null and b/utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-310.pyc differ diff --git a/utils/__pycache__/gradio_utils.cpython-310.pyc b/utils/__pycache__/gradio_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5be6755314c372cc0c0bb37bb65ad18b55d27bab Binary files /dev/null and b/utils/__pycache__/gradio_utils.cpython-310.pyc differ diff --git a/utils/__pycache__/lora_utils.cpython-310.pyc b/utils/__pycache__/lora_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e4a21a4c4815e943ad916d86d9aef9522422373 Binary files /dev/null and b/utils/__pycache__/lora_utils.cpython-310.pyc differ diff --git a/utils/__pycache__/utils.cpython-310.pyc b/utils/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3097d1c3d5ecd5a0c7091c30feea10450d844c55 Binary files /dev/null and b/utils/__pycache__/utils.cpython-310.pyc differ diff --git a/utils/__pycache__/visualizer.cpython-310.pyc b/utils/__pycache__/visualizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d25e81cafb1edbab61053bbf2f483cdfe26f3398 Binary files /dev/null and b/utils/__pycache__/visualizer.cpython-310.pyc differ diff --git a/utils/convert_from_ckpt.py b/utils/convert_from_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..23e4b7979636d700988026de187c4e9681c83248 --- /dev/null +++ b/utils/convert_from_ckpt.py @@ -0,0 +1,959 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for the Stable Diffusion checkpoints.""" + +import re +from io import BytesIO +from typing import Optional + +import requests +import torch +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + +from diffusers.models import ( + AutoencoderKL, + PriorTransformer, + UNet2DConditionModel, +) +from diffusers.schedulers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UnCLIPScheduler, +) +from diffusers.utils.import_utils import BACKENDS_MAPPING + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config.model.params.control_stage_config.params + else: + unet_params = original_config.model.params.unet_config.params + + vae_params = original_config.model.params.first_stage_config.params.ddconfig + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + + head_dim = unet_params.num_heads if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim = [5, 10, 20, 20] + + class_embed_type = None + projection_class_embeddings_input_dim = None + + if "num_classes" in unet_params: + if unet_params.num_classes == "sequential": + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params.adm_in_channels + else: + raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params.in_channels, + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_res_blocks, + "cross_attention_dim": unet_params.context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + } + + if not controlnet: + config["out_channels"] = unet_params.out_channels + config["up_block_types"] = tuple(up_block_types) + + return config + + +def create_vae_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config.model.params.first_stage_config.params.ddconfig + _ = original_config.model.params.first_stage_config.params.embed_dim + + block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params.in_channels, + "out_channels": vae_params.out_ch, + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params.z_channels, + "layers_per_block": vae_params.num_res_blocks, + } + return config + + +def create_diffusers_schedular(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config.model.params.timesteps, + beta_start=original_config.model.params.linear_start, + beta_end=original_config.model.params.linear_end, + beta_schedule="scaled_linear", + ) + return schedular + + +def create_ldm_bert_config(original_config): + bert_params = original_config.model.parms.cond_stage_config.params + config = LDMBertConfig( + d_model=bert_params.n_embed, + encoder_layers=bert_params.n_layer, + encoder_ffn_dim=bert_params.n_embed * 4, + ) + return config + + +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + if not controlnet: + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + if controlnet: + # conditioning embedding + + orig_index = 0 + + new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + orig_index += 2 + + diffusers_index = 0 + + while diffusers_index < 6: + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + diffusers_index += 1 + orig_index += 2 + + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + # down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def convert_ldm_bert_checkpoint(checkpoint, config): + def _copy_attn_layer(hf_attn_layer, pt_attn_layer): + hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight + hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight + hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight + + hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight + hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias + + def _copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + def _copy_layer(hf_layer, pt_layer): + # copy layer norms + _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) + _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) + + # copy attn + _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) + + # copy MLP + pt_mlp = pt_layer[1][1] + _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) + _copy_linear(hf_layer.fc2, pt_mlp.net[2]) + + def _copy_layers(hf_layers, pt_layers): + for i, hf_layer in enumerate(hf_layers): + if i != 0: + i += i + pt_layer = pt_layers[i : i + 2] + _copy_layer(hf_layer, pt_layer) + + hf_model = LDMBertModel(config).eval() + + # copy embeds + hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight + hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight + + # copy layer norm + _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) + + # copy hidden layers + _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) + + _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) + + return hf_model + + +def convert_ldm_clip_checkpoint(checkpoint): + text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") + keys = list(checkpoint.keys()) + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + text_model.load_state_dict(text_model_dict, strict=False) + + return text_model + + +textenc_conversion_lst = [ + ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"), + ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"), + ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"), + ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"), +] +textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} + +textenc_transformer_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + + +def convert_paint_by_example_checkpoint(checkpoint): + config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14") + model = PaintByExampleImageEncoder(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # load clip vision + model.model.load_state_dict(text_model_dict) + + # load mapper + keys_mapper = { + k[len("cond_stage_model.mapper.res") :]: v + for k, v in checkpoint.items() + if k.startswith("cond_stage_model.mapper") + } + + MAPPING = { + "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], + "attn.c_proj": ["attn1.to_out.0"], + "ln_1": ["norm1"], + "ln_2": ["norm3"], + "mlp.c_fc": ["ff.net.0.proj"], + "mlp.c_proj": ["ff.net.2"], + } + + mapped_weights = {} + for key, value in keys_mapper.items(): + prefix = key[: len("blocks.i")] + suffix = key.split(prefix)[-1].split(".")[-1] + name = key.split(prefix)[-1].split(suffix)[0][1:-1] + mapped_names = MAPPING[name] + + num_splits = len(mapped_names) + for i, mapped_name in enumerate(mapped_names): + new_name = ".".join([prefix, mapped_name, suffix]) + shape = value.shape[0] // num_splits + mapped_weights[new_name] = value[i * shape : (i + 1) * shape] + + model.mapper.load_state_dict(mapped_weights) + + # load final layer norm + model.final_layer_norm.load_state_dict( + { + "bias": checkpoint["cond_stage_model.final_ln.bias"], + "weight": checkpoint["cond_stage_model.final_ln.weight"], + } + ) + + # load final proj + model.proj_out.load_state_dict( + { + "bias": checkpoint["proj_out.bias"], + "weight": checkpoint["proj_out.weight"], + } + ) + + # load uncond vector + model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) + return model + + +def convert_open_clip_checkpoint(checkpoint): + text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + if "cond_stage_model.model.text_projection" in checkpoint: + d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0]) + else: + d_model = 1024 + + text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") + + for key in keys: + if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer + continue + if key in textenc_conversion_map: + text_model_dict[textenc_conversion_map[key]] = checkpoint[key] + if key.startswith("cond_stage_model.model.transformer."): + new_key = key[len("cond_stage_model.model.transformer.") :] + if new_key.endswith(".in_proj_weight"): + new_key = new_key[: -len(".in_proj_weight")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] + text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] + text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] + elif new_key.endswith(".in_proj_bias"): + new_key = new_key[: -len(".in_proj_bias")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] + text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] + text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] + else: + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + + text_model_dict[new_key] = checkpoint[key] + + text_model.load_state_dict(text_model_dict) + + return text_model + + +def stable_unclip_image_encoder(original_config): + """ + Returns the image processor and clip image encoder for the img2img unclip pipeline. + + We currently know of two types of stable unclip models which separately use the clip and the openclip image + encoders. + """ + + image_embedder_config = original_config.model.params.embedder_config + + sd_clip_image_embedder_class = image_embedder_config.target + sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] + + if sd_clip_image_embedder_class == "ClipImageEmbedder": + clip_model_name = image_embedder_config.params.model + + if clip_model_name == "ViT-L/14": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + else: + raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") + + elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + else: + raise NotImplementedError( + f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" + ) + + return feature_extractor, image_encoder + + +def stable_unclip_image_noising_components( + original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None +): + """ + Returns the noising components for the img2img and txt2img unclip pipelines. + + Converts the stability noise augmentor into + 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats + 2. a `DDPMScheduler` for holding the noise schedule + + If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. + """ + noise_aug_config = original_config.model.params.noise_aug_config + noise_aug_class = noise_aug_config.target + noise_aug_class = noise_aug_class.split(".")[-1] + + if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": + noise_aug_config = noise_aug_config.params + embedding_dim = noise_aug_config.timestep_dim + max_noise_level = noise_aug_config.noise_schedule_config.timesteps + beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule + + image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) + image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) + + if "clip_stats_path" in noise_aug_config: + if clip_stats_path is None: + raise ValueError("This stable unclip config requires a `clip_stats_path`") + + clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + clip_stats_state_dict = { + "mean": clip_mean, + "std": clip_std, + } + + image_normalizer.load_state_dict(clip_stats_state_dict) + else: + raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") + + return image_normalizer, image_noising_scheduler + + +def convert_controlnet_checkpoint( + checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema +): + ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) + ctrlnet_config["upcast_attention"] = upcast_attention + + ctrlnet_config.pop("sample_size") + + controlnet_model = ControlNetModel(**ctrlnet_config) + + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True + ) + + controlnet_model.load_state_dict(converted_ctrl_checkpoint) + + return controlnet_model diff --git a/utils/convert_lora_safetensor_to_diffusers.py b/utils/convert_lora_safetensor_to_diffusers.py new file mode 100644 index 0000000000000000000000000000000000000000..549336a62011dfa23d74452d64f5b9eec8a6b81b --- /dev/null +++ b/utils/convert_lora_safetensor_to_diffusers.py @@ -0,0 +1,152 @@ +# coding=utf-8 +# Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Changes were made to this source code by Yuwei Guo. +""" Conversion script for the LoRA's safetensors checkpoints. """ + +import argparse + +import torch +from safetensors.torch import load_file + +from diffusers import StableDiffusionPipeline + + +def load_diffusers_lora(pipeline, state_dict, alpha=1.0): + # directly update weight in diffusers model + for key in state_dict: + # only process lora down key + if "up." in key: continue + + up_key = key.replace(".down.", ".up.") + model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") + model_key = model_key.replace("to_out.", "to_out.0.") + layer_infos = model_key.split(".")[:-1] + + curr_layer = pipeline.unet + while len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + curr_layer = curr_layer.__getattr__(temp_name) + + weight_down = state_dict[key] + weight_up = state_dict[up_key] + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) + + return pipeline + + +def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): + # load base model + # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) + + # load LoRA weight from .safetensors + # state_dict = load_file(checkpoint_path) + + visited = [] + + # directly update weight in diffusers model + for key in state_dict: + # it is suggested to print out the key, it usually will be something like below + # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + + # as we have set the alpha beforehand, so just skip + if ".alpha" in key or key in visited: + continue + + if "text" in key: + layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = pipeline.unet + + # find the target layer + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + pair_keys = [] + if "lora_down" in key: + pair_keys.append(key.replace("lora_down", "lora_up")) + pair_keys.append(key) + else: + pair_keys.append(key) + pair_keys.append(key.replace("lora_up", "lora_down")) + + # update weight + if len(state_dict[pair_keys[0]].shape) == 4: + weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) + weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) + else: + weight_up = state_dict[pair_keys[0]].to(torch.float32) + weight_down = state_dict[pair_keys[1]].to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) + + # update visited list + for item in pair_keys: + visited.append(item) + + return pipeline + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." + ) + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument( + "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" + ) + parser.add_argument( + "--lora_prefix_text_encoder", + default="lora_te", + type=str, + help="The prefix of text encoder weight in safetensors", + ) + parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") + parser.add_argument( + "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." + ) + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + + args = parser.parse_args() + + base_model_path = args.base_model_path + checkpoint_path = args.checkpoint_path + dump_path = args.dump_path + lora_prefix_unet = args.lora_prefix_unet + lora_prefix_text_encoder = args.lora_prefix_text_encoder + alpha = args.alpha + + pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) + + pipe = pipe.to(args.device) + pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/utils/gradio_utils.py b/utils/gradio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..531d20cb683732d5d55eb9771ff51317d527e7e9 --- /dev/null +++ b/utils/gradio_utils.py @@ -0,0 +1,238 @@ +# -*- coding:utf-8 -*- +import os +import sys +import shutil +from tqdm import tqdm +import yaml + +import random +import importlib +from PIL import Image +import imageio + +import numpy as np +import cv2 +import torch +from torchvision import utils + +from scipy.interpolate import PchipInterpolator + +def split_filename(filename): + absname = os.path.abspath(filename) + dirname, basename = os.path.split(absname) + split_tmp = basename.rsplit('.', maxsplit=1) + if len(split_tmp) == 2: + rootname, extname = split_tmp + elif len(split_tmp) == 1: + rootname = split_tmp[0] + extname = None + else: + raise ValueError("programming error!") + return dirname, rootname, extname + + +def data2file(data, filename, type=None, override=False, printable=False, **kwargs): + dirname, rootname, extname = split_filename(filename) + print_did_not_save_flag = True + if type: + extname = type + if not os.path.exists(dirname): + os.makedirs(dirname, exist_ok=True) + + if not os.path.exists(filename) or override: + if extname in ['jpg', 'png', 'jpeg']: + utils.save_image(data, filename, **kwargs) + elif extname == 'gif': + imageio.mimsave(filename, data, format='GIF', duration=kwargs.get('duration'), loop=0) + elif extname == 'txt': + if kwargs is None: + kwargs = {} + max_step = kwargs.get('max_step') + if max_step is None: + max_step = np.Infinity + + with open(filename, 'w', encoding='utf-8') as f: + for i, e in enumerate(data): + if i < max_step: + f.write(str(e) + '\n') + else: + break + else: + raise ValueError('Do not support this type') + if printable: print('Saved data to %s' % os.path.abspath(filename)) + else: + if print_did_not_save_flag: print( + 'Did not save data to %s because file exists and override is False' % os.path.abspath( + filename)) + + +def file2data(filename, type=None, printable=True, **kwargs): + dirname, rootname, extname = split_filename(filename) + print_load_flag = True + if type: + extname = type + + if extname in ['pth', 'ckpt', 'bin']: + data = torch.load(filename, map_location=kwargs.get('map_location')) + if "state_dict" in data.keys(): + data = data["state_dict"] + data = {k.replace("_forward_module.", ""):v for k,v in data.items()} + elif extname == 'txt': + top = kwargs.get('top', None) + with open(filename, encoding='utf-8') as f: + if top: + data = [f.readline() for _ in range(top)] + else: + data = [e for e in f.read().split('\n') if e] + elif extname == 'yaml': + with open(filename, 'r') as f: + data = yaml.load(f) + + else: + raise ValueError('type can only support h5, npy, json, txt') + if printable: + if print_load_flag: + print('Loaded data from %s' % os.path.abspath(filename)) + return data + + +def ensure_dirname(dirname, override=False): + if os.path.exists(dirname) and override: + print('Removing dirname: %s' % os.path.abspath(dirname)) + try: + shutil.rmtree(dirname) + except OSError as e: + raise ValueError('Failed to delete %s because %s' % (dirname, e)) + + if not os.path.exists(dirname): + print('Making dirname: %s' % os.path.abspath(dirname)) + os.makedirs(dirname, exist_ok=True) + + +def import_filename(filename): + spec = importlib.util.spec_from_file_location("mymodule", filename) + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def adaptively_load_state_dict(target, state_dict): + target_dict = target.state_dict() + + try: + common_dict = {k: v for k, v in state_dict.items() if k in target_dict and v.size() == target_dict[k].size()} + # unmatch_dict = {k: v for k, v in state_dict.items() if k not in target_dict or v.size() != target_dict[k].size()} + except Exception as e: + print('load error %s', e) + common_dict = {k: v for k, v in state_dict.items() if k in target_dict} + + if 'param_groups' in common_dict and common_dict['param_groups'][0]['params'] != \ + target.state_dict()['param_groups'][0]['params']: + print('Detected mismatch params, auto adapte state_dict to current') + common_dict['param_groups'][0]['params'] = target.state_dict()['param_groups'][0]['params'] + target_dict.update(common_dict) + target.load_state_dict(target_dict) + + missing_keys = [k for k in target_dict.keys() if k not in common_dict] + unexpected_keys = [k for k in state_dict.keys() if k not in common_dict] + + if len(unexpected_keys) != 0: + print( + f"Some weights of state_dict were not used in target: {unexpected_keys}" + ) + if len(missing_keys) != 0: + print( + f"Some weights of state_dict are missing used in target {missing_keys}" + ) + if len(unexpected_keys) == 0 and len(missing_keys) == 0: + print("Strictly Loaded state_dict.") + + +def set_seed(seed=42): + random.seed(seed) + os.environ['PYHTONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +def image2pil(filename): + return Image.open(filename) + + +def image2arr(filename): + pil = image2pil(filename) + return pil2arr(pil) + + +def pil2arr(pil): + if isinstance(pil, list): + arr = np.array( + [np.array(e.convert('RGB').getdata(), dtype=np.uint8).reshape(e.size[1], e.size[0], 3) for e in pil]) + else: + arr = np.array(pil) + return arr + + +def arr2pil(arr): + if arr.ndim == 3: + return Image.fromarray(arr.astype('uint8'), 'RGB') + elif arr.ndim == 4: + return [Image.fromarray(e.astype('uint8'), 'RGB') for e in list(arr)] + else: + raise ValueError('arr must has ndim of 3 or 4, but got %s' % arr.ndim) + + +def interpolate_trajectory(points, n_points): + x = [point[0] for point in points] + y = [point[1] for point in points] + + t = np.linspace(0, 1, len(points)) + + fx = PchipInterpolator(t, x) + fy = PchipInterpolator(t, y) + + new_t = np.linspace(0, 1, n_points) + + new_x = fx(new_t) + new_y = fy(new_t) + new_points = list(zip(new_x, new_y)) + + return new_points + +def visualize_drag(background_image_path, splited_tracks, drag_mode, width, height, model_length): + if drag_mode=='object': + color = (255, 0, 0, 255) + elif drag_mode=='camera': + color = (0, 0, 255, 255) + background_image = Image.open(background_image_path).convert('RGBA') + background_image = background_image.resize((width, height)) + w, h = background_image.size + transparent_background = np.array(background_image) + transparent_background[:, :, -1] = 128 + transparent_background = Image.fromarray(transparent_background) + + # Create a transparent layer with the same size as the background image + transparent_layer = np.zeros((h, w, 4)) + for splited_track in splited_tracks: + if len(splited_track) > 1: + splited_track = interpolate_trajectory(splited_track, model_length) + splited_track = splited_track[:model_length] + for i in range(len(splited_track)-1): + start_point = (int(splited_track[i][0]), int(splited_track[i][1])) + end_point = (int(splited_track[i+1][0]), int(splited_track[i+1][1])) + vx = end_point[0] - start_point[0] + vy = end_point[1] - start_point[1] + arrow_length = np.sqrt(vx**2 + vy**2) + if i == len(splited_track)-2: + cv2.arrowedLine(transparent_layer, start_point, end_point, color, 2, tipLength=8 / arrow_length) + else: + cv2.line(transparent_layer, start_point, end_point, color, 2) + else: + cv2.circle(transparent_layer, (int(splited_track[0][0]), int(splited_track[0][1])), 5, color, -1) + + transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) + trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) + return trajectory_map, transparent_layer \ No newline at end of file diff --git a/utils/lora_utils.py b/utils/lora_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dce8c4a254339a20ae3439dcf86b0f30f7a7b92f --- /dev/null +++ b/utils/lora_utils.py @@ -0,0 +1,33 @@ +from peft import LoraConfig + +def add_LoRA_to_controlnet(lora_rank, controlnet): + controlnet_down_blocks = [f"controlnet_down_blocks.{i}" for i in range(12)] + target_modules = ['to_q', 'to_k', 'to_v','to_out.0', 'conv', 'proj', + 'proj_in', 'proj_out', 'time_emb_proj', + 'linear_1', 'linear_2', 'ff.net.2', 'conv_shortcut', + 'controlnet_cond_embedding', + 'conv_in', 'conv1', 'conv2', + 'controlnet_mid_block', + ] + target_modules = target_modules + controlnet_down_blocks + + # object motion control module + omcm_lora_config = LoraConfig( + r=lora_rank, + use_dora=False, + lora_alpha=lora_rank, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + + # camera motion control module + cmcm_lora_config = LoraConfig( + r=lora_rank, + use_dora=False, + lora_alpha=lora_rank, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + + controlnet.add_adapter(omcm_lora_config, "omcm_weights") + controlnet.add_adapter(cmcm_lora_config, "cmcm_weights") \ No newline at end of file diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c5bd7943339945dde75812d0be5a4d2748d2f4c2 --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,365 @@ +import os +import imageio +import importlib + +from omegaconf import OmegaConf +from typing import Union +from safetensors import safe_open +from tqdm import tqdm + + +import numpy as np +import torch +import torchvision +import torch.distributed as dist + +from scipy.interpolate import PchipInterpolator +from einops import rearrange + +from utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint +from utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora +from modules.flow_controlnet import FlowControlNetModel +from modules.image_controlnet import ImageControlNetModel + +def zero_rank_print(s): + if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) + + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=2, fps=8): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + imageio.mimsave(path, outputs, fps=fps, loop=0) + + +# DDIM Inversion +@torch.no_grad() +def init_prompt(prompt, pipeline): + uncond_input = pipeline.tokenizer( + [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, + return_tensors="pt" + ) + uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] + text_input = pipeline.tokenizer( + [prompt], + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] + context = torch.cat([uncond_embeddings, text_embeddings]) + + return context + + +def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): + timestep, next_timestep = min( + timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep + alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod + alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] + beta_prod_t = 1 - alpha_prod_t + next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 + next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output + next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction + return next_sample + + +def get_noise_pred_single(latents, t, context, unet): + noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] + return noise_pred + + +@torch.no_grad() +def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): + context = init_prompt(prompt, pipeline) + uncond_embeddings, cond_embeddings = context.chunk(2) + all_latent = [latent] + latent = latent.clone().detach() + for i in tqdm(range(num_inv_steps)): + t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] + noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) + latent = next_step(noise_pred, t, latent, ddim_scheduler) + all_latent.append(latent) + return all_latent + + +@torch.no_grad() +def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): + ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) + return ddim_latents + + +def load_weights( + animation_pipeline, + # motion module + motion_module_path = "", + motion_module_lora_configs = [], + # domain adapter + adapter_lora_path = "", + adapter_lora_scale = 1.0, + # image layers + dreambooth_model_path = "", + lora_model_path = "", + lora_alpha = 0.8, +): + # motion module + unet_state_dict = {} + if motion_module_path != "": + print(f"load motion module from {motion_module_path}") + motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") + motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict + unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) + unet_state_dict.pop("animatediff_config", "") + + missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) + assert len(unexpected) == 0 + del unet_state_dict + + # base model + if dreambooth_model_path != "": + print(f"load dreambooth model from {dreambooth_model_path}") + if dreambooth_model_path.endswith(".safetensors"): + dreambooth_state_dict = {} + with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: + for key in f.keys(): + dreambooth_state_dict[key] = f.get_tensor(key) + elif dreambooth_model_path.endswith(".ckpt"): + dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") + + # 1. vae + converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config) + for key in list(converted_vae_checkpoint.keys()): + if 'mid_block' in key: + if 'key' in key: + new_key = key.replace('key', 'to_k') + elif 'query' in key: + new_key = key.replace('query', 'to_q') + elif 'value' in key: + new_key = key.replace('value', 'to_v') + elif 'proj_attn' in key: + new_key = key.replace('proj_attn', 'to_out.0') + else: new_key=False + if new_key: + converted_vae_checkpoint[new_key] = converted_vae_checkpoint[key] + del converted_vae_checkpoint[key] + m, u = animation_pipeline.vae.load_state_dict(converted_vae_checkpoint, strict=False) + print(f"dreambooth vae: {u}") + # 2. unet + converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config) + m,u = animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) + # 3. text_model + animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) + del dreambooth_state_dict + + # lora layers + if lora_model_path != "": + print(f"load lora model from {lora_model_path}") + assert lora_model_path.endswith(".safetensors") + lora_state_dict = {} + with safe_open(lora_model_path, framework="pt", device="cpu") as f: + for key in f.keys(): + lora_state_dict[key] = f.get_tensor(key) + + animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha) + del lora_state_dict + + # domain adapter lora + if adapter_lora_path != "": + print(f"load domain lora from {adapter_lora_path}") + domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu") + domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict + domain_lora_state_dict.pop("animatediff_config", "") + + animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale) + + # motion module lora + for motion_module_lora_config in motion_module_lora_configs: + path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] + print(f"load motion LoRA from {path}") + motion_lora_state_dict = torch.load(path, map_location="cpu") + motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict + motion_lora_state_dict.pop("animatediff_config", "") + + animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha) + + return animation_pipeline + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def load_checkpoint(model_file, model): + + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + + state_dict = torch.load(model_file, map_location="cpu") + global_step = state_dict['global_step'] if "global_step" in state_dict else 0 + new_state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + new_state_dict = {k.replace('module.', '') : v for k, v in new_state_dict.items()} + m, u = model.load_state_dict(new_state_dict, strict=False) + return model, global_step, m, u, new_state_dict + + +def load_model(model, model_path): + if model_path != "": + print(f"init model from checkpoint: {model_path}") + model_ckpt = torch.load(model_path, map_location="cuda") + if "global_step" in model_ckpt: print(f"global_step: {model_ckpt['global_step']}") + state_dict = model_ckpt["state_dict"] if "state_dict" in model_ckpt else model_ckpt + m, u = model.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + assert len(u) == 0 + + +def interpolate_trajectory(points, n_points): + x = [point[0] for point in points] + y = [point[1] for point in points] + + t = np.linspace(0, 1, len(points)) + + fx = PchipInterpolator(t, x) + fy = PchipInterpolator(t, y) + + new_t = np.linspace(0, 1, n_points) + + new_x = fx(new_t) + new_y = fy(new_t) + new_points = list(zip(new_x, new_y)) + + return new_points + + +def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True): + """Generate a bivariate isotropic or anisotropic Gaussian kernel. + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + isotropic (bool): + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + kernel = pdf2(sigma_matrix, grid) + kernel = kernel / np.sum(kernel) + return kernel + + +def mesh_grid(kernel_size): + """Generate the mesh grid, centering at zero. + Args: + kernel_size (int): + Returns: + xy (ndarray): with the shape (kernel_size, kernel_size, 2) + xx (ndarray): with the shape (kernel_size, kernel_size) + yy (ndarray): with the shape (kernel_size, kernel_size) + """ + ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.) + xx, yy = np.meshgrid(ax, ax) + xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size, + 1))).reshape(kernel_size, kernel_size, 2) + return xy, xx, yy + + +def pdf2(sigma_matrix, grid): + """Calculate PDF of the bivariate Gaussian distribution. + Args: + sigma_matrix (ndarray): with the shape (2, 2) + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + Returns: + kernel (ndarrray): un-normalized kernel. + """ + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2)) + return kernel + + +def sigma_matrix2(sig_x, sig_y, theta): + """Calculate the rotated sigma matrix (two dimensional matrix). + Args: + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + Returns: + ndarray: Rotated sigma matrix. + """ + d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]]) + u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T)) + + +def create_image_controlnet(controlnet_config, unet, controlnet_path=""): + # load controlnet model + controlnet = None + + unet.config.num_attention_heads = 8 + unet.config.projection_class_embeddings_input_dim = None + + controlnet_config = OmegaConf.load(controlnet_config) + controlnet = ImageControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {})) + + if controlnet_path != "": + print(f"loading controlnet checkpoint from {controlnet_path} ...") + controlnet_state_dict = torch.load(controlnet_path, map_location="cuda") + if "global_step" in controlnet_state_dict: print(f"global_step: {controlnet_state_dict['global_step']}") + controlnet_state_dict = controlnet_state_dict["state_dict"] if "state_dict" in controlnet_state_dict else controlnet_state_dict + controlnet_state_dict.pop("animatediff_config", "") + controlnet.load_state_dict(controlnet_state_dict) + + return controlnet + + +def create_flow_controlnet(controlnet_config, unet, controlnet_path=""): + # load controlnet model + controlnet = None + + unet.config.num_attention_heads = 8 + unet.config.projection_class_embeddings_input_dim = None + + controlnet_config = OmegaConf.load(controlnet_config) + controlnet = FlowControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {})) + + if controlnet_path != "": + print(f"loading controlnet checkpoint from {controlnet_path} ...") + controlnet_state_dict = torch.load(controlnet_path, map_location="cuda") + controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict + controlnet_state_dict.pop("animatediff_config", "") + controlnet.load_state_dict(controlnet_state_dict) + + return controlnet diff --git a/utils/visualizer.py b/utils/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a9d6aab40ec0bc90def2d9d35f236cba475ec82d --- /dev/null +++ b/utils/visualizer.py @@ -0,0 +1,497 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import os +import numpy as np +import imageio +import torch + +from matplotlib import cm +import torch.nn.functional as F +import torchvision.transforms as transforms +import matplotlib.pyplot as plt +from PIL import Image, ImageDraw + + +def read_video_from_path(path): + try: + reader = imageio.get_reader(path) + except Exception as e: + print("Error opening video file: ", e) + return None + frames = [] + for i, im in enumerate(reader): + frames.append(np.array(im)) + return np.stack(frames) + + +def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True): + # Create a draw object + draw = ImageDraw.Draw(rgb) + # Calculate the bounding box of the circle + left_up_point = (coord[0] - radius, coord[1] - radius) + right_down_point = (coord[0] + radius, coord[1] + radius) + # Draw the circle + draw.ellipse( + [left_up_point, right_down_point], + fill=tuple(color) if visible else None, + outline=tuple(color), + ) + return rgb + + +def draw_line(rgb, coord_y, coord_x, color, linewidth): + draw = ImageDraw.Draw(rgb) + draw.line( + (coord_y[0], coord_y[1], coord_x[0], coord_x[1]), + fill=tuple(color), + width=linewidth, + ) + return rgb + + +def add_weighted(rgb, alpha, original, beta, gamma): + return (rgb * alpha + original * beta + gamma).astype("uint8") + + +class Visualizer: + def __init__( + self, + save_dir: str = "./results", + grayscale: bool = False, + pad_value: int = 0, + fps: int = 8, + mode: str = "rainbow", # 'cool', 'optical_flow' + linewidth: int = 2, + show_first_frame: int = 1, + tracks_leave_trace: int = 0, # -1 for infinite + ): + self.mode = mode + self.save_dir = save_dir + if mode == "rainbow": + self.color_map = cm.get_cmap("gist_rainbow") + elif mode == "cool": + self.color_map = cm.get_cmap(mode) + self.show_first_frame = show_first_frame + self.grayscale = grayscale + self.tracks_leave_trace = tracks_leave_trace + self.pad_value = pad_value + self.linewidth = linewidth + self.fps = fps + + def visualize( + self, + video: torch.Tensor, # (B,T,C,H,W) + tracks: torch.Tensor, # (B,T,N,2) + visibility: torch.Tensor = None, # (B, T, N, 1) bool + gt_tracks: torch.Tensor = None, # (B,T,N,2) + segm_mask: torch.Tensor = None, # (B,1,H,W) + filename: str = "video", + writer=None, # tensorboard Summary Writer, used for visualization during training + step: int = 0, + query_frame: int = 0, + save_video: bool = True, + compensate_for_camera_motion: bool = False, + ): + if compensate_for_camera_motion: + assert segm_mask is not None + if segm_mask is not None: + coords = tracks[0, query_frame].round().long() + segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long() + + video = F.pad( + video, + (self.pad_value, self.pad_value, self.pad_value, self.pad_value), + "constant", + 255, + ) + tracks = tracks + self.pad_value + + if self.grayscale: + transform = transforms.Grayscale() + video = transform(video) + video = video.repeat(1, 1, 3, 1, 1) + + res_video = self.draw_tracks_on_video( + video=video, + tracks=tracks, + visibility=visibility, + segm_mask=segm_mask, + gt_tracks=gt_tracks, + query_frame=query_frame, + compensate_for_camera_motion=compensate_for_camera_motion, + ) + if save_video: + self.save_video(res_video, filename=filename, writer=writer, step=step) + return res_video + + def save_video(self, video, filename, writer=None, step=0): + if writer is not None: + writer.add_video( + filename, + video.to(torch.uint8), + global_step=step, + fps=self.fps, + ) + else: + os.makedirs(self.save_dir, exist_ok=True) + wide_list = list(video.unbind(1)) + wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list] + + # Prepare the video file path + save_path = os.path.join(self.save_dir, f"{filename}.mp4") + + # Create a writer object + video_writer = imageio.get_writer(save_path, fps=self.fps) + + # Write frames to the video file + for frame in wide_list: + video_writer.append_data(frame) + + video_writer.close() + + # print(f"Video saved to {save_path}") + + def draw_tracks_on_video( + self, + video: torch.Tensor, + tracks: torch.Tensor, + visibility: torch.Tensor = None, + segm_mask: torch.Tensor = None, + gt_tracks=None, + query_frame: int = 0, + compensate_for_camera_motion=False, + ): + B, T, C, H, W = video.shape + _, _, N, D = tracks.shape + + assert D == 2 + assert C == 3 + video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C + tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2 + if gt_tracks is not None: + gt_tracks = gt_tracks[0].detach().cpu().numpy() + + res_video = [] + + # process input video + for rgb in video: + res_video.append(rgb.copy()) + vector_colors = np.zeros((T, N, 3)) + + if self.mode == "optical_flow": + import flow_vis + + vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None]) + elif segm_mask is None: + if self.mode == "rainbow": + y_min, y_max = ( + tracks[query_frame, :, 1].min(), + tracks[query_frame, :, 1].max(), + ) + norm = plt.Normalize(y_min, y_max) + for n in range(N): + color = self.color_map(norm(tracks[query_frame, n, 1])) + color = np.array(color[:3])[None] * 255 + vector_colors[:, n] = np.repeat(color, T, axis=0) + else: + # color changes with time + for t in range(T): + color = np.array(self.color_map(t / T)[:3])[None] * 255 + vector_colors[t] = np.repeat(color, N, axis=0) + else: + if self.mode == "rainbow": + vector_colors[:, segm_mask <= 0, :] = 255 + + y_min, y_max = ( + tracks[0, segm_mask > 0, 1].min(), + tracks[0, segm_mask > 0, 1].max(), + ) + norm = plt.Normalize(y_min, y_max) + for n in range(N): + if segm_mask[n] > 0: + color = self.color_map(norm(tracks[0, n, 1])) + color = np.array(color[:3])[None] * 255 + vector_colors[:, n] = np.repeat(color, T, axis=0) + + else: + # color changes with segm class + segm_mask = segm_mask.cpu() + color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32) + color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0 + color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0 + vector_colors = np.repeat(color[None], T, axis=0) + + # draw tracks + if self.tracks_leave_trace != 0: + for t in range(query_frame + 1, T): + first_ind = ( + max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0 + ) + curr_tracks = tracks[first_ind : t + 1] + curr_colors = vector_colors[first_ind : t + 1] + if compensate_for_camera_motion: + diff = ( + tracks[first_ind : t + 1, segm_mask <= 0] + - tracks[t : t + 1, segm_mask <= 0] + ).mean(1)[:, None] + + curr_tracks = curr_tracks - diff + curr_tracks = curr_tracks[:, segm_mask > 0] + curr_colors = curr_colors[:, segm_mask > 0] + + res_video[t] = self._draw_pred_tracks( + res_video[t], + curr_tracks, + curr_colors, + ) + if gt_tracks is not None: + res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1]) + + # draw points + for t in range(query_frame, T): + img = Image.fromarray(np.uint8(res_video[t])) + for i in range(N): + coord = (tracks[t, i, 0], tracks[t, i, 1]) + visibile = True + if visibility is not None: + visibile = visibility[0, t, i] + if coord[0] != 0 and coord[1] != 0: + if not compensate_for_camera_motion or ( + compensate_for_camera_motion and segm_mask[i] > 0 + ): + img = draw_circle( + img, + coord=coord, + radius=int(self.linewidth * 2), + color=vector_colors[t, i].astype(int), + visible=visibile, + ) + res_video[t] = np.array(img) + + # construct the final rgb sequence + if self.show_first_frame > 0: + res_video = [res_video[0]] * self.show_first_frame + res_video[1:] + return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte() + + def _draw_pred_tracks( + self, + rgb: np.ndarray, # H x W x 3 + tracks: np.ndarray, # T x 2 + vector_colors: np.ndarray, + alpha: float = 0.5, + ): + T, N, _ = tracks.shape + rgb = Image.fromarray(np.uint8(rgb)) + for s in range(T - 1): + vector_color = vector_colors[s] + original = rgb.copy() + alpha = (s / T) ** 2 + for i in range(N): + coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1])) + coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1])) + if coord_y[0] != 0 and coord_y[1] != 0: + rgb = draw_line( + rgb, + coord_y, + coord_x, + vector_color[i].astype(int), + self.linewidth, + ) + if self.tracks_leave_trace > 0: + rgb = Image.fromarray( + np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0)) + ) + rgb = np.array(rgb) + return rgb + + def _draw_gt_tracks( + self, + rgb: np.ndarray, # H x W x 3, + gt_tracks: np.ndarray, # T x 2 + ): + T, N, _ = gt_tracks.shape + color = np.array((211, 0, 0)) + rgb = Image.fromarray(np.uint8(rgb)) + for t in range(T): + for i in range(N): + gt_tracks = gt_tracks[t][i] + # draw a red cross + if gt_tracks[0] > 0 and gt_tracks[1] > 0: + length = self.linewidth * 3 + coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length) + coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length) + rgb = draw_line( + rgb, + coord_y, + coord_x, + color, + self.linewidth, + ) + coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length) + coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length) + rgb = draw_line( + rgb, + coord_y, + coord_x, + color, + self.linewidth, + ) + rgb = np.array(rgb) + return rgb + + + +########## optical flow visualization ########## + +UNKNOWN_FLOW_THRESH = 1e7 +SMALLFLOW = 0.0 +LARGEFLOW = 1e8 + + +def vis_flow_to_video(optical_flow, num_frames): + ''' + optical_flow: T-1 x H x W x C + ''' + video = [] + for i in range(1, num_frames): + flow_img = flow_to_image(optical_flow[i]) + flow_img = torch.Tensor(flow_img) # H x W x 3 + video.append(flow_img) + video = torch.stack(video, dim=0) # T-1 x H x W x 3 + return video + + +# from https://github.com/gengshan-y/VCN +def flow_to_image(flow): + """ + Convert flow into middlebury color code image + :param flow: optical flow map + :return: optical flow image in middlebury color + """ + u = flow[:, :, 0] + v = flow[:, :, 1] + + maxu = -999. + maxv = -999. + minu = 999. + minv = 999. + + idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) + u[idxUnknow] = 0 + v[idxUnknow] = 0 + + maxu = max(maxu, np.max(u)) + minu = min(minu, np.min(u)) + + maxv = max(maxv, np.max(v)) + minv = min(minv, np.min(v)) + + rad = np.sqrt(u ** 2 + v ** 2) + maxrad = max(-1, np.max(rad)) + + u = u / (maxrad + np.finfo(float).eps) + v = v / (maxrad + np.finfo(float).eps) + + img = compute_color(u, v) + + idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) + img[idx] = 0 + + return np.uint8(img) + + +def compute_color(u, v): + """ + compute optical flow color map + :param u: optical flow horizontal map + :param v: optical flow vertical map + :return: optical flow in color code + """ + [h, w] = u.shape + img = np.zeros([h, w, 3]) + nanIdx = np.isnan(u) | np.isnan(v) + u[nanIdx] = 0 + v[nanIdx] = 0 + + colorwheel = make_color_wheel() + ncols = np.size(colorwheel, 0) + + rad = np.sqrt(u ** 2 + v ** 2) + + a = np.arctan2(-v, -u) / np.pi + + fk = (a + 1) / 2 * (ncols - 1) + 1 + + k0 = np.floor(fk).astype(int) + + k1 = k0 + 1 + k1[k1 == ncols + 1] = 1 + f = fk - k0 + + for i in range(0, np.size(colorwheel, 1)): + tmp = colorwheel[:, i] + col0 = tmp[k0 - 1] / 255 + col1 = tmp[k1 - 1] / 255 + col = (1 - f) * col0 + f * col1 + + idx = rad <= 1 + col[idx] = 1 - rad[idx] * (1 - col[idx]) # 光流越小,颜色越亮。这样可以使得静止或者运动较慢的区域在可视化结果中更加明显 + notidx = np.logical_not(idx) + + col[notidx] *= 0.75 # 光流越大,颜色越暗 + img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx))) + + return img + + +def make_color_wheel(): + """ + Generate color wheel according Middlebury color code + :return: Color wheel + """ + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + + colorwheel = np.zeros([ncols, 3]) + + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY)) + col += RY + + # YG + colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG)) + colorwheel[col:col + YG, 1] = 255 + col += YG + + # GC + colorwheel[col:col + GC, 1] = 255 + colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC)) + col += GC + + # CB + colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB)) + colorwheel[col:col + CB, 2] = 255 + col += CB + + # BM + colorwheel[col:col + BM, 2] = 255 + colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM)) + col += + BM + + # MR + colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) + colorwheel[col:col + MR, 0] = 255 + + return colorwheel \ No newline at end of file