File size: 7,197 Bytes
8d56d36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2abddc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d56d36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
---
library_name: diffusers
pipeline_tag: text-to-image
---

## Model Details

### Model Description

This model is fine-tuned from [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) on 110,000 image-text pairs from the MIMIC dataset using the SVDIFF [1] PEFT method. Under this fine-tuning strategy, fine-tune only the singular values of weight matrices in the U-Net while keeping everything else frozen. 

- **Developed by:** [Raman Dutt](https://twitter.com/RamanDutt4)
- **Shared by:** [Raman Dutt](https://twitter.com/RamanDutt4)
- **Model type:** [Stable Diffusion fine-tuned using Parameter-Efficient Fine-Tuning]
- **Finetuned from model:** [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)

### Model Sources


- **Paper:** [Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity](https://arxiv.org/abs/2305.08252)
- **Demo:** [MIMIC-SD-PEFT-Demo](https://huggingface.co/spaces/raman07/MIMIC-SD-Demo-Memory-Optimized?logs=container)

## Direct Use

This model can be directly used to generate realistic medical images from text prompts. 


## How to Get Started with the Model

```python
import os
from safetensors.torch import load_file
from diffusers.pipelines import StableDiffusionPipeline


#### Defining loading function

def load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=None, hf_hub_kwargs=None, **kwargs):
    
    print(pretrained_model_name_or_path)
    config = UNet2DConditionModel.load_config(pretrained_model_name_or_path, **kwargs)
    original_model = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
    state_dict = original_model.state_dict()
    with accelerate.init_empty_weights():
        model = UNet2DConditionModelForSVDiff.from_config(config)
    # load pre-trained weights
    param_device = "cpu"
    torch_dtype = kwargs["torch_dtype"] if "torch_dtype" in kwargs else None
    spectral_shifts_weights = {n: torch.zeros(p.shape) for n, p in model.named_parameters() if "delta" in n}
    state_dict.update(spectral_shifts_weights)
    # move the params from meta device to cpu
    missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
    if len(missing_keys) > 0:
        raise ValueError(
            f"Cannot load {model.__class__.__name__} from {pretrained_model_name_or_path} because the following keys are"
            f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
            " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
            " those weights or else make sure your checkpoint file is correct."
        )

    for param_name, param in state_dict.items():
        accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
        if accepts_dtype:
            set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
        else:
            set_module_tensor_to_device(model, param_name, param_device, value=param)
    
    if spectral_shifts_ckpt:
        if os.path.isdir(spectral_shifts_ckpt):
            spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts.safetensors")
        elif not os.path.exists(spectral_shifts_ckpt):
            # download from hub
            hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs
            spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts.safetensors", **hf_hub_kwargs)
        assert os.path.exists(spectral_shifts_ckpt)

        with safe_open(spectral_shifts_ckpt, framework="pt", device="cpu") as f:
            for key in f.keys():
                # spectral_shifts_weights[key] = f.get_tensor(key)
                accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
                if accepts_dtype:
                    set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key), dtype=torch_dtype)
                else:
                    set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key))
        print(f"Resumed from {spectral_shifts_ckpt}")
    if "torch_dtype"in kwargs:
        model = model.to(kwargs["torch_dtype"])
    model.register_to_config(_name_or_path=pretrained_model_name_or_path)
    # Set model in evaluation mode to deactivate DropOut modules by default
    model.eval()
    del original_model
    torch.cuda.empty_cache()
    return model

pipe.unet = load_unet_for_svdiff(
            "runwayml/stable-diffusion-v1-5",
            spectral_shifts_ckpt=os.path.join('unet', "spectral_shifts.safetensors"),
            subfolder="unet",
        )
for module in pipe.unet.modules():
    if hasattr(module, "perform_svd"):
        module.perform_svd()

# Load the adapted U-Net
pipe.unet.load_state_dict(state_dict, strict=False)
pipe.to('cuda:0')

# Generate images with text prompts

TEXT_PROMPT = "No acute cardiopulmonary abnormality."
GUIDANCE_SCALE = 4
INFERENCE_STEPS = 75

result_image = pipe(
        prompt=TEXT_PROMPT,
        height=224,
        width=224,
        guidance_scale=GUIDANCE_SCALE,
        num_inference_steps=INFERENCE_STEPS,
    )

result_pil_image = result_image["images"][0]
```


## Training Details

### Training Data

This model has been fine-tuned on 110K image-text pairs from the MIMIC dataset. 

### Training Procedure

The training procedure has been described in detail in Section 4.3 of this [paper](https://arxiv.org/abs/2305.08252).

#### Metrics

This model has been evaluated using the Fréchet inception distance (FID) Score on MIMIC dataset.

### Results

| Fine-Tuning Strategy   | FID Score |
|------------------------|-----------|
| Full FT                | 58.74     |
| Attention              | 52.41     |
| Bias                   | 20.81     |
| Norm                   | 29.84     |
| Bias+Norm+Attention    | 35.93     |
| LoRA                   | 439.65    |
| SV-Diff                | 23.59     |
| DiffFit                | 42.50     |


## Environmental Impact

Using Parameter-Efficient Fine-Tuning potentially causes **lesser** harm to the environment since we fine-tune a significantly lesser number of parameters in a model. This results in much lesser computing and hardware requirements.

## Citation


**BibTeX:**

@article{dutt2023parameter,
  title={Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity},
  author={Dutt, Raman and Ericsson, Linus and Sanchez, Pedro and Tsaftaris, Sotirios A and Hospedales, Timothy},
  journal={arXiv preprint arXiv:2305.08252},
  year={2023}
}

**APA:**  
Dutt, R., Ericsson, L., Sanchez, P., Tsaftaris, S. A., & Hospedales, T. (2023). Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity. arXiv preprint arXiv:2305.08252.

## Model Card Authors

Raman Dutt  
[Twitter](https://twitter.com/RamanDutt4)  
[LinkedIn](https://www.linkedin.com/in/raman-dutt/)  
[Email](mailto:s2198939@ed.ac.uk)

## References

1. Han, Ligong, et al. "Svdiff: Compact parameter space for diffusion fine-tuning." arXiv preprint arXiv:2303.11305 (2023).