valhalla commited on
Commit
f15f2f0
·
1 Parent(s): 4bccd61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -18
app.py CHANGED
@@ -1,31 +1,21 @@
1
  import gradio as gr
2
- from datasets import load_dataset
3
  from PIL import Image
4
 
5
  import re
6
  import os
7
  import requests
8
 
 
 
9
  from share_btn import community_icon_html, loading_icon_html, share_js
10
 
11
- word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
12
- word_list = word_list_dataset["train"]['text']
 
 
13
 
14
- is_gpu_busy = False
15
- def infer(prompt, negative, scale):
16
- global is_gpu_busy
17
- for filter in word_list:
18
- if re.search(rf"\b{filter}\b", prompt):
19
- raise gr.Error("Unsafe content found. Please try again with different prompts.")
20
-
21
- images = []
22
- url = os.getenv('JAX_BACKEND_URL')
23
- payload = {'prompt': prompt, 'negative_prompt': negative, 'guidance_scale': scale}
24
- images_request = requests.post(url, json = payload)
25
- for image in images_request.json()["images"]:
26
- image_b64 = (f"data:image/jpeg;base64,{image}")
27
- images.append(image_b64)
28
-
29
  return images
30
 
31
 
 
1
  import gradio as gr
 
2
  from PIL import Image
3
 
4
  import re
5
  import os
6
  import requests
7
 
8
+ from muse import PipelineMuse
9
+
10
  from share_btn import community_icon_html, loading_icon_html, share_js
11
 
12
+ device = "cuda" if torch.cuda.is_available()
13
+ pipe = PipelineMuse.from_pretrained("openMUSE/muse-cc12m-200k").to(device)
14
+ if device == "cuda":
15
+ pipe.transformer.enable_xformers_memory_efficient_attention()
16
 
17
+ def infer(prompt, negative, scale):
18
+ images = pipe(prompt, timesteps=12, guidance_scale=scale, num_images_per_prompt=4, use_fp16=device == "cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  return images
20
 
21