muhammad.i.fidatama commited on
Commit
fef08a3
1 Parent(s): 154c710

add accelerate

Browse files
Files changed (2) hide show
  1. app.py +4 -2
  2. requirements.txt +2 -1
app.py CHANGED
@@ -7,7 +7,7 @@ from io import BytesIO
7
  import streamlit as st
8
  from diffusers import StableDiffusionPipeline
9
  import torch
10
-
11
  #model_id = "CompVis/stable-diffusion-v1-4"
12
  #pipe = StableDiffusionPipeline.from_pretrained(model_id)
13
 
@@ -15,6 +15,7 @@ import torch
15
  #pipe = pipe.to("cpu")
16
 
17
  image_html = ""
 
18
 
19
  # Function to display an example image
20
  def display_example_image(url):
@@ -26,7 +27,8 @@ def display_example_image(url):
26
  def generate_images_using_huggingface_diffusers(text):
27
  # pipe = StableDiffusionPipeline.from_pretrained("sd-dreambooth-library/cat-toy", torch_dtype=torch.float16)
28
  pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32)
29
- pipe = pipe.to("cpu")
 
30
  prompt = text
31
  image = pipe(prompt).images[0]
32
  return image
 
7
  import streamlit as st
8
  from diffusers import StableDiffusionPipeline
9
  import torch
10
+ from accelerate import Accelerator
11
  #model_id = "CompVis/stable-diffusion-v1-4"
12
  #pipe = StableDiffusionPipeline.from_pretrained(model_id)
13
 
 
15
  #pipe = pipe.to("cpu")
16
 
17
  image_html = ""
18
+ accelerator = Accelerator()
19
 
20
  # Function to display an example image
21
  def display_example_image(url):
 
27
  def generate_images_using_huggingface_diffusers(text):
28
  # pipe = StableDiffusionPipeline.from_pretrained("sd-dreambooth-library/cat-toy", torch_dtype=torch.float16)
29
  pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32)
30
+ # pipe = pipe.to("cpu")
31
+ pipe = pipe.to(accelerator.device)
32
  prompt = text
33
  image = pipe(prompt).images[0]
34
  return image
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  diffusers==0.29.2
2
  transformers==4.42.3
3
- torch==2.3.1
 
 
1
  diffusers==0.29.2
2
  transformers==4.42.3
3
+ torch==2.3.1
4
+ accelerate==0.31.0