Model does not work with device set to `mps` #2

by akbir - opened

Inference on M1 GPU (device=mps) does not work.

Shapes end up being wrong here, not sure if this is a pytorch code or bad model implementation.

Can you share a snippet of code to reproduce?

This comment has been hidden
This comment has been hidden

Sign up or log in to comment