cocktailpeanut commited on
Commit
4104e3d
1 Parent(s): 32b7aba
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -22,15 +22,18 @@ else:
22
  MAX_SEED = np.iinfo(np.int32).max
23
  CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES", "1") == "1"
24
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
25
- #TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
26
 
27
 
28
  pipe_id = "SPRIGHT-T2I/spright-t2i-sd2"
29
  pipe = DiffusionPipeline.from_pretrained(
30
  pipe_id,
31
- torch_dtype=torch.float16,
32
  use_safetensors=True,
33
- # token=TOKEN,
34
  ).to(device)
35
 
36
 
@@ -114,14 +117,14 @@ with gr.Blocks(css="style.css") as demo:
114
  minimum=256,
115
  maximum=MAX_IMAGE_SIZE,
116
  step=32,
117
- value=1024,
118
  )
119
  height = gr.Slider(
120
  label="Height",
121
  minimum=256,
122
  maximum=MAX_IMAGE_SIZE,
123
  step=32,
124
- value=1024,
125
  )
126
  with gr.Row():
127
  guidance_scale = gr.Slider(
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
  CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES", "1") == "1"
24
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
25
+ DEFAULT_IMAGE_SIZE = 1024
26
+ torch_dtype = torch.float16
27
+ if device == "cpu" or device == "mps":
28
+ DEFAULT_IMAGE_SIZE = 512
29
+ torch_dtype = torch.float32
30
 
31
 
32
  pipe_id = "SPRIGHT-T2I/spright-t2i-sd2"
33
  pipe = DiffusionPipeline.from_pretrained(
34
  pipe_id,
35
+ torch_dtype=torch_dtype,
36
  use_safetensors=True,
 
37
  ).to(device)
38
 
39
 
 
117
  minimum=256,
118
  maximum=MAX_IMAGE_SIZE,
119
  step=32,
120
+ value=DEFAULT_IMAGE_SIZE,
121
  )
122
  height = gr.Slider(
123
  label="Height",
124
  minimum=256,
125
  maximum=MAX_IMAGE_SIZE,
126
  step=32,
127
+ value=DEFAULT_IMAGE_SIZE,
128
  )
129
  with gr.Row():
130
  guidance_scale = gr.Slider(