callum-canavan commited on
Commit
6d057c8
·
1 Parent(s): b6fc2cb

Change to diffusion

Browse files
Files changed (2) hide show
  1. app.py +57 -13
  2. requirements.txt +6 -1
app.py CHANGED
@@ -1,24 +1,68 @@
 
 
1
  import gradio as gr
2
- from transformers import pipeline
3
 
4
- pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  def predict(input_img):
8
- predictions = pipeline(input_img)
9
- return input_img, {p["label"]: p["score"] for p in predictions}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  gradio_app = gr.Interface(
13
- predict,
14
- inputs=gr.Image(
15
- label="Select hot dog candidate", sources=["upload", "webcam"], type="pil"
16
- ),
17
- outputs=[
18
- gr.Image(label="Processed Image"),
19
- gr.Label(label="Result", num_top_classes=2),
20
- ],
21
- title="Hot Dog? Or Not?",
22
  )
23
 
24
  if __name__ == "__main__":
 
1
+ from diffusers import DiffusionPipeline
2
+ from diffusers.utils import pt_to_pil
3
  import gradio as gr
4
+ import torch
5
 
6
+
7
+ stage_1 = DiffusionPipeline.from_pretrained(
8
+ "DeepFloyd/IF-I-M-v1.0", variant="fp16", torch_dtype=torch.float16
9
+ )
10
+ stage_1.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
11
+ stage_1.enable_model_cpu_offload()
12
+ stage_2 = DiffusionPipeline.from_pretrained(
13
+ "DeepFloyd/IF-II-M-v1.0",
14
+ text_encoder=None,
15
+ variant="fp16",
16
+ torch_dtype=torch.float16,
17
+ )
18
+ stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
19
+ stage_2.enable_model_cpu_offload()
20
+
21
+ # stage 3
22
+ safety_modules = {
23
+ "feature_extractor": stage_1.feature_extractor,
24
+ "safety_checker": stage_1.safety_checker,
25
+ "watermarker": stage_1.watermarker,
26
+ }
27
+ stage_3 = DiffusionPipeline.from_pretrained(
28
+ "stabilityai/stable-diffusion-x4-upscaler",
29
+ **safety_modules,
30
+ torch_dtype=torch.float16
31
+ )
32
+ stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
33
+ stage_3.enable_model_cpu_offload()
34
 
35
 
36
  def predict(input_img):
37
+ prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
38
+
39
+ prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
40
+ generator = torch.manual_seed(0)
41
+ image = stage_1(
42
+ prompt_embeds=prompt_embeds,
43
+ negative_prompt_embeds=negative_embeds,
44
+ generator=generator,
45
+ output_type="pt",
46
+ ).images
47
+ image = stage_2(
48
+ image=image,
49
+ prompt_embeds=prompt_embeds,
50
+ negative_prompt_embeds=negative_embeds,
51
+ generator=generator,
52
+ output_type="pt",
53
+ ).images
54
+ image = stage_3(
55
+ prompt=prompt, image=image, generator=generator, noise_level=100
56
+ ).images
57
+ return image
58
 
59
 
60
  gradio_app = gr.Interface(
61
+ fn=predict,
62
+ inputs="text",
63
+ outputs="image",
64
+ title="Text to Image Generator",
65
+ description="Enter a text string to generate an image.",
 
 
 
 
66
  )
67
 
68
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,3 +1,8 @@
 
 
1
  gradio
 
 
2
  transformers
3
- torch
 
 
1
+ accelerate
2
+ diffusers
3
  gradio
4
+ safetensors
5
+ sentencepiece
6
  transformers
7
+ torch
8
+ xformers