hysts HF staff commited on
Commit
c0a7c3c
1 Parent(s): 8b7a3d1
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
1
+ *.whl filter=lfs diff=lfs merge=lfs -text
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,6 +1,5 @@
1
- training_data/
2
  experiments/
3
- wandb/
4
 
5
 
6
  # Byte-compiled / optimized / DLL files
1
+ checkpoints/
2
  experiments/
 
3
 
4
 
5
  # Byte-compiled / optimized / DLL files
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ [submodule "Tune-A-Video"]
2
+ path = Tune-A-Video
3
+ url = https://github.com/showlab/Tune-A-Video
.pre-commit-config.yaml CHANGED
@@ -1,4 +1,4 @@
1
- exclude: train_dreambooth_lora.py
2
  repos:
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
  rev: v4.2.0
@@ -21,7 +21,7 @@ repos:
21
  - id: docformatter
22
  args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
- rev: 5.10.1
25
  hooks:
26
  - id: isort
27
  - repo: https://github.com/pre-commit/mirrors-mypy
1
+ exclude: patch
2
  repos:
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
  rev: v4.2.0
21
  - id: docformatter
22
  args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
  hooks:
26
  - id: isort
27
  - repo: https://github.com/pre-commit/mirrors-mypy
Dockerfile ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+ RUN apt-get update && \
4
+ apt-get upgrade -y && \
5
+ apt-get install -y --no-install-recommends \
6
+ git \
7
+ git-lfs \
8
+ wget \
9
+ curl \
10
+ # ffmpeg \
11
+ ffmpeg \
12
+ x264 \
13
+ # python build dependencies \
14
+ build-essential \
15
+ libssl-dev \
16
+ zlib1g-dev \
17
+ libbz2-dev \
18
+ libreadline-dev \
19
+ libsqlite3-dev \
20
+ libncursesw5-dev \
21
+ xz-utils \
22
+ tk-dev \
23
+ libxml2-dev \
24
+ libxmlsec1-dev \
25
+ libffi-dev \
26
+ liblzma-dev && \
27
+ apt-get clean && \
28
+ rm -rf /var/lib/apt/lists/*
29
+
30
+ RUN useradd -m -u 1000 user
31
+ USER user
32
+ ENV HOME=/home/user \
33
+ PATH=/home/user/.local/bin:${PATH}
34
+ WORKDIR ${HOME}/app
35
+
36
+ RUN curl https://pyenv.run | bash
37
+ ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
38
+ ENV PYTHON_VERSION=3.10.9
39
+ RUN pyenv install ${PYTHON_VERSION} && \
40
+ pyenv global ${PYTHON_VERSION} && \
41
+ pyenv rehash && \
42
+ pip install --no-cache-dir -U pip setuptools wheel
43
+
44
+ RUN pip install --no-cache-dir -U torch==1.13.1 torchvision==0.14.1
45
+ COPY --chown=1000 requirements.txt /tmp/requirements.txt
46
+ RUN pip install --no-cache-dir -U -r /tmp/requirements.txt
47
+ COPY --chown=1000 wheel/xformers-0.0.16+bc08bbc.d20230130-cp310-cp310-linux_x86_64.whl /tmp/xformers-0.0.16+bc08bbc.d20230130-cp310-cp310-linux_x86_64.whl
48
+ RUN pip install --no-cache-dir -U /tmp/xformers-0.0.16+bc08bbc.d20230130-cp310-cp310-linux_x86_64.whl
49
+
50
+ COPY --chown=1000 . ${HOME}/app
51
+ RUN cd Tune-A-Video && patch -p1 < ../patch
52
+ ENV PYTHONPATH=${HOME}/app \
53
+ PYTHONUNBUFFERED=1 \
54
+ GRADIO_ALLOW_FLAGGING=never \
55
+ GRADIO_NUM_PORTS=1 \
56
+ GRADIO_SERVER_NAME=0.0.0.0 \
57
+ GRADIO_THEME=huggingface \
58
+ SYSTEM=spaces
59
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,12 +1,9 @@
1
  ---
2
- title: LoRA DreamBooth Training UI
3
  emoji: ⚡
4
  colorFrom: red
5
  colorTo: purple
6
- sdk: gradio
7
- sdk_version: 3.16.2
8
- python_version: 3.10.9
9
- app_file: app.py
10
  pinned: false
11
  license: mit
12
  duplicated_from: lora-library/LoRA-DreamBooth-Training-UI
1
  ---
2
+ title: Tune-A-Video Training UI
3
  emoji: ⚡
4
  colorFrom: red
5
  colorTo: purple
6
+ sdk: docker
 
 
 
7
  pinned: false
8
  license: mit
9
  duplicated_from: lora-library/LoRA-DreamBooth-Training-UI
Tune-A-Video ADDED
@@ -0,0 +1 @@
 
1
+ Subproject commit b2c8c3eeac0df5c5d9eccc4dd2153e17b83c638c
app.py CHANGED
@@ -13,11 +13,11 @@ from app_upload import create_upload_demo
13
  from inference import InferencePipeline
14
  from trainer import Trainer
15
 
16
- TITLE = '# LoRA DreamBooth Training UI'
17
 
18
- ORIGINAL_SPACE_ID = 'lora-library/LoRA-DreamBooth-Training-UI'
19
  SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
20
- SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
21
 
22
  <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
23
  '''
@@ -29,7 +29,7 @@ else:
29
  CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
30
  <center>
31
  You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
32
- "T4 small" is sufficient to run this demo.
33
  </center>
34
  '''
35
 
13
  from inference import InferencePipeline
14
  from trainer import Trainer
15
 
16
+ TITLE = '# Tune-A-Video Training UI'
17
 
18
+ ORIGINAL_SPACE_ID = 'hysts/Tune-A-Video-Training-UI'
19
  SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
20
+ SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private A100 GPU.
21
 
22
  <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
23
  '''
29
  CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
30
  <center>
31
  You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
32
+ "A100 large" is required to run this demo.
33
  </center>
34
  '''
35
 
app_inference.py CHANGED
@@ -7,18 +7,13 @@ import enum
7
  import gradio as gr
8
  from huggingface_hub import HfApi
9
 
 
10
  from inference import InferencePipeline
11
  from utils import find_exp_dirs
12
 
13
- SAMPLE_MODEL_IDS = [
14
- 'patrickvonplaten/lora_dreambooth_dog_example',
15
- 'sayakpaul/sd-model-finetuned-lora-t4',
16
- ]
17
-
18
 
19
  class ModelSource(enum.Enum):
20
- SAMPLE = 'Sample'
21
- HUB_LIB = 'Hub (lora-library)'
22
  LOCAL = 'Local'
23
 
24
 
@@ -26,47 +21,41 @@ class InferenceUtil:
26
  def __init__(self, hf_token: str | None):
27
  self.hf_token = hf_token
28
 
29
- @staticmethod
30
- def load_sample_lora_model_list():
31
- return gr.update(choices=SAMPLE_MODEL_IDS, value=SAMPLE_MODEL_IDS[0])
32
-
33
- def load_hub_lora_model_list(self) -> dict:
34
  api = HfApi(token=self.hf_token)
35
  choices = [
36
- info.modelId for info in api.list_models(author='lora-library')
 
37
  ]
38
  return gr.update(choices=choices,
39
  value=choices[0] if choices else None)
40
 
41
  @staticmethod
42
- def load_local_lora_model_list() -> dict:
43
  choices = find_exp_dirs()
44
  return gr.update(choices=choices,
45
  value=choices[0] if choices else None)
46
 
47
- def reload_lora_model_list(self, model_source: str) -> dict:
48
- if model_source == ModelSource.SAMPLE.value:
49
- return self.load_sample_lora_model_list()
50
- elif model_source == ModelSource.HUB_LIB.value:
51
- return self.load_hub_lora_model_list()
52
  elif model_source == ModelSource.LOCAL.value:
53
- return self.load_local_lora_model_list()
54
  else:
55
  raise ValueError
56
 
57
- def load_model_info(self, lora_model_id: str) -> tuple[str, str]:
58
  try:
59
- card = InferencePipeline.get_model_card(lora_model_id,
60
- self.hf_token)
61
  except Exception:
62
  return '', ''
63
  base_model = getattr(card.data, 'base_model', '')
64
- instance_prompt = getattr(card.data, 'instance_prompt', '')
65
- return base_model, instance_prompt
66
 
67
- def reload_lora_model_list_and_update_model_info(
68
  self, model_source: str) -> tuple[dict, str, str]:
69
- model_list_update = self.reload_lora_model_list(model_source)
70
  model_list = model_list_update['choices']
71
  model_info = self.load_model_info(model_list[0] if model_list else '')
72
  return model_list_update, *model_info
@@ -83,30 +72,34 @@ def create_inference_demo(pipe: InferencePipeline,
83
  model_source = gr.Radio(
84
  label='Model Source',
85
  choices=[_.value for _ in ModelSource],
86
- value=ModelSource.SAMPLE.value)
87
  reload_button = gr.Button('Reload Model List')
88
- lora_model_id = gr.Dropdown(label='LoRA Model ID',
89
- choices=SAMPLE_MODEL_IDS,
90
- value=SAMPLE_MODEL_IDS[0])
91
  with gr.Accordion(
92
  label=
93
- 'Model info (Base model and instance prompt used for training)',
94
  open=False):
95
  with gr.Row():
96
  base_model_used_for_training = gr.Text(
97
  label='Base model', interactive=False)
98
- instance_prompt_used_for_training = gr.Text(
99
- label='Instance prompt', interactive=False)
100
  prompt = gr.Textbox(
101
  label='Prompt',
102
  max_lines=1,
103
- placeholder='Example: "A picture of a sks dog in a bucket"'
104
- )
105
- alpha = gr.Slider(label='LoRA alpha',
106
- minimum=0,
107
- maximum=2,
108
- step=0.05,
109
- value=1)
 
 
 
 
110
  seed = gr.Slider(label='Seed',
111
  minimum=0,
112
  maximum=100000,
@@ -117,7 +110,7 @@ def create_inference_demo(pipe: InferencePipeline,
117
  minimum=0,
118
  maximum=100,
119
  step=1,
120
- value=25)
121
  guidance_scale = gr.Slider(label='CFG Scale',
122
  minimum=0,
123
  maximum=50,
@@ -130,34 +123,33 @@ def create_inference_demo(pipe: InferencePipeline,
130
  - After training, you can press "Reload Model List" button to load your trained model names.
131
  ''')
132
  with gr.Column():
133
- result = gr.Image(label='Result')
134
-
135
- model_source.change(
136
- fn=app.reload_lora_model_list_and_update_model_info,
137
- inputs=model_source,
138
- outputs=[
139
- lora_model_id,
140
- base_model_used_for_training,
141
- instance_prompt_used_for_training,
142
- ])
143
- reload_button.click(
144
- fn=app.reload_lora_model_list_and_update_model_info,
145
- inputs=model_source,
146
- outputs=[
147
- lora_model_id,
148
- base_model_used_for_training,
149
- instance_prompt_used_for_training,
150
- ])
151
- lora_model_id.change(fn=app.load_model_info,
152
- inputs=lora_model_id,
153
- outputs=[
154
- base_model_used_for_training,
155
- instance_prompt_used_for_training,
156
- ])
157
  inputs = [
158
- lora_model_id,
159
  prompt,
160
- alpha,
 
161
  seed,
162
  num_steps,
163
  guidance_scale,
7
  import gradio as gr
8
  from huggingface_hub import HfApi
9
 
10
+ from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
11
  from inference import InferencePipeline
12
  from utils import find_exp_dirs
13
 
 
 
 
 
 
14
 
15
  class ModelSource(enum.Enum):
16
+ HUB_LIB = UploadTarget.MODEL_LIBRARY.value
 
17
  LOCAL = 'Local'
18
 
19
 
21
  def __init__(self, hf_token: str | None):
22
  self.hf_token = hf_token
23
 
24
+ def load_hub_model_list(self) -> dict:
 
 
 
 
25
  api = HfApi(token=self.hf_token)
26
  choices = [
27
+ info.modelId
28
+ for info in api.list_models(author=MODEL_LIBRARY_ORG_NAME)
29
  ]
30
  return gr.update(choices=choices,
31
  value=choices[0] if choices else None)
32
 
33
  @staticmethod
34
+ def load_local_model_list() -> dict:
35
  choices = find_exp_dirs()
36
  return gr.update(choices=choices,
37
  value=choices[0] if choices else None)
38
 
39
+ def reload_model_list(self, model_source: str) -> dict:
40
+ if model_source == ModelSource.HUB_LIB.value:
41
+ return self.load_hub_model_list()
 
 
42
  elif model_source == ModelSource.LOCAL.value:
43
+ return self.load_local_model_list()
44
  else:
45
  raise ValueError
46
 
47
+ def load_model_info(self, model_id: str) -> tuple[str, str]:
48
  try:
49
+ card = InferencePipeline.get_model_card(model_id, self.hf_token)
 
50
  except Exception:
51
  return '', ''
52
  base_model = getattr(card.data, 'base_model', '')
53
+ training_prompt = getattr(card.data, 'training_prompt', '')
54
+ return base_model, training_prompt
55
 
56
+ def reload_model_list_and_update_model_info(
57
  self, model_source: str) -> tuple[dict, str, str]:
58
+ model_list_update = self.reload_model_list(model_source)
59
  model_list = model_list_update['choices']
60
  model_info = self.load_model_info(model_list[0] if model_list else '')
61
  return model_list_update, *model_info
72
  model_source = gr.Radio(
73
  label='Model Source',
74
  choices=[_.value for _ in ModelSource],
75
+ value=ModelSource.HUB_LIB.value)
76
  reload_button = gr.Button('Reload Model List')
77
+ model_id = gr.Dropdown(label='Model ID',
78
+ choices=None,
79
+ value=None)
80
  with gr.Accordion(
81
  label=
82
+ 'Model info (Base model and prompt used for training)',
83
  open=False):
84
  with gr.Row():
85
  base_model_used_for_training = gr.Text(
86
  label='Base model', interactive=False)
87
+ prompt_used_for_training = gr.Text(
88
+ label='Training prompt', interactive=False)
89
  prompt = gr.Textbox(
90
  label='Prompt',
91
  max_lines=1,
92
+ placeholder='Example: "A panda is surfing"')
93
+ video_length = gr.Slider(label='Video length',
94
+ minimum=4,
95
+ maximum=12,
96
+ step=1,
97
+ value=8)
98
+ fps = gr.Slider(label='FPS',
99
+ minimum=1,
100
+ maximum=12,
101
+ step=1,
102
+ value=1)
103
  seed = gr.Slider(label='Seed',
104
  minimum=0,
105
  maximum=100000,
110
  minimum=0,
111
  maximum=100,
112
  step=1,
113
+ value=50)
114
  guidance_scale = gr.Slider(label='CFG Scale',
115
  minimum=0,
116
  maximum=50,
123
  - After training, you can press "Reload Model List" button to load your trained model names.
124
  ''')
125
  with gr.Column():
126
+ result = gr.Video(label='Result')
127
+
128
+ model_source.change(fn=app.reload_model_list_and_update_model_info,
129
+ inputs=model_source,
130
+ outputs=[
131
+ model_id,
132
+ base_model_used_for_training,
133
+ prompt_used_for_training,
134
+ ])
135
+ reload_button.click(fn=app.reload_model_list_and_update_model_info,
136
+ inputs=model_source,
137
+ outputs=[
138
+ model_id,
139
+ base_model_used_for_training,
140
+ prompt_used_for_training,
141
+ ])
142
+ model_id.change(fn=app.load_model_info,
143
+ inputs=model_id,
144
+ outputs=[
145
+ base_model_used_for_training,
146
+ prompt_used_for_training,
147
+ ])
 
 
148
  inputs = [
149
+ model_id,
150
  prompt,
151
+ video_length,
152
+ fps,
153
  seed,
154
  num_steps,
155
  guidance_scale,
app_training.py CHANGED
@@ -6,7 +6,7 @@ import os
6
 
7
  import gradio as gr
8
 
9
- from constants import UploadTarget
10
  from inference import InferencePipeline
11
  from trainer import Trainer
12
 
@@ -18,12 +18,13 @@ def create_training_demo(trainer: Trainer,
18
  with gr.Column():
19
  with gr.Box():
20
  gr.Markdown('Training Data')
21
- instance_images = gr.Files(label='Instance images')
22
- instance_prompt = gr.Textbox(label='Instance prompt',
23
- max_lines=1)
 
 
24
  gr.Markdown('''
25
- - Upload images of the style you are planning on training on.
26
- - For an instance prompt, use a unique, made up word to avoid collisions.
27
  ''')
28
  with gr.Box():
29
  gr.Markdown('Output Model')
@@ -46,25 +47,26 @@ def create_training_demo(trainer: Trainer,
46
  upload_to = gr.Radio(
47
  label='Upload to',
48
  choices=[_.value for _ in UploadTarget],
49
- value=UploadTarget.LORA_LIBRARY.value)
50
- gr.Markdown('''
51
- - By default, trained models will be uploaded to [LoRA Library](https://huggingface.co/lora-library) (see [this example model](https://huggingface.co/lora-library/lora-dreambooth-sample-dog)).
52
- - You can also choose "Personal Profile", in which case, the model will be uploaded to https://huggingface.co/{your_username}/{model_name}.
53
  ''')
54
 
55
  with gr.Box():
56
  gr.Markdown('Training Parameters')
57
  with gr.Row():
58
- base_model = gr.Text(
59
- label='Base Model',
60
- value='stabilityai/stable-diffusion-2-1-base',
61
- max_lines=1)
62
  resolution = gr.Dropdown(choices=['512', '768'],
63
  value='512',
64
- label='Resolution')
 
65
  num_training_steps = gr.Number(
66
- label='Number of Training Steps', value=1000, precision=0)
67
- learning_rate = gr.Number(label='Learning Rate', value=0.0001)
 
68
  gradient_accumulation = gr.Number(
69
  label='Number of Gradient Accumulation',
70
  value=1,
@@ -75,25 +77,20 @@ def create_training_demo(trainer: Trainer,
75
  step=1,
76
  value=0)
77
  fp16 = gr.Checkbox(label='FP16', value=True)
78
- use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
79
  checkpointing_steps = gr.Number(label='Checkpointing Steps',
80
- value=100,
81
  precision=0)
82
- use_wandb = gr.Checkbox(label='Use W&B',
83
- value=False,
84
- interactive=bool(
85
- os.getenv('WANDB_API_KEY')))
86
  validation_epochs = gr.Number(label='Validation Epochs',
87
  value=100,
88
  precision=0)
89
  gr.Markdown('''
90
  - The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
91
  - It takes a few minutes to download the base model first.
92
- - It will take about 8 minutes to train for 1000 steps with a T4 GPU.
 
93
  - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
94
  - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
95
- - You need to set the environment variable `WANDB_API_KEY` if you'd like to use [W&B](https://wandb.ai/site). See [W&B documentation](https://docs.wandb.ai/guides/track/advanced/environment-variables).
96
- - **Note:** Due to [this issue](https://github.com/huggingface/accelerate/issues/944), currently, training will not terminate properly if you use W&B.
97
  ''')
98
 
99
  remove_gpu_after_training = gr.Checkbox(
@@ -111,8 +108,8 @@ def create_training_demo(trainer: Trainer,
111
  run_button.click(fn=pipe.clear)
112
  run_button.click(fn=trainer.run,
113
  inputs=[
114
- instance_images,
115
- instance_prompt,
116
  output_model_name,
117
  delete_existing_model,
118
  validation_prompt,
@@ -125,7 +122,6 @@ def create_training_demo(trainer: Trainer,
125
  fp16,
126
  use_8bit_adam,
127
  checkpointing_steps,
128
- use_wandb,
129
  validation_epochs,
130
  upload_to_hub,
131
  use_private_repo,
6
 
7
  import gradio as gr
8
 
9
+ from constants import MODEL_LIBRARY_ORG_NAME, SAMPLE_MODEL_REPO, UploadTarget
10
  from inference import InferencePipeline
11
  from trainer import Trainer
12
 
18
  with gr.Column():
19
  with gr.Box():
20
  gr.Markdown('Training Data')
21
+ training_video = gr.File(label='Training video')
22
+ training_prompt = gr.Textbox(
23
+ label='Training prompt',
24
+ max_lines=1,
25
+ placeholder='A man is surfing')
26
  gr.Markdown('''
27
+ - Upload a video and write a prompt describing the video.
 
28
  ''')
29
  with gr.Box():
30
  gr.Markdown('Output Model')
47
  upload_to = gr.Radio(
48
  label='Upload to',
49
  choices=[_.value for _ in UploadTarget],
50
+ value=UploadTarget.MODEL_LIBRARY.value)
51
+ gr.Markdown(f'''
52
+ - By default, trained models will be uploaded to [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (see [this example model](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{SAMPLE_MODEL_REPO})).
53
+ - You can also choose "Personal Profile", in which case, the model will be uploaded to https://huggingface.co/{{your_username}}/{{model_name}}.
54
  ''')
55
 
56
  with gr.Box():
57
  gr.Markdown('Training Parameters')
58
  with gr.Row():
59
+ base_model = gr.Text(label='Base Model',
60
+ value='CompVis/stable-diffusion-v1-4',
61
+ max_lines=1)
 
62
  resolution = gr.Dropdown(choices=['512', '768'],
63
  value='512',
64
+ label='Resolution',
65
+ visible=False)
66
  num_training_steps = gr.Number(
67
+ label='Number of Training Steps', value=300, precision=0)
68
+ learning_rate = gr.Number(label='Learning Rate',
69
+ value=0.000035)
70
  gradient_accumulation = gr.Number(
71
  label='Number of Gradient Accumulation',
72
  value=1,
77
  step=1,
78
  value=0)
79
  fp16 = gr.Checkbox(label='FP16', value=True)
80
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=False)
81
  checkpointing_steps = gr.Number(label='Checkpointing Steps',
82
+ value=1000,
83
  precision=0)
 
 
 
 
84
  validation_epochs = gr.Number(label='Validation Epochs',
85
  value=100,
86
  precision=0)
87
  gr.Markdown('''
88
  - The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
89
  - It takes a few minutes to download the base model first.
90
+ - It will take about 4 minutes to train for 300 steps with an A100 GPU.
91
+ - It takes a few minutes to upload your trained model.
92
  - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
93
  - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
 
 
94
  ''')
95
 
96
  remove_gpu_after_training = gr.Checkbox(
108
  run_button.click(fn=pipe.clear)
109
  run_button.click(fn=trainer.run,
110
  inputs=[
111
+ training_video,
112
+ training_prompt,
113
  output_model_name,
114
  delete_existing_model,
115
  validation_prompt,
122
  fp16,
123
  use_8bit_adam,
124
  checkpointing_steps,
 
125
  validation_epochs,
126
  upload_to_hub,
127
  use_private_repo,
app_upload.py CHANGED
@@ -7,13 +7,13 @@ import pathlib
7
  import gradio as gr
8
  import slugify
9
 
10
- from constants import UploadTarget
11
  from uploader import Uploader
12
  from utils import find_exp_dirs
13
 
14
 
15
- class LoRAModelUploader(Uploader):
16
- def upload_lora_model(
17
  self,
18
  folder_path: str,
19
  repo_name: str,
@@ -29,8 +29,8 @@ class LoRAModelUploader(Uploader):
29
 
30
  if upload_to == UploadTarget.PERSONAL_PROFILE.value:
31
  organization = ''
32
- elif upload_to == UploadTarget.LORA_LIBRARY.value:
33
- organization = 'lora-library'
34
  else:
35
  raise ValueError
36
 
@@ -41,14 +41,14 @@ class LoRAModelUploader(Uploader):
41
  delete_existing_repo=delete_existing_repo)
42
 
43
 
44
- def load_local_lora_model_list() -> dict:
45
- choices = find_exp_dirs(ignore_repo=True)
46
  return gr.update(choices=choices, value=choices[0] if choices else None)
47
 
48
 
49
  def create_upload_demo(hf_token: str | None) -> gr.Blocks:
50
- uploader = LoRAModelUploader(hf_token)
51
- model_dirs = find_exp_dirs(ignore_repo=True)
52
 
53
  with gr.Blocks() as demo:
54
  with gr.Box():
@@ -66,20 +66,20 @@ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
66
  label='Delete existing repo of the same name', value=False)
67
  upload_to = gr.Radio(label='Upload to',
68
  choices=[_.value for _ in UploadTarget],
69
- value=UploadTarget.LORA_LIBRARY.value)
70
  model_name = gr.Textbox(label='Model Name')
71
  upload_button = gr.Button('Upload')
72
- gr.Markdown('''
73
- - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{your_username}/{model_name}) or to the public [LoRA Concepts Library](https://huggingface.co/lora-library) (i.e. https://huggingface.co/lora-library/{model_name}).
74
  ''')
75
  with gr.Box():
76
  gr.Markdown('Output message')
77
  output_message = gr.Markdown()
78
 
79
- reload_button.click(fn=load_local_lora_model_list,
80
  inputs=None,
81
  outputs=model_dir)
82
- upload_button.click(fn=uploader.upload_lora_model,
83
  inputs=[
84
  model_dir,
85
  model_name,
7
  import gradio as gr
8
  import slugify
9
 
10
+ from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
11
  from uploader import Uploader
12
  from utils import find_exp_dirs
13
 
14
 
15
+ class ModelUploader(Uploader):
16
+ def upload_model(
17
  self,
18
  folder_path: str,
19
  repo_name: str,
29
 
30
  if upload_to == UploadTarget.PERSONAL_PROFILE.value:
31
  organization = ''
32
+ elif upload_to == UploadTarget.MODEL_LIBRARY.value:
33
+ organization = MODEL_LIBRARY_ORG_NAME
34
  else:
35
  raise ValueError
36
 
41
  delete_existing_repo=delete_existing_repo)
42
 
43
 
44
+ def load_local_model_list() -> dict:
45
+ choices = find_exp_dirs()
46
  return gr.update(choices=choices, value=choices[0] if choices else None)
47
 
48
 
49
  def create_upload_demo(hf_token: str | None) -> gr.Blocks:
50
+ uploader = ModelUploader(hf_token)
51
+ model_dirs = find_exp_dirs()
52
 
53
  with gr.Blocks() as demo:
54
  with gr.Box():
66
  label='Delete existing repo of the same name', value=False)
67
  upload_to = gr.Radio(label='Upload to',
68
  choices=[_.value for _ in UploadTarget],
69
+ value=UploadTarget.MODEL_LIBRARY.value)
70
  model_name = gr.Textbox(label='Model Name')
71
  upload_button = gr.Button('Upload')
72
+ gr.Markdown(f'''
73
+ - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}).
74
  ''')
75
  with gr.Box():
76
  gr.Markdown('Output message')
77
  output_message = gr.Markdown()
78
 
79
+ reload_button.click(fn=load_local_model_list,
80
  inputs=None,
81
  outputs=model_dir)
82
+ upload_button.click(fn=uploader.upload_model,
83
  inputs=[
84
  model_dir,
85
  model_name,
constants.py CHANGED
@@ -3,4 +3,8 @@ import enum
3
 
4
  class UploadTarget(enum.Enum):
5
  PERSONAL_PROFILE = 'Personal Profile'
6
- LORA_LIBRARY = 'LoRA Library'
 
 
 
 
3
 
4
  class UploadTarget(enum.Enum):
5
  PERSONAL_PROFILE = 'Personal Profile'
6
+ MODEL_LIBRARY = 'Tune-A-Video Library'
7
+
8
+
9
+ MODEL_LIBRARY_ORG_NAME = 'Tune-A-Video-library'
10
+ SAMPLE_MODEL_REPO = 'Tune-A-Video-library/a-man-is-surfing'
inference.py CHANGED
@@ -2,13 +2,21 @@ from __future__ import annotations
2
 
3
  import gc
4
  import pathlib
 
 
5
 
6
  import gradio as gr
 
7
  import PIL.Image
8
  import torch
9
- from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
10
  from huggingface_hub import ModelCard
11
 
 
 
 
 
 
12
 
13
  class InferencePipeline:
14
  def __init__(self, hf_token: str | None = None):
@@ -16,20 +24,18 @@ class InferencePipeline:
16
  self.pipe = None
17
  self.device = torch.device(
18
  'cuda:0' if torch.cuda.is_available() else 'cpu')
19
- self.lora_model_id = None
20
- self.base_model_id = None
21
 
22
  def clear(self) -> None:
23
- self.lora_model_id = None
24
- self.base_model_id = None
25
  del self.pipe
26
  self.pipe = None
27
  torch.cuda.empty_cache()
28
  gc.collect()
29
 
30
  @staticmethod
31
- def check_if_model_is_local(lora_model_id: str) -> bool:
32
- return pathlib.Path(lora_model_id).exists()
33
 
34
  @staticmethod
35
  def get_model_card(model_id: str,
@@ -41,39 +47,30 @@ class InferencePipeline:
41
  return ModelCard.load(card_path, token=hf_token)
42
 
43
  @staticmethod
44
- def get_base_model_info(lora_model_id: str,
45
- hf_token: str | None = None) -> str:
46
- card = InferencePipeline.get_model_card(lora_model_id, hf_token)
47
  return card.data.base_model
48
 
49
- def load_pipe(self, lora_model_id: str) -> None:
50
- if lora_model_id == self.lora_model_id:
51
  return
52
- base_model_id = self.get_base_model_info(lora_model_id, self.hf_token)
53
- if base_model_id != self.base_model_id:
54
- if self.device.type == 'cpu':
55
- pipe = DiffusionPipeline.from_pretrained(
56
- base_model_id, use_auth_token=self.hf_token)
57
- else:
58
- pipe = DiffusionPipeline.from_pretrained(
59
- base_model_id,
60
- torch_dtype=torch.float16,
61
- use_auth_token=self.hf_token)
62
- pipe = pipe.to(self.device)
63
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(
64
- pipe.scheduler.config)
65
- self.pipe = pipe
66
- self.pipe.unet.load_attn_procs( # type: ignore
67
- lora_model_id, use_auth_token=self.hf_token)
68
-
69
- self.lora_model_id = lora_model_id # type: ignore
70
- self.base_model_id = base_model_id # type: ignore
71
 
72
  def run(
73
  self,
74
- lora_model_id: str,
75
  prompt: str,
76
- lora_scale: float,
 
77
  seed: int,
78
  n_steps: int,
79
  guidance_scale: float,
@@ -81,14 +78,26 @@ class InferencePipeline:
81
  if not torch.cuda.is_available():
82
  raise gr.Error('CUDA is not available.')
83
 
84
- self.load_pipe(lora_model_id)
85
 
86
  generator = torch.Generator(device=self.device).manual_seed(seed)
87
  out = self.pipe(
88
  prompt,
 
 
 
89
  num_inference_steps=n_steps,
90
  guidance_scale=guidance_scale,
91
  generator=generator,
92
- cross_attention_kwargs={'scale': lora_scale},
93
  ) # type: ignore
94
- return out.images[0]
 
 
 
 
 
 
 
 
 
 
2
 
3
  import gc
4
  import pathlib
5
+ import sys
6
+ import tempfile
7
 
8
  import gradio as gr
9
+ import imageio
10
  import PIL.Image
11
  import torch
12
+ from einops import rearrange
13
  from huggingface_hub import ModelCard
14
 
15
+ sys.path.append('Tune-A-Video')
16
+
17
+ from tuneavideo.models.unet import UNet3DConditionModel
18
+ from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
19
+
20
 
21
  class InferencePipeline:
22
  def __init__(self, hf_token: str | None = None):
24
  self.pipe = None
25
  self.device = torch.device(
26
  'cuda:0' if torch.cuda.is_available() else 'cpu')
27
+ self.model_id = None
 
28
 
29
  def clear(self) -> None:
30
+ self.model_id = None
 
31
  del self.pipe
32
  self.pipe = None
33
  torch.cuda.empty_cache()
34
  gc.collect()
35
 
36
  @staticmethod
37
+ def check_if_model_is_local(model_id: str) -> bool:
38
+ return pathlib.Path(model_id).exists()
39
 
40
  @staticmethod
41
  def get_model_card(model_id: str,
47
  return ModelCard.load(card_path, token=hf_token)
48
 
49
  @staticmethod
50
+ def get_base_model_info(model_id: str, hf_token: str | None = None) -> str:
51
+ card = InferencePipeline.get_model_card(model_id, hf_token)
 
52
  return card.data.base_model
53
 
54
+ def load_pipe(self, model_id: str) -> None:
55
+ if model_id == self.model_id:
56
  return
57
+ base_model_id = self.get_base_model_info(model_id, self.hf_token)
58
+ unet = UNet3DConditionModel.from_pretrained(model_id,
59
+ subfolder='unet',
60
+ torch_dtype=torch.float16)
61
+ pipe = TuneAVideoPipeline.from_pretrained(base_model_id,
62
+ unet=unet,
63
+ torch_dtype=torch.float16)
64
+ pipe = pipe.to(self.device)
65
+ self.pipe = pipe
66
+ self.model_id = model_id # type: ignore
 
 
 
 
 
 
 
 
 
67
 
68
  def run(
69
  self,
70
+ model_id: str,
71
  prompt: str,
72
+ video_length: int,
73
+ fps: int,
74
  seed: int,
75
  n_steps: int,
76
  guidance_scale: float,
78
  if not torch.cuda.is_available():
79
  raise gr.Error('CUDA is not available.')
80
 
81
+ self.load_pipe(model_id)
82
 
83
  generator = torch.Generator(device=self.device).manual_seed(seed)
84
  out = self.pipe(
85
  prompt,
86
+ video_length=video_length,
87
+ width=512,
88
+ height=512,
89
  num_inference_steps=n_steps,
90
  guidance_scale=guidance_scale,
91
  generator=generator,
 
92
  ) # type: ignore
93
+
94
+ frames = rearrange(out.videos[0], 'c t h w -> t h w c')
95
+ frames = (frames * 255).to(torch.uint8).numpy()
96
+
97
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
98
+ writer = imageio.get_writer(out_file.name, fps=fps)
99
+ for frame in frames:
100
+ writer.append_data(frame)
101
+ writer.close()
102
+
103
+ return out_file.name
packages.txt ADDED
@@ -0,0 +1 @@
 
1
+ ffmpeg
patch ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/train_tuneavideo.py b/train_tuneavideo.py
2
+ index 66d51b2..86b2a5d 100644
3
+ --- a/train_tuneavideo.py
4
+ +++ b/train_tuneavideo.py
5
+ @@ -94,8 +94,8 @@ def main(
6
+
7
+ # Handle the output folder creation
8
+ if accelerator.is_main_process:
9
+ - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
10
+ - output_dir = os.path.join(output_dir, now)
11
+ + #now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
12
+ + #output_dir = os.path.join(output_dir, now)
13
+ os.makedirs(output_dir, exist_ok=True)
14
+ OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
15
+
requirements.txt CHANGED
@@ -1,14 +1,18 @@
1
  accelerate==0.15.0
2
- bitsandbytes==0.36.0.post2
3
- datasets==2.8.0
4
- git+https://github.com/huggingface/diffusers@31be42209ddfdb69d9640a777b32e9b5c6259bf0#egg=diffusers
 
5
  ftfy==6.1.1
6
  gradio==3.16.2
7
  huggingface-hub==0.12.0
 
 
 
8
  Pillow==9.4.0
9
  python-slugify==7.0.0
10
  tensorboard==2.11.2
11
  torch==1.13.1
12
  torchvision==0.14.1
13
  transformers==4.26.0
14
- wandb==0.13.9
1
  accelerate==0.15.0
2
+ bitsandbytes==0.35.4
3
+ decord==0.6.0
4
+ diffusers[torch]==0.11.1
5
+ einops==0.6.0
6
  ftfy==6.1.1
7
  gradio==3.16.2
8
  huggingface-hub==0.12.0
9
+ imageio==2.25.0
10
+ imageio-ffmpeg==0.4.8
11
+ omegaconf==2.3.0
12
  Pillow==9.4.0
13
  python-slugify==7.0.0
14
  tensorboard==2.11.2
15
  torch==1.13.1
16
  torchvision==0.14.1
17
  transformers==4.26.0
18
+ triton==2.0.0.dev20221202
train_dreambooth_lora.py DELETED
@@ -1,1026 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- #
4
- # This file is adapted from https://github.com/huggingface/diffusers/blob/febaf863026bd014b7a14349336544fc109d0f57/examples/dreambooth/train_dreambooth_lora.py
5
- # The original license is as below:
6
- #
7
- # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
-
20
- import argparse
21
- import hashlib
22
- import logging
23
- import math
24
- import os
25
- import warnings
26
- from pathlib import Path
27
- from typing import Optional
28
-
29
- import numpy as np
30
- import torch
31
- import torch.nn.functional as F
32
- import torch.utils.checkpoint
33
- from torch.utils.data import Dataset
34
-
35
- import datasets
36
- import diffusers
37
- import transformers
38
- from accelerate import Accelerator
39
- from accelerate.logging import get_logger
40
- from accelerate.utils import set_seed
41
- from diffusers import (
42
- AutoencoderKL,
43
- DDPMScheduler,
44
- DiffusionPipeline,
45
- DPMSolverMultistepScheduler,
46
- UNet2DConditionModel,
47
- )
48
- from diffusers.loaders import AttnProcsLayers
49
- from diffusers.models.cross_attention import LoRACrossAttnProcessor
50
- from diffusers.optimization import get_scheduler
51
- from diffusers.utils import check_min_version, is_wandb_available
52
- from diffusers.utils.import_utils import is_xformers_available
53
- from huggingface_hub import HfFolder, Repository, create_repo, whoami
54
- from PIL import Image
55
- from torchvision import transforms
56
- from tqdm.auto import tqdm
57
- from transformers import AutoTokenizer, PretrainedConfig
58
-
59
-
60
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
61
- check_min_version("0.12.0.dev0")
62
-
63
- logger = get_logger(__name__)
64
-
65
-
66
- def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None):
67
- img_str = ""
68
- for i, image in enumerate(images):
69
- image.save(os.path.join(repo_folder, f"image_{i}.png"))
70
- img_str += f"![img_{i}](./image_{i}.png)\n"
71
-
72
- yaml = f"""
73
- ---
74
- license: creativeml-openrail-m
75
- base_model: {base_model}
76
- tags:
77
- - stable-diffusion
78
- - stable-diffusion-diffusers
79
- - text-to-image
80
- - diffusers
81
- - lora
82
- inference: true
83
- ---
84
- """
85
- model_card = f"""
86
- # LoRA DreamBooth - {repo_name}
87
-
88
- These are LoRA adaption weights for {repo_name}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
89
- {img_str}
90
- """
91
- with open(os.path.join(repo_folder, "README.md"), "w") as f:
92
- f.write(yaml + model_card)
93
-
94
-
95
- def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
96
- text_encoder_config = PretrainedConfig.from_pretrained(
97
- pretrained_model_name_or_path,
98
- subfolder="text_encoder",
99
- revision=revision,
100
- )
101
- model_class = text_encoder_config.architectures[0]
102
-
103
- if model_class == "CLIPTextModel":
104
- from transformers import CLIPTextModel
105
-
106
- return CLIPTextModel
107
- elif model_class == "RobertaSeriesModelWithTransformation":
108
- from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
109
-
110
- return RobertaSeriesModelWithTransformation
111
- else:
112
- raise ValueError(f"{model_class} is not supported.")
113
-
114
-
115
- def parse_args(input_args=None):
116
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
117
- parser.add_argument(
118
- "--pretrained_model_name_or_path",
119
- type=str,
120
- default=None,
121
- required=True,
122
- help="Path to pretrained model or model identifier from huggingface.co/models.",
123
- )
124
- parser.add_argument(
125
- "--revision",
126
- type=str,
127
- default=None,
128
- required=False,
129
- help="Revision of pretrained model identifier from huggingface.co/models.",
130
- )
131
- parser.add_argument(
132
- "--tokenizer_name",
133
- type=str,
134
- default=None,
135
- help="Pretrained tokenizer name or path if not the same as model_name",
136
- )
137
- parser.add_argument(
138
- "--instance_data_dir",
139
- type=str,
140
- default=None,
141
- required=True,
142
- help="A folder containing the training data of instance images.",
143
- )
144
- parser.add_argument(
145
- "--class_data_dir",
146
- type=str,
147
- default=None,
148
- required=False,
149
- help="A folder containing the training data of class images.",
150
- )
151
- parser.add_argument(
152
- "--instance_prompt",
153
- type=str,
154
- default=None,
155
- required=True,
156
- help="The prompt with identifier specifying the instance",
157
- )
158
- parser.add_argument(
159
- "--class_prompt",
160
- type=str,
161
- default=None,
162
- help="The prompt to specify images in the same class as provided instance images.",
163
- )
164
- parser.add_argument(
165
- "--validation_prompt",
166
- type=str,
167
- default=None,
168
- help="A prompt that is used during validation to verify that the model is learning.",
169
- )
170
- parser.add_argument(
171
- "--num_validation_images",
172
- type=int,
173
- default=4,
174
- help="Number of images that should be generated during validation with `validation_prompt`.",
175
- )
176
- parser.add_argument(
177
- "--validation_epochs",
178
- type=int,
179
- default=50,
180
- help=(
181
- "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
182
- " `args.validation_prompt` multiple times: `args.num_validation_images`."
183
- ),
184
- )
185
- parser.add_argument(
186
- "--with_prior_preservation",
187
- default=False,
188
- action="store_true",
189
- help="Flag to add prior preservation loss.",
190
- )
191
- parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
192
- parser.add_argument(
193
- "--num_class_images",
194
- type=int,
195
- default=100,
196
- help=(
197
- "Minimal class images for prior preservation loss. If there are not enough images already present in"
198
- " class_data_dir, additional images will be sampled with class_prompt."
199
- ),
200
- )
201
- parser.add_argument(
202
- "--output_dir",
203
- type=str,
204
- default="lora-dreambooth-model",
205
- help="The output directory where the model predictions and checkpoints will be written.",
206
- )
207
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
208
- parser.add_argument(
209
- "--resolution",
210
- type=int,
211
- default=512,
212
- help=(
213
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
214
- " resolution"
215
- ),
216
- )
217
- parser.add_argument(
218
- "--center_crop",
219
- default=False,
220
- action="store_true",
221
- help=(
222
- "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
223
- " cropped. The images will be resized to the resolution first before cropping."
224
- ),
225
- )
226
- parser.add_argument(
227
- "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
228
- )
229
- parser.add_argument(
230
- "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
231
- )
232
- parser.add_argument("--num_train_epochs", type=int, default=1)
233
- parser.add_argument(
234
- "--max_train_steps",
235
- type=int,
236
- default=None,
237
- help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
238
- )
239
- parser.add_argument(
240
- "--checkpointing_steps",
241
- type=int,
242
- default=500,
243
- help=(
244
- "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
245
- " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
246
- " training using `--resume_from_checkpoint`."
247
- ),
248
- )
249
- parser.add_argument(
250
- "--resume_from_checkpoint",
251
- type=str,
252
- default=None,
253
- help=(
254
- "Whether training should be resumed from a previous checkpoint. Use a path saved by"
255
- ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
256
- ),
257
- )
258
- parser.add_argument(
259
- "--gradient_accumulation_steps",
260
- type=int,
261
- default=1,
262
- help="Number of updates steps to accumulate before performing a backward/update pass.",
263
- )
264
- parser.add_argument(
265
- "--gradient_checkpointing",
266
- action="store_true",
267
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
268
- )
269
- parser.add_argument(
270
- "--learning_rate",
271
- type=float,
272
- default=5e-4,
273
- help="Initial learning rate (after the potential warmup period) to use.",
274
- )
275
- parser.add_argument(
276
- "--scale_lr",
277
- action="store_true",
278
- default=False,
279
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
280
- )
281
- parser.add_argument(
282
- "--lr_scheduler",
283
- type=str,
284
- default="constant",
285
- help=(
286
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
287
- ' "constant", "constant_with_warmup"]'
288
- ),
289
- )
290
- parser.add_argument(
291
- "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
292
- )
293
- parser.add_argument(
294
- "--lr_num_cycles",
295
- type=int,
296
- default=1,
297
- help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
298
- )
299
- parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
300
- parser.add_argument(
301
- "--dataloader_num_workers",
302
- type=int,
303
- default=0,
304
- help=(
305
- "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
306
- ),
307
- )
308
- parser.add_argument(
309
- "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
310
- )
311
- parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
312
- parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
313
- parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
314
- parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
315
- parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
316
- parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
317
- parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
318
- parser.add_argument(
319
- "--hub_model_id",
320
- type=str,
321
- default=None,
322
- help="The name of the repository to keep in sync with the local `output_dir`.",
323
- )
324
- parser.add_argument(
325
- "--logging_dir",
326
- type=str,
327
- default="logs",
328
- help=(
329
- "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
330
- " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
331
- ),
332
- )
333
- parser.add_argument(
334
- "--allow_tf32",
335
- action="store_true",
336
- help=(
337
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
338
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
339
- ),
340
- )
341
- parser.add_argument(
342
- "--report_to",
343
- type=str,
344
- default="tensorboard",
345
- help=(
346
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
347
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
348
- ),
349
- )
350
- parser.add_argument(
351
- "--mixed_precision",
352
- type=str,
353
- default=None,
354
- choices=["no", "fp16", "bf16"],
355
- help=(
356
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
357
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
358
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
359
- ),
360
- )
361
- parser.add_argument(
362
- "--prior_generation_precision",
363
- type=str,
364
- default=None,
365
- choices=["no", "fp32", "fp16", "bf16"],
366
- help=(
367
- "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
368
- " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
369
- ),
370
- )
371
- parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
372
- parser.add_argument(
373
- "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
374
- )
375
-
376
- if input_args is not None:
377
- args = parser.parse_args(input_args)
378
- else:
379
- args = parser.parse_args()
380
-
381
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
382
- if env_local_rank != -1 and env_local_rank != args.local_rank:
383
- args.local_rank = env_local_rank
384
-
385
- if args.with_prior_preservation:
386
- if args.class_data_dir is None:
387
- raise ValueError("You must specify a data directory for class images.")
388
- if args.class_prompt is None:
389
- raise ValueError("You must specify prompt for class images.")
390
- else:
391
- # logger is not available yet
392
- if args.class_data_dir is not None:
393
- warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
394
- if args.class_prompt is not None:
395
- warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
396
-
397
- return args
398
-
399
-
400
- class DreamBoothDataset(Dataset):
401
- """
402
- A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
403
- It pre-processes the images and the tokenizes prompts.
404
- """
405
-
406
- def __init__(
407
- self,
408
- instance_data_root,
409
- instance_prompt,
410
- tokenizer,
411
- class_data_root=None,
412
- class_prompt=None,
413
- size=512,
414
- center_crop=False,
415
- ):
416
- self.size = size
417
- self.center_crop = center_crop
418
- self.tokenizer = tokenizer
419
-
420
- self.instance_data_root = Path(instance_data_root)
421
- if not self.instance_data_root.exists():
422
- raise ValueError("Instance images root doesn't exists.")
423
-
424
- self.instance_images_path = list(Path(instance_data_root).iterdir())
425
- self.num_instance_images = len(self.instance_images_path)
426
- self.instance_prompt = instance_prompt
427
- self._length = self.num_instance_images
428
-
429
- if class_data_root is not None:
430
- self.class_data_root = Path(class_data_root)
431
- self.class_data_root.mkdir(parents=True, exist_ok=True)
432
- self.class_images_path = list(self.class_data_root.iterdir())
433
- self.num_class_images = len(self.class_images_path)
434
- self._length = max(self.num_class_images, self.num_instance_images)
435
- self.class_prompt = class_prompt
436
- else:
437
- self.class_data_root = None
438
-
439
- self.image_transforms = transforms.Compose(
440
- [
441
- transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
442
- transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
443
- transforms.ToTensor(),
444
- transforms.Normalize([0.5], [0.5]),
445
- ]
446
- )
447
-
448
- def __len__(self):
449
- return self._length
450
-
451
- def __getitem__(self, index):
452
- example = {}
453
- instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
454
- if not instance_image.mode == "RGB":
455
- instance_image = instance_image.convert("RGB")
456
- example["instance_images"] = self.image_transforms(instance_image)
457
- example["instance_prompt_ids"] = self.tokenizer(
458
- self.instance_prompt,
459
- truncation=True,
460
- padding="max_length",
461
- max_length=self.tokenizer.model_max_length,
462
- return_tensors="pt",
463
- ).input_ids
464
-
465
- if self.class_data_root:
466
- class_image = Image.open(self.class_images_path[index % self.num_class_images])
467
- if not class_image.mode == "RGB":
468
- class_image = class_image.convert("RGB")
469
- example["class_images"] = self.image_transforms(class_image)
470
- example["class_prompt_ids"] = self.tokenizer(
471
- self.class_prompt,
472
- truncation=True,
473
- padding="max_length",
474
- max_length=self.tokenizer.model_max_length,
475
- return_tensors="pt",
476
- ).input_ids
477
-
478
- return example
479
-
480
-
481
- def collate_fn(examples, with_prior_preservation=False):
482
- input_ids = [example["instance_prompt_ids"] for example in examples]
483
- pixel_values = [example["instance_images"] for example in examples]
484
-
485
- # Concat class and instance examples for prior preservation.
486
- # We do this to avoid doing two forward passes.
487
- if with_prior_preservation:
488
- input_ids += [example["class_prompt_ids"] for example in examples]
489
- pixel_values += [example["class_images"] for example in examples]
490
-
491
- pixel_values = torch.stack(pixel_values)
492
- pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
493
-
494
- input_ids = torch.cat(input_ids, dim=0)
495
-
496
- batch = {
497
- "input_ids": input_ids,
498
- "pixel_values": pixel_values,
499
- }
500
- return batch
501
-
502
-
503
- class PromptDataset(Dataset):
504
- "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
505
-
506
- def __init__(self, prompt, num_samples):
507
- self.prompt = prompt
508
- self.num_samples = num_samples
509
-
510
- def __len__(self):
511
- return self.num_samples
512
-
513
- def __getitem__(self, index):
514
- example = {}
515
- example["prompt"] = self.prompt
516
- example["index"] = index
517
- return example
518
-
519
-
520
- def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
521
- if token is None:
522
- token = HfFolder.get_token()
523
- if organization is None:
524
- username = whoami(token)["name"]
525
- return f"{username}/{model_id}"
526
- else:
527
- return f"{organization}/{model_id}"
528
-
529
-
530
- def main(args):
531
- logging_dir = Path(args.output_dir, args.logging_dir)
532
-
533
- accelerator = Accelerator(
534
- gradient_accumulation_steps=args.gradient_accumulation_steps,
535
- mixed_precision=args.mixed_precision,
536
- log_with=args.report_to,
537
- logging_dir=logging_dir,
538
- )
539
-
540
- if args.report_to == "wandb":
541
- if not is_wandb_available():
542
- raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
543
- import wandb
544
-
545
- # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
546
- # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
547
- # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
548
- # Make one log on every process with the configuration for debugging.
549
- logging.basicConfig(
550
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
551
- datefmt="%m/%d/%Y %H:%M:%S",
552
- level=logging.INFO,
553
- )
554
- logger.info(accelerator.state, main_process_only=False)
555
- if accelerator.is_local_main_process:
556
- datasets.utils.logging.set_verbosity_warning()
557
- transformers.utils.logging.set_verbosity_warning()
558
- diffusers.utils.logging.set_verbosity_info()
559
- else:
560
- datasets.utils.logging.set_verbosity_error()
561
- transformers.utils.logging.set_verbosity_error()
562
- diffusers.utils.logging.set_verbosity_error()
563
-
564
- # If passed along, set the training seed now.
565
- if args.seed is not None:
566
- set_seed(args.seed)
567
-
568
- # Generate class images if prior preservation is enabled.
569
- if args.with_prior_preservation:
570
- class_images_dir = Path(args.class_data_dir)
571
- if not class_images_dir.exists():
572
- class_images_dir.mkdir(parents=True)
573
- cur_class_images = len(list(class_images_dir.iterdir()))
574
-
575
- if cur_class_images < args.num_class_images:
576
- torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
577
- if args.prior_generation_precision == "fp32":
578
- torch_dtype = torch.float32
579
- elif args.prior_generation_precision == "fp16":
580
- torch_dtype = torch.float16
581
- elif args.prior_generation_precision == "bf16":
582
- torch_dtype = torch.bfloat16
583
- pipeline = DiffusionPipeline.from_pretrained(
584
- args.pretrained_model_name_or_path,
585
- torch_dtype=torch_dtype,
586
- safety_checker=None,
587
- revision=args.revision,
588
- )
589
- pipeline.set_progress_bar_config(disable=True)
590
-
591
- num_new_images = args.num_class_images - cur_class_images
592
- logger.info(f"Number of class images to sample: {num_new_images}.")
593
-
594
- sample_dataset = PromptDataset(args.class_prompt, num_new_images)
595
- sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
596
-
597
- sample_dataloader = accelerator.prepare(sample_dataloader)
598
- pipeline.to(accelerator.device)
599
-
600
- for example in tqdm(
601
- sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
602
- ):
603
- images = pipeline(example["prompt"]).images
604
-
605
- for i, image in enumerate(images):
606
- hash_image = hashlib.sha1(image.tobytes()).hexdigest()
607
- image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
608
- image.save(image_filename)
609
-
610
- del pipeline
611
- if torch.cuda.is_available():
612
- torch.cuda.empty_cache()
613
-
614
- # Handle the repository creation
615
- if accelerator.is_main_process:
616
- if args.push_to_hub:
617
- if args.hub_model_id is None:
618
- repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
619
- else:
620
- repo_name = args.hub_model_id
621
-
622
- create_repo(repo_name, exist_ok=True, token=args.hub_token)
623
- repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
624
-
625
- with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
626
- if "step_*" not in gitignore:
627
- gitignore.write("step_*\n")
628
- if "epoch_*" not in gitignore:
629
- gitignore.write("epoch_*\n")
630
- elif args.output_dir is not None:
631
- os.makedirs(args.output_dir, exist_ok=True)
632
-
633
- # Load the tokenizer
634
- if args.tokenizer_name:
635
- tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
636
- elif args.pretrained_model_name_or_path:
637
- tokenizer = AutoTokenizer.from_pretrained(
638
- args.pretrained_model_name_or_path,
639
- subfolder="tokenizer",
640
- revision=args.revision,
641
- use_fast=False,
642
- )
643
-
644
- # import correct text encoder class
645
- text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
646
-
647
- # Load scheduler and models
648
- noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
649
- text_encoder = text_encoder_cls.from_pretrained(
650
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
651
- )
652
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
653
- unet = UNet2DConditionModel.from_pretrained(
654
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
655
- )
656
-
657
- # We only train the additional adapter LoRA layers
658
- vae.requires_grad_(False)
659
- text_encoder.requires_grad_(False)
660
- unet.requires_grad_(False)
661
-
662
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
663
- # as these models are only used for inference, keeping weights in full precision is not required.
664
- weight_dtype = torch.float32
665
- if accelerator.mixed_precision == "fp16":
666
- weight_dtype = torch.float16
667
- elif accelerator.mixed_precision == "bf16":
668
- weight_dtype = torch.bfloat16
669
-
670
- # Move unet, vae and text_encoder to device and cast to weight_dtype
671
- unet.to(accelerator.device, dtype=weight_dtype)
672
- vae.to(accelerator.device, dtype=weight_dtype)
673
- text_encoder.to(accelerator.device, dtype=weight_dtype)
674
-
675
- if args.enable_xformers_memory_efficient_attention:
676
- if is_xformers_available():
677
- unet.enable_xformers_memory_efficient_attention()
678
- else:
679
- raise ValueError("xformers is not available. Make sure it is installed correctly")
680
-
681
- # now we will add new LoRA weights to the attention layers
682
- # It's important to realize here how many attention weights will be added and of which sizes
683
- # The sizes of the attention layers consist only of two different variables:
684
- # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
685
- # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
686
-
687
- # Let's first see how many attention processors we will have to set.
688
- # For Stable Diffusion, it should be equal to:
689
- # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
690
- # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
691
- # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
692
- # => 32 layers
693
-
694
- # Set correct lora layers
695
- lora_attn_procs = {}
696
- for name in unet.attn_processors.keys():
697
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
698
- if name.startswith("mid_block"):
699
- hidden_size = unet.config.block_out_channels[-1]
700
- elif name.startswith("up_blocks"):
701
- block_id = int(name[len("up_blocks.")])
702
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
703
- elif name.startswith("down_blocks"):
704
- block_id = int(name[len("down_blocks.")])
705
- hidden_size = unet.config.block_out_channels[block_id]
706
-
707
- lora_attn_procs[name] = LoRACrossAttnProcessor(
708
- hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
709
- )
710
-
711
- unet.set_attn_processor(lora_attn_procs)
712
- lora_layers = AttnProcsLayers(unet.attn_processors)
713
-
714
- accelerator.register_for_checkpointing(lora_layers)
715
-
716
- if args.scale_lr:
717
- args.learning_rate = (
718
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
719
- )
720
-
721
- # Enable TF32 for faster training on Ampere GPUs,
722
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
723
- if args.allow_tf32:
724
- torch.backends.cuda.matmul.allow_tf32 = True
725
-
726
- if args.scale_lr:
727
- args.learning_rate = (
728
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
729
- )
730
-
731
- # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
732
- if args.use_8bit_adam:
733
- try:
734
- import bitsandbytes as bnb
735
- except ImportError:
736
- raise ImportError(
737
- "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
738
- )
739
-
740
- optimizer_class = bnb.optim.AdamW8bit
741
- else:
742
- optimizer_class = torch.optim.AdamW
743
-
744
- # Optimizer creation
745
- optimizer = optimizer_class(
746
- lora_layers.parameters(),
747
- lr=args.learning_rate,
748
- betas=(args.adam_beta1, args.adam_beta2),
749
- weight_decay=args.adam_weight_decay,
750
- eps=args.adam_epsilon,
751
- )
752
-
753
- # Dataset and DataLoaders creation:
754
- train_dataset = DreamBoothDataset(
755
- instance_data_root=args.instance_data_dir,
756
- instance_prompt=args.instance_prompt,
757
- class_data_root=args.class_data_dir if args.with_prior_preservation else None,
758
- class_prompt=args.class_prompt,
759
- tokenizer=tokenizer,
760
- size=args.resolution,
761
- center_crop=args.center_crop,
762
- )
763
-
764
- train_dataloader = torch.utils.data.DataLoader(
765
- train_dataset,
766
- batch_size=args.train_batch_size,
767
- shuffle=True,
768
- collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
769
- num_workers=args.dataloader_num_workers,
770
- )
771
-
772
- # Scheduler and math around the number of training steps.
773
- overrode_max_train_steps = False
774
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
775
- if args.max_train_steps is None:
776
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
777
- overrode_max_train_steps = True
778
-
779
- lr_scheduler = get_scheduler(
780
- args.lr_scheduler,
781
- optimizer=optimizer,
782
- num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
783
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
784
- num_cycles=args.lr_num_cycles,
785
- power=args.lr_power,
786
- )
787
-
788
- # Prepare everything with our `accelerator`.
789
- lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
790
- lora_layers, optimizer, train_dataloader, lr_scheduler
791
- )
792
-
793
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
794
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
795
- if overrode_max_train_steps:
796
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
797
- # Afterwards we recalculate our number of training epochs
798
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
799
-
800
- # We need to initialize the trackers we use, and also store our configuration.
801
- # The trackers initializes automatically on the main process.
802
- if accelerator.is_main_process:
803
- accelerator.init_trackers("dreambooth-lora", config=vars(args))
804
-
805
- # Train!
806
- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
807
-
808
- logger.info("***** Running training *****")
809
- logger.info(f" Num examples = {len(train_dataset)}")
810
- logger.info(f" Num batches each epoch = {len(train_dataloader)}")
811
- logger.info(f" Num Epochs = {args.num_train_epochs}")
812
- logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
813
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
814
- logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
815
- logger.info(f" Total optimization steps = {args.max_train_steps}")
816
- global_step = 0
817
- first_epoch = 0
818
-
819
- # Potentially load in the weights and states from a previous save
820
- if args.resume_from_checkpoint:
821
- if args.resume_from_checkpoint != "latest":
822
- path = os.path.basename(args.resume_from_checkpoint)
823
- else:
824
- # Get the mos recent checkpoint
825
- dirs = os.listdir(args.output_dir)
826
- dirs = [d for d in dirs if d.startswith("checkpoint")]
827
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
828
- path = dirs[-1] if len(dirs) > 0 else None
829
-
830
- if path is None:
831
- accelerator.print(
832
- f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
833
- )
834
- args.resume_from_checkpoint = None
835
- else:
836
- accelerator.print(f"Resuming from checkpoint {path}")
837
- accelerator.load_state(os.path.join(args.output_dir, path))
838
- global_step = int(path.split("-")[1])
839
-
840
- resume_global_step = global_step * args.gradient_accumulation_steps
841
- first_epoch = global_step // num_update_steps_per_epoch
842
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
843
-
844
- # Only show the progress bar once on each machine.
845
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
846
- progress_bar.set_description("Steps")
847
-
848
- for epoch in range(first_epoch, args.num_train_epochs):
849
- unet.train()
850
- for step, batch in enumerate(train_dataloader):
851
- # Skip steps until we reach the resumed step
852
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
853
- if step % args.gradient_accumulation_steps == 0:
854
- progress_bar.update(1)
855
- continue
856
-
857
- with accelerator.accumulate(unet):
858
- # Convert images to latent space
859
- latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
860
- latents = latents * 0.18215
861
-
862
- # Sample noise that we'll add to the latents
863
- noise = torch.randn_like(latents)
864
- bsz = latents.shape[0]
865
- # Sample a random timestep for each image
866
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
867
- timesteps = timesteps.long()
868
-
869
- # Add noise to the latents according to the noise magnitude at each timestep
870
- # (this is the forward diffusion process)
871
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
872
-
873
- # Get the text embedding for conditioning
874
- encoder_hidden_states = text_encoder(batch["input_ids"])[0]
875
-
876
- # Predict the noise residual
877
- model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
878
-
879
- # Get the target for loss depending on the prediction type
880
- if noise_scheduler.config.prediction_type == "epsilon":
881
- target = noise
882
- elif noise_scheduler.config.prediction_type == "v_prediction":
883
- target = noise_scheduler.get_velocity(latents, noise, timesteps)
884
- else:
885
- raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
886
-
887
- if args.with_prior_preservation:
888
- # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
889
- model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
890
- target, target_prior = torch.chunk(target, 2, dim=0)
891
-
892
- # Compute instance loss
893
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
894
-
895
- # Compute prior loss
896
- prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
897
-
898
- # Add the prior loss to the instance loss.
899
- loss = loss + args.prior_loss_weight * prior_loss
900
- else:
901
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
902
-
903
- accelerator.backward(loss)
904
- if accelerator.sync_gradients:
905
- params_to_clip = lora_layers.parameters()
906
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
907
- optimizer.step()
908
- lr_scheduler.step()
909
- optimizer.zero_grad()
910
-
911
- # Checks if the accelerator has performed an optimization step behind the scenes
912
- if accelerator.sync_gradients:
913
- progress_bar.update(1)
914
- global_step += 1
915
-
916
- if global_step % args.checkpointing_steps == 0:
917
- if accelerator.is_main_process:
918
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
919
- accelerator.save_state(save_path)
920
- logger.info(f"Saved state to {save_path}")
921
-
922
- logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
923
- progress_bar.set_postfix(**logs)
924
- accelerator.log(logs, step=global_step)
925
-
926
- if global_step >= args.max_train_steps:
927
- break
928
-
929
- if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
930
- logger.info(
931
- f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
932
- f" {args.validation_prompt}."
933
- )
934
- # create pipeline
935
- pipeline = DiffusionPipeline.from_pretrained(
936
- args.pretrained_model_name_or_path,
937
- unet=accelerator.unwrap_model(unet),
938
- text_encoder=accelerator.unwrap_model(text_encoder),
939
- revision=args.revision,
940
- torch_dtype=weight_dtype,
941
- )
942
- pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
943
- pipeline = pipeline.to(accelerator.device)
944
- pipeline.set_progress_bar_config(disable=True)
945
-
946
- # run inference
947
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
948
- prompt = args.num_validation_images * [args.validation_prompt]
949
- images = pipeline(prompt, num_inference_steps=25, generator=generator).images
950
-
951
- for tracker in accelerator.trackers:
952
- if tracker.name == "tensorboard":
953
- np_images = np.stack([np.asarray(img) for img in images])
954
- tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
955
- if tracker.name == "wandb":
956
- tracker.log(
957
- {
958
- "validation": [
959
- wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
960
- for i, image in enumerate(images)
961
- ]
962
- }
963
- )
964
-
965
- del pipeline
966
- torch.cuda.empty_cache()
967
-
968
- # Save the lora layers
969
- accelerator.wait_for_everyone()
970
- if accelerator.is_main_process:
971
- unet = unet.to(torch.float32)
972
- unet.save_attn_procs(args.output_dir)
973
-
974
- # Final inference
975
- # Load previous pipeline
976
- pipeline = DiffusionPipeline.from_pretrained(
977
- args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
978
- )
979
- pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
980
- pipeline = pipeline.to(accelerator.device)
981
-
982
- # load attention processors
983
- pipeline.unet.load_attn_procs(args.output_dir)
984
-
985
- # run inference
986
- if args.validation_prompt and args.num_validation_images > 0:
987
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
988
- prompt = args.num_validation_images * [args.validation_prompt]
989
- images = pipeline(prompt, num_inference_steps=25, generator=generator).images
990
-
991
- test_image_dir = Path(args.output_dir) / 'test_images'
992
- test_image_dir.mkdir()
993
- for i, image in enumerate(images):
994
- out_path = test_image_dir / f'image_{i}.png'
995
- image.save(out_path)
996
-
997
- for tracker in accelerator.trackers:
998
- if tracker.name == "tensorboard":
999
- np_images = np.stack([np.asarray(img) for img in images])
1000
- tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1001
- if tracker.name == "wandb":
1002
- tracker.log(
1003
- {
1004
- "test": [
1005
- wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1006
- for i, image in enumerate(images)
1007
- ]
1008
- }
1009
- )
1010
-
1011
- if args.push_to_hub:
1012
- save_model_card(
1013
- repo_name,
1014
- images=images,
1015
- base_model=args.pretrained_model_name_or_path,
1016
- prompt=args.instance_prompt,
1017
- repo_folder=args.output_dir,
1018
- )
1019
- repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
1020
-
1021
- accelerator.end_training()
1022
-
1023
-
1024
- if __name__ == "__main__":
1025
- args = parse_args()
1026
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trainer.py CHANGED
@@ -6,61 +6,52 @@ import pathlib
6
  import shlex
7
  import shutil
8
  import subprocess
 
9
 
10
  import gradio as gr
11
- import PIL.Image
12
  import slugify
13
  import torch
14
  from huggingface_hub import HfApi
 
15
 
16
- from app_upload import LoRAModelUploader
17
  from utils import save_model_card
18
 
19
- URL_TO_JOIN_LORA_LIBRARY_ORG = 'https://huggingface.co/organizations/lora-library/share/hjetHAcKjnPHXhHfbeEcqnBqmhgilFfpOL'
20
 
21
-
22
- def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
23
- w, h = image.size
24
- if w == h:
25
- return image
26
- elif w > h:
27
- new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
28
- new_image.paste(image, (0, (w - h) // 2))
29
- return new_image
30
- else:
31
- new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
32
- new_image.paste(image, ((h - w) // 2, 0))
33
- return new_image
34
 
35
 
36
  class Trainer:
37
  def __init__(self, hf_token: str | None = None):
38
  self.hf_token = hf_token
39
  self.api = HfApi(token=hf_token)
40
- self.model_uploader = LoRAModelUploader(hf_token)
41
-
42
- def prepare_dataset(self, instance_images: list, resolution: int,
43
- instance_data_dir: pathlib.Path) -> None:
44
- shutil.rmtree(instance_data_dir, ignore_errors=True)
45
- instance_data_dir.mkdir(parents=True)
46
- for i, temp_path in enumerate(instance_images):
47
- image = PIL.Image.open(temp_path.name)
48
- image = pad_image(image)
49
- image = image.resize((resolution, resolution))
50
- image = image.convert('RGB')
51
- out_path = instance_data_dir / f'{i:03d}.jpg'
52
- image.save(out_path, format='JPEG', quality=100)
53
-
54
- def join_lora_library_org(self) -> None:
 
 
55
  subprocess.run(
56
  shlex.split(
57
- f'curl -X POST -H "Authorization: Bearer {self.hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_LORA_LIBRARY_ORG}'
58
  ))
59
 
60
  def run(
61
  self,
62
- instance_images: list | None,
63
- instance_prompt: str,
64
  output_model_name: str,
65
  overwrite_existing_model: bool,
66
  validation_prompt: str,
@@ -73,7 +64,6 @@ class Trainer:
73
  fp16: bool,
74
  use_8bit_adam: bool,
75
  checkpointing_steps: int,
76
- use_wandb: bool,
77
  validation_epochs: int,
78
  upload_to_hub: bool,
79
  use_private_repo: bool,
@@ -83,10 +73,10 @@ class Trainer:
83
  ) -> str:
84
  if not torch.cuda.is_available():
85
  raise gr.Error('CUDA is not available.')
86
- if instance_images is None:
87
- raise gr.Error('You need to upload images.')
88
- if not instance_prompt:
89
- raise gr.Error('The instance prompt is missing.')
90
  if not validation_prompt:
91
  raise gr.Error('The validation prompt is missing.')
92
 
@@ -94,7 +84,7 @@ class Trainer:
94
 
95
  if not output_model_name:
96
  timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
97
- output_model_name = f'lora-dreambooth-{timestamp}'
98
  output_model_name = slugify.slugify(output_model_name)
99
 
100
  repo_dir = pathlib.Path(__file__).parent
@@ -103,52 +93,52 @@ class Trainer:
103
  shutil.rmtree(output_dir, ignore_errors=True)
104
  output_dir.mkdir(parents=True)
105
 
106
- instance_data_dir = repo_dir / 'training_data' / output_model_name
107
- self.prepare_dataset(instance_images, resolution, instance_data_dir)
108
-
109
  if upload_to_hub:
110
- self.join_lora_library_org()
111
-
112
- command = f'''
113
- accelerate launch train_dreambooth_lora.py \
114
- --pretrained_model_name_or_path={base_model} \
115
- --instance_data_dir={instance_data_dir} \
116
- --output_dir={output_dir} \
117
- --instance_prompt="{instance_prompt}" \
118
- --resolution={resolution} \
119
- --train_batch_size=1 \
120
- --gradient_accumulation_steps={gradient_accumulation} \
121
- --learning_rate={learning_rate} \
122
- --lr_scheduler=constant \
123
- --lr_warmup_steps=0 \
124
- --max_train_steps={n_steps} \
125
- --checkpointing_steps={checkpointing_steps} \
126
- --validation_prompt="{validation_prompt}" \
127
- --validation_epochs={validation_epochs} \
128
- --seed={seed}
129
- '''
130
- if fp16:
131
- command += ' --mixed_precision fp16'
132
- if use_8bit_adam:
133
- command += ' --use_8bit_adam'
134
- if use_wandb:
135
- command += ' --report_to wandb'
136
-
137
- with open(output_dir / 'train.sh', 'w') as f:
138
- command_s = ' '.join(command.split())
139
- f.write(command_s)
 
 
 
140
  subprocess.run(shlex.split(command))
141
  save_model_card(save_dir=output_dir,
142
  base_model=base_model,
143
- instance_prompt=instance_prompt,
144
  test_prompt=validation_prompt,
145
- test_image_dir='test_images')
146
 
147
  message = 'Training completed!'
148
  print(message)
149
 
150
  if upload_to_hub:
151
- upload_message = self.model_uploader.upload_lora_model(
152
  folder_path=output_dir.as_posix(),
153
  repo_name=output_model_name,
154
  upload_to=upload_to,
6
  import shlex
7
  import shutil
8
  import subprocess
9
+ import sys
10
 
11
  import gradio as gr
 
12
  import slugify
13
  import torch
14
  from huggingface_hub import HfApi
15
+ from omegaconf import OmegaConf
16
 
17
+ from app_upload import ModelUploader
18
  from utils import save_model_card
19
 
20
+ sys.path.append('Tune-A-Video')
21
 
22
+ URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  class Trainer:
26
  def __init__(self, hf_token: str | None = None):
27
  self.hf_token = hf_token
28
  self.api = HfApi(token=hf_token)
29
+ self.model_uploader = ModelUploader(hf_token)
30
+
31
+ self.checkpoint_dir = pathlib.Path('checkpoints')
32
+ self.checkpoint_dir.mkdir(exist_ok=True)
33
+
34
+ def download_base_model(self, base_model_id: str) -> str:
35
+ model_dir = self.checkpoint_dir / base_model_id
36
+ if not model_dir.exists():
37
+ org_name = base_model_id.split('/')[0]
38
+ org_dir = self.checkpoint_dir / org_name
39
+ org_dir.mkdir(exist_ok=True)
40
+ subprocess.run(shlex.split(
41
+ f'git clone https://huggingface.co/{base_model_id}'),
42
+ cwd=org_dir)
43
+ return model_dir.as_posix()
44
+
45
+ def join_model_library_org(self) -> None:
46
  subprocess.run(
47
  shlex.split(
48
+ f'curl -X POST -H "Authorization: Bearer {self.hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
49
  ))
50
 
51
  def run(
52
  self,
53
+ training_video: str,
54
+ training_prompt: str,
55
  output_model_name: str,
56
  overwrite_existing_model: bool,
57
  validation_prompt: str,
64
  fp16: bool,
65
  use_8bit_adam: bool,
66
  checkpointing_steps: int,
 
67
  validation_epochs: int,
68
  upload_to_hub: bool,
69
  use_private_repo: bool,
73
  ) -> str:
74
  if not torch.cuda.is_available():
75
  raise gr.Error('CUDA is not available.')
76
+ if training_video is None:
77
+ raise gr.Error('You need to upload a video.')
78
+ if not training_prompt:
79
+ raise gr.Error('The training prompt is missing.')
80
  if not validation_prompt:
81
  raise gr.Error('The validation prompt is missing.')
82
 
84
 
85
  if not output_model_name:
86
  timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
87
+ output_model_name = f'tune-a-video-{timestamp}'
88
  output_model_name = slugify.slugify(output_model_name)
89
 
90
  repo_dir = pathlib.Path(__file__).parent
93
  shutil.rmtree(output_dir, ignore_errors=True)
94
  output_dir.mkdir(parents=True)
95
 
 
 
 
96
  if upload_to_hub:
97
+ self.join_model_library_org()
98
+
99
+ config = OmegaConf.load('Tune-A-Video/configs/man-surfing.yaml')
100
+ config.pretrained_model_path = self.download_base_model(base_model)
101
+ config.output_dir = output_dir.as_posix()
102
+ config.train_data.video_path = training_video.name # type: ignore
103
+ config.train_data.prompt = training_prompt
104
+ config.train_data.n_sample_frames = 8
105
+ config.train_data.width = resolution
106
+ config.train_data.height = resolution
107
+ config.train_data.sample_start_idx = 0
108
+ config.train_data.sample_frame_rate = 1
109
+ config.validation_data.prompts = [validation_prompt]
110
+ config.validation_data.video_length = 8
111
+ config.validation_data.width = resolution
112
+ config.validation_data.height = resolution
113
+ config.validation_data.num_inference_steps = 50
114
+ config.validation_data.guidance_scale = 7.5
115
+ config.learning_rate = learning_rate
116
+ config.gradient_accumulation_steps = gradient_accumulation
117
+ config.train_batch_size = 1
118
+ config.max_train_steps = n_steps
119
+ config.checkpointing_steps = checkpointing_steps
120
+ config.validation_steps = validation_epochs
121
+ config.seed = seed
122
+ config.mixed_precision = 'fp16' if fp16 else ''
123
+ config.use_8bit_adam = use_8bit_adam
124
+
125
+ config_path = output_dir / 'config.yaml'
126
+ with open(config_path, 'w') as f:
127
+ OmegaConf.save(config, f)
128
+
129
+ command = f'accelerate launch Tune-A-Video/train_tuneavideo.py --config {config_path}'
130
  subprocess.run(shlex.split(command))
131
  save_model_card(save_dir=output_dir,
132
  base_model=base_model,
133
+ training_prompt=training_prompt,
134
  test_prompt=validation_prompt,
135
+ test_image_dir='samples')
136
 
137
  message = 'Training completed!'
138
  print(message)
139
 
140
  if upload_to_hub:
141
+ upload_message = self.model_uploader.upload_model(
142
  folder_path=output_dir.as_posix(),
143
  repo_name=output_model_name,
144
  upload_to=upload_to,
utils.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
  import pathlib
4
 
5
 
6
- def find_exp_dirs(ignore_repo: bool = False) -> list[str]:
7
  repo_dir = pathlib.Path(__file__).parent
8
  exp_root_dir = repo_dir / 'experiments'
9
  if not exp_root_dir.exists():
@@ -11,46 +11,45 @@ def find_exp_dirs(ignore_repo: bool = False) -> list[str]:
11
  exp_dirs = sorted(exp_root_dir.glob('*'))
12
  exp_dirs = [
13
  exp_dir for exp_dir in exp_dirs
14
- if (exp_dir / 'pytorch_lora_weights.bin').exists()
15
  ]
16
- if ignore_repo:
17
- exp_dirs = [
18
- exp_dir for exp_dir in exp_dirs if not (exp_dir / '.git').exists()
19
- ]
20
  return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
21
 
22
 
23
  def save_model_card(
24
  save_dir: pathlib.Path,
25
  base_model: str,
26
- instance_prompt: str,
27
  test_prompt: str = '',
28
  test_image_dir: str = '',
29
  ) -> None:
30
  image_str = ''
31
  if test_prompt and test_image_dir:
32
- image_paths = sorted((save_dir / test_image_dir).glob('*'))
33
  if image_paths:
34
- image_str = f'Test prompt: {test_prompt}\n'
35
- for image_path in image_paths:
36
- rel_path = image_path.relative_to(save_dir)
37
- image_str += f'![{image_path.stem}]({rel_path})\n'
38
 
39
  model_card = f'''---
40
  license: creativeml-openrail-m
41
  base_model: {base_model}
42
- instance_prompt: {instance_prompt}
43
  tags:
44
  - stable-diffusion
45
  - stable-diffusion-diffusers
46
  - text-to-image
47
  - diffusers
48
- - lora
49
- inference: true
 
50
  ---
51
- # LoRA DreamBooth - {save_dir.name}
52
 
53
- These are LoRA adaption weights for [{base_model}](https://huggingface.co/{base_model}). The weights were trained on the instance prompt "{instance_prompt}" using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following.
 
 
 
 
54
 
55
  {image_str}
56
  '''
3
  import pathlib
4
 
5
 
6
+ def find_exp_dirs() -> list[str]:
7
  repo_dir = pathlib.Path(__file__).parent
8
  exp_root_dir = repo_dir / 'experiments'
9
  if not exp_root_dir.exists():
11
  exp_dirs = sorted(exp_root_dir.glob('*'))
12
  exp_dirs = [
13
  exp_dir for exp_dir in exp_dirs
14
+ if (exp_dir / 'model_index.json').exists()
15
  ]
 
 
 
 
16
  return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
17
 
18
 
19
  def save_model_card(
20
  save_dir: pathlib.Path,
21
  base_model: str,
22
+ training_prompt: str,
23
  test_prompt: str = '',
24
  test_image_dir: str = '',
25
  ) -> None:
26
  image_str = ''
27
  if test_prompt and test_image_dir:
28
+ image_paths = sorted((save_dir / test_image_dir).glob('*.gif'))
29
  if image_paths:
30
+ image_path = image_paths[-1]
31
+ rel_path = image_path.relative_to(save_dir)
32
+ image_str = f'Test prompt: {test_prompt}\n' + f'![{image_path.stem}]({rel_path})\n'
 
33
 
34
  model_card = f'''---
35
  license: creativeml-openrail-m
36
  base_model: {base_model}
37
+ training_prompt: {training_prompt}
38
  tags:
39
  - stable-diffusion
40
  - stable-diffusion-diffusers
41
  - text-to-image
42
  - diffusers
43
+ - text-to-video
44
+ - tune-a-video
45
+ inference: false
46
  ---
 
47
 
48
+ # Tune-A-Video - {save_dir.name}
49
+
50
+ Base model: [{base_model}](https://huggingface.co/{base_model}).
51
+
52
+ Training prompt: {training_prompt}
53
 
54
  {image_str}
55
  '''
wheel/xformers-0.0.16+bc08bbc.d20230130-cp310-cp310-linux_x86_64.whl ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:016219e017ce06b351ef0f98fc074ee60be06ee1d700cfe0a45c9b59e25bb938
3
+ size 134437916