CUDA OOM if using with_prior_preservation v2

#11
by acul3 - opened

i got OOM error when using with_prior_preservation for person training. model v2_512
it is okay when using 1.5 model

maybe related issue : https://github.com/huggingface/diffusers/issues/696

Traceback :

  File "/home/ubuntu/env/lib/python3.9/site-packages/gradio/routes.py", line 292, in run_predict
    output = await app.blocks.process_api(
  File "/home/ubuntu/env/lib/python3.9/site-packages/gradio/blocks.py", line 1007, in process_api
    result = await self.call_function(fn_index, inputs, iterator, request)
  File "/home/ubuntu/env/lib/python3.9/site-packages/gradio/blocks.py", line 848, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "/home/ubuntu/env/lib/python3.9/site-packages/anyio/to_thread.py", line 31, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
  File "/home/ubuntu/env/lib/python3.9/site-packages/anyio/_backends/_asyncio.py", line 937, in run_sync_in_worker_thread
    return await future
  File "/home/ubuntu/env/lib/python3.9/site-packages/anyio/_backends/_asyncio.py", line 867, in run
    result = context.run(func, *args)
  File "/home/ubuntu/dreambooth-training/app.py", line 253, in train
    run_training(args_general)
  File "/home/ubuntu/dreambooth-training/train_dreambooth.py", line 737, in run_training
    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
  File "/home/ubuntu/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/env/lib/python3.9/site-packages/accelerate/utils/operations.py", line 507, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/ubuntu/env/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/ubuntu/env/lib/python3.9/site-packages/diffusers/models/unet_2d_condition.py", line 367, in forward
    sample = upsample_block(
  File "/home/ubuntu/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/env/lib/python3.9/site-packages/diffusers/models/unet_2d_blocks.py", line 1249, in forward
    hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
  File "/home/ubuntu/env/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/home/ubuntu/env/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 107, in forward
    outputs = run_function(*args)
  File "/home/ubuntu/env/lib/python3.9/site-packages/diffusers/models/unet_2d_blocks.py", line 1245, in custom_forward
    return module(*inputs)
  File "/home/ubuntu/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/env/lib/python3.9/site-packages/diffusers/models/resnet.py", line 450, in forward
    hidden_states = self.norm1(hidden_states)
  File "/home/ubuntu/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/env/lib/python3.9/site-packages/torch/nn/modules/normalization.py", line 273, in forward
    return F.group_norm(
  File "/home/ubuntu/env/lib/python3.9/site-packages/torch/nn/functional.py", line 2528, in group_norm
    return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled)
acul3 changed discussion status to closed

Sign up or log in to comment