Has anyone run this using an M1?

#17
by sgt101 - opened

I tried changing the device to MPS... but no dice..

torch.backends.mps.is_available() is true - but autocast is defeating me

Any ideas?

This script works on Apple M1

import requests
from PIL import Image
from io import BytesIO

from torch import autocast

from image_to_image import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'running on {device}')

pipei2i = StableDiffusionImg2ImgPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    #revision="fp16", 
    #torch_dtype=torch.float16,
    use_auth_token=True
).to(device)


response = requests.get('https://pbs.twimg.com/media/Fa1_7_vWYAEwfX-.png')
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))
init_image = preprocess(init_image)

prompt = "a cat, artstation"
samples = 2
steps = 50
strength = 0.75
scale = 7.5
outputs = []
if device=='cuda':
    with autocast("cuda"):
        outputs = pipei2i(prompt=[prompt]*samples, 
            init_image=init_image, 
            strength=strength,
            num_inference_steps=steps,
            guidance_scale=scale)
else:
    outputs = pipei2i(prompt=[prompt]*samples, 
        init_image=init_image, 
        strength=strength,
        num_inference_steps=steps,
        guidance_scale=scale)

safe_images = []
unsafe_images = []
# {'sample': [<PIL.Image.Image image mode=RGB size=512x512 at 0x7FEE48615510>], 'nsfw_content_detected': [False]}
for i, image in enumerate(outputs["sample"]):
    if(outputs["nsfw_content_detected"][i]):
        unsafe_images.append(image)
    else:
        safe_images.append(image)

for (index,image) in enumerate(safe_images):
    image.save(f"safe_{index}.png")
for (index,image) in enumerate(unsafe_images):
    image.save(f"unsafe_{index}.png")

requirements are

scipy
torch
transformers
diffusers
ftfy

This is what I can see from Activity Monitor during the operation
Screenshot 2022-08-26 at 20.57.31.png

See here for more details and dependencies.

@patrickvonplaten I now get

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

setting the fallback export PYTORCH_ENABLE_MPS_FALLBACK=1; it works as I can clearly see it from logs:

 UserWarning: The operator 'aten::index.Tensor' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at  /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:11.)
  pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]

and in the code torch it recognizes mps and I suppose loading the pipeline with fp32

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'mps' if torch.backends.mps.is_available() else device
print(f'running on {device}')

...

pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-3", 
    scheduler=lms,
    use_auth_token=True
).to(device)

I'm using the latest diffusers, while in the thread you shared some possibile solution is using forks / hacks / etc.

The whole traceback is

Traceback (most recent call last):
  File "diffuser.py", line 65, in <module>
    guidance_scale=scale,
  File "/Projects/bloom/.venv/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/Projects/bloom/pipeline_stable_diffusion.py", line 142, in __call__
    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
  File "/Users/musixmatch/Documents/Projects/bloom/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Projects/bloom/.venv/lib/python3.7/site-packages/diffusers/models/unet_2d_condition.py", line 134, in forward
    timesteps = timesteps[None].to(sample.device)
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Thanks!

CompVis org

Hi @loretoparisi !

We are working on official support for mps in Diffusers, hold tight for a couple days :)

Sign up or log in to comment