V2V example encounter OOM in 80G A100.

#25
by guyuchao - opened

I try to run the provided example code, and encounter the OOM error at the first spatial attention. I print the shape and dtype at the OOM position: query: torch.Size([240, 9216, 64]) torch.float16 key: torch.Size([240, 9216, 64]) torch.float16. Have you encounter such problem?

Screenshot 2023-07-24 at 9.16.28 AM.png

Screenshot 2023-07-24 at 9.19.07 AM.png

Does it need to run with xformer or some other techniques?

Same problem, apparently needs xformers, but I have no idea how to implement it to the python script.

Same problem, apparently needs xformers, but I have no idea how to implement it to the python script.

Just add pipe.enable_xformers_memory_efficient_attention() and the gpu memory usage will be reduced significantly.
For me, it takes ~20.1GB for rendering 1024x576 at 32 frames.

Sign up or log in to comment