--- library_name: diffusers pipeline_tag: text-to-image --- ## Model Details ### Model Description This model is fine-tuned from [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) on 110,000 image-text pairs from the MIMIC dataset using the SVDIFF [1] PEFT method. Under this fine-tuning strategy, fine-tune only the singular values of weight matrices in the U-Net while keeping everything else frozen. - **Developed by:** [Raman Dutt](https://twitter.com/RamanDutt4) - **Shared by:** [Raman Dutt](https://twitter.com/RamanDutt4) - **Model type:** [Stable Diffusion fine-tuned using Parameter-Efficient Fine-Tuning] - **Finetuned from model:** [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) ### Model Sources - **Paper:** [Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity](https://arxiv.org/abs/2305.08252) - **Demo:** [MIMIC-SD-PEFT-Demo](https://huggingface.co/spaces/raman07/MIMIC-SD-Demo-Memory-Optimized?logs=container) ## Direct Use This model can be directly used to generate realistic medical images from text prompts. ## How to Get Started with the Model ```python import os from safetensors.torch import load_file from diffusers.pipelines import StableDiffusionPipeline #### Defining loading function def load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=None, hf_hub_kwargs=None, **kwargs): print(pretrained_model_name_or_path) config = UNet2DConditionModel.load_config(pretrained_model_name_or_path, **kwargs) original_model = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, **kwargs) state_dict = original_model.state_dict() with accelerate.init_empty_weights(): model = UNet2DConditionModelForSVDiff.from_config(config) # load pre-trained weights param_device = "cpu" torch_dtype = kwargs["torch_dtype"] if "torch_dtype" in kwargs else None spectral_shifts_weights = {n: torch.zeros(p.shape) for n, p in model.named_parameters() if "delta" in n} state_dict.update(spectral_shifts_weights) # move the params from meta device to cpu missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) if len(missing_keys) > 0: raise ValueError( f"Cannot load {model.__class__.__name__} from {pretrained_model_name_or_path} because the following keys are" f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize" " those weights or else make sure your checkpoint file is correct." ) for param_name, param in state_dict.items(): accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) if accepts_dtype: set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype) else: set_module_tensor_to_device(model, param_name, param_device, value=param) if spectral_shifts_ckpt: if os.path.isdir(spectral_shifts_ckpt): spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts.safetensors") elif not os.path.exists(spectral_shifts_ckpt): # download from hub hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts.safetensors", **hf_hub_kwargs) assert os.path.exists(spectral_shifts_ckpt) with safe_open(spectral_shifts_ckpt, framework="pt", device="cpu") as f: for key in f.keys(): # spectral_shifts_weights[key] = f.get_tensor(key) accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) if accepts_dtype: set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key), dtype=torch_dtype) else: set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key)) print(f"Resumed from {spectral_shifts_ckpt}") if "torch_dtype"in kwargs: model = model.to(kwargs["torch_dtype"]) model.register_to_config(_name_or_path=pretrained_model_name_or_path) # Set model in evaluation mode to deactivate DropOut modules by default model.eval() del original_model torch.cuda.empty_cache() return model pipe.unet = load_unet_for_svdiff( "runwayml/stable-diffusion-v1-5", spectral_shifts_ckpt=os.path.join('unet', "spectral_shifts.safetensors"), subfolder="unet", ) for module in pipe.unet.modules(): if hasattr(module, "perform_svd"): module.perform_svd() # Load the adapted U-Net pipe.unet.load_state_dict(state_dict, strict=False) pipe.to('cuda:0') # Generate images with text prompts TEXT_PROMPT = "No acute cardiopulmonary abnormality." GUIDANCE_SCALE = 4 INFERENCE_STEPS = 75 result_image = pipe( prompt=TEXT_PROMPT, height=224, width=224, guidance_scale=GUIDANCE_SCALE, num_inference_steps=INFERENCE_STEPS, ) result_pil_image = result_image["images"][0] ``` ## Training Details ### Training Data This model has been fine-tuned on 110K image-text pairs from the MIMIC dataset. ### Training Procedure The training procedure has been described in detail in Section 4.3 of this [paper](https://arxiv.org/abs/2305.08252). #### Metrics This model has been evaluated using the Fréchet inception distance (FID) Score on MIMIC dataset. ### Results | Fine-Tuning Strategy | FID Score | |------------------------|-----------| | Full FT | 58.74 | | Attention | 52.41 | | Bias | 20.81 | | Norm | 29.84 | | Bias+Norm+Attention | 35.93 | | LoRA | 439.65 | | SV-Diff | 23.59 | | DiffFit | 42.50 | ## Environmental Impact Using Parameter-Efficient Fine-Tuning potentially causes **lesser** harm to the environment since we fine-tune a significantly lesser number of parameters in a model. This results in much lesser computing and hardware requirements. ## Citation **BibTeX:** @article{dutt2023parameter, title={Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity}, author={Dutt, Raman and Ericsson, Linus and Sanchez, Pedro and Tsaftaris, Sotirios A and Hospedales, Timothy}, journal={arXiv preprint arXiv:2305.08252}, year={2023} } **APA:** Dutt, R., Ericsson, L., Sanchez, P., Tsaftaris, S. A., & Hospedales, T. (2023). Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity. arXiv preprint arXiv:2305.08252. ## Model Card Authors Raman Dutt [Twitter](https://twitter.com/RamanDutt4) [LinkedIn](https://www.linkedin.com/in/raman-dutt/) [Email](mailto:s2198939@ed.ac.uk) ## References 1. Han, Ligong, et al. "Svdiff: Compact parameter space for diffusion fine-tuning." arXiv preprint arXiv:2303.11305 (2023).