atharvapawar commited on
Commit
0672323
1 Parent(s): cae188c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -6
app.py CHANGED
@@ -2,17 +2,47 @@ import gradio as gr
2
  from diffusers import AutoPipelineForText2Image
3
  import torch
4
 
 
5
  pipeline = AutoPipelineForText2Image.from_pretrained("sd-dreambooth-library/herge-style", torch_dtype=torch.float16).to("cuda")
6
 
7
  def generate_image(prompt):
 
8
  pipeline.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
9
  image = pipeline(prompt).images[0]
10
- return image.permute(1, 2, 0).cpu().numpy()
11
 
12
- title = "DreamBooth Gradio App"
13
- description = "Generate images based on prompts using the DreamBooth model."
 
 
 
14
 
15
- prompt_textbox = gr.Textbox(lines=3, label="Prompt")
16
- image_output = gr.Image(draw_fn=generate_image, label="Generated Image")
 
 
17
 
18
- gr.Interface(fn=generate_image, inputs=prompt_textbox, outputs=image_output, title=title, description=description).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from diffusers import AutoPipelineForText2Image
3
  import torch
4
 
5
+ # Load Dreambooth model
6
  pipeline = AutoPipelineForText2Image.from_pretrained("sd-dreambooth-library/herge-style", torch_dtype=torch.float16).to("cuda")
7
 
8
  def generate_image(prompt):
9
+ # Generate image based on prompt
10
  pipeline.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
11
  image = pipeline(prompt).images[0]
12
+ return image
13
 
14
+ def image_to_base64(image):
15
+ # Convert image to base64
16
+ buffered = BytesIO()
17
+ image.save(buffered, format="JPEG")
18
+ return base64.b64encode(buffered.getvalue()).decode()
19
 
20
+ def base64_to_image(base64_str):
21
+ # Convert base64 to image
22
+ image_data = base64.b64decode(base64_str)
23
+ return Image.open(BytesIO(image_data))
24
 
25
+ def handle_prompt_image(prompt):
26
+ # Generate image based on prompt and convert to base64
27
+ image = generate_image(prompt)
28
+ base64_str = image_to_base64(image)
29
+ return base64_str
30
+
31
+ def main():
32
+ # Interface setup
33
+ image_input = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here...")
34
+ prompt_output = gr.Textbox(label="Base64 Encoded Image", readonly=True)
35
+
36
+ iface = gr.Interface(
37
+ fn=handle_prompt_image,
38
+ inputs=image_input,
39
+ outputs=prompt_output,
40
+ title="Dreambooth Image Generator",
41
+ description="Enter a prompt to generate an image using the Dreambooth model.",
42
+ theme="compact"
43
+ )
44
+
45
+ iface.launch(share=True)
46
+
47
+ if __name__ == "__main__":
48
+ main()