raman07's picture
using instructions changed
a2abddc verified
|
raw
history blame
No virus
7.2 kB
---
library_name: diffusers
pipeline_tag: text-to-image
---
## Model Details
### Model Description
This model is fine-tuned from [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) on 110,000 image-text pairs from the MIMIC dataset using the SVDIFF [1] PEFT method. Under this fine-tuning strategy, fine-tune only the singular values of weight matrices in the U-Net while keeping everything else frozen.
- **Developed by:** [Raman Dutt](https://twitter.com/RamanDutt4)
- **Shared by:** [Raman Dutt](https://twitter.com/RamanDutt4)
- **Model type:** [Stable Diffusion fine-tuned using Parameter-Efficient Fine-Tuning]
- **Finetuned from model:** [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)
### Model Sources
- **Paper:** [Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity](https://arxiv.org/abs/2305.08252)
- **Demo:** [MIMIC-SD-PEFT-Demo](https://huggingface.co/spaces/raman07/MIMIC-SD-Demo-Memory-Optimized?logs=container)
## Direct Use
This model can be directly used to generate realistic medical images from text prompts.
## How to Get Started with the Model
```python
import os
from safetensors.torch import load_file
from diffusers.pipelines import StableDiffusionPipeline
#### Defining loading function
def load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=None, hf_hub_kwargs=None, **kwargs):
print(pretrained_model_name_or_path)
config = UNet2DConditionModel.load_config(pretrained_model_name_or_path, **kwargs)
original_model = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
state_dict = original_model.state_dict()
with accelerate.init_empty_weights():
model = UNet2DConditionModelForSVDiff.from_config(config)
# load pre-trained weights
param_device = "cpu"
torch_dtype = kwargs["torch_dtype"] if "torch_dtype" in kwargs else None
spectral_shifts_weights = {n: torch.zeros(p.shape) for n, p in model.named_parameters() if "delta" in n}
state_dict.update(spectral_shifts_weights)
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0:
raise ValueError(
f"Cannot load {model.__class__.__name__} from {pretrained_model_name_or_path} because the following keys are"
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
" those weights or else make sure your checkpoint file is correct."
)
for param_name, param in state_dict.items():
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
if accepts_dtype:
set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
else:
set_module_tensor_to_device(model, param_name, param_device, value=param)
if spectral_shifts_ckpt:
if os.path.isdir(spectral_shifts_ckpt):
spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts.safetensors")
elif not os.path.exists(spectral_shifts_ckpt):
# download from hub
hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs
spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts.safetensors", **hf_hub_kwargs)
assert os.path.exists(spectral_shifts_ckpt)
with safe_open(spectral_shifts_ckpt, framework="pt", device="cpu") as f:
for key in f.keys():
# spectral_shifts_weights[key] = f.get_tensor(key)
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
if accepts_dtype:
set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key), dtype=torch_dtype)
else:
set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key))
print(f"Resumed from {spectral_shifts_ckpt}")
if "torch_dtype"in kwargs:
model = model.to(kwargs["torch_dtype"])
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
del original_model
torch.cuda.empty_cache()
return model
pipe.unet = load_unet_for_svdiff(
"runwayml/stable-diffusion-v1-5",
spectral_shifts_ckpt=os.path.join('unet', "spectral_shifts.safetensors"),
subfolder="unet",
)
for module in pipe.unet.modules():
if hasattr(module, "perform_svd"):
module.perform_svd()
# Load the adapted U-Net
pipe.unet.load_state_dict(state_dict, strict=False)
pipe.to('cuda:0')
# Generate images with text prompts
TEXT_PROMPT = "No acute cardiopulmonary abnormality."
GUIDANCE_SCALE = 4
INFERENCE_STEPS = 75
result_image = pipe(
prompt=TEXT_PROMPT,
height=224,
width=224,
guidance_scale=GUIDANCE_SCALE,
num_inference_steps=INFERENCE_STEPS,
)
result_pil_image = result_image["images"][0]
```
## Training Details
### Training Data
This model has been fine-tuned on 110K image-text pairs from the MIMIC dataset.
### Training Procedure
The training procedure has been described in detail in Section 4.3 of this [paper](https://arxiv.org/abs/2305.08252).
#### Metrics
This model has been evaluated using the Fréchet inception distance (FID) Score on MIMIC dataset.
### Results
| Fine-Tuning Strategy | FID Score |
|------------------------|-----------|
| Full FT | 58.74 |
| Attention | 52.41 |
| Bias | 20.81 |
| Norm | 29.84 |
| Bias+Norm+Attention | 35.93 |
| LoRA | 439.65 |
| SV-Diff | 23.59 |
| DiffFit | 42.50 |
## Environmental Impact
Using Parameter-Efficient Fine-Tuning potentially causes **lesser** harm to the environment since we fine-tune a significantly lesser number of parameters in a model. This results in much lesser computing and hardware requirements.
## Citation
**BibTeX:**
@article{dutt2023parameter,
title={Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity},
author={Dutt, Raman and Ericsson, Linus and Sanchez, Pedro and Tsaftaris, Sotirios A and Hospedales, Timothy},
journal={arXiv preprint arXiv:2305.08252},
year={2023}
}
**APA:**
Dutt, R., Ericsson, L., Sanchez, P., Tsaftaris, S. A., & Hospedales, T. (2023). Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity. arXiv preprint arXiv:2305.08252.
## Model Card Authors
Raman Dutt
[Twitter](https://twitter.com/RamanDutt4)
[LinkedIn](https://www.linkedin.com/in/raman-dutt/)
[Email](mailto:s2198939@ed.ac.uk)
## References
1. Han, Ligong, et al. "Svdiff: Compact parameter space for diffusion fine-tuning." arXiv preprint arXiv:2303.11305 (2023).