ysharma HF staff commited on
Commit
8bf5a04
β€’
1 Parent(s): cf81290

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -1,8 +1,11 @@
1
  import torch
2
  import requests
 
 
 
 
3
  from PIL import Image
4
  from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
5
- import rembg
6
 
7
  # Load the pipeline
8
  pipeline = DiffusionPipeline.from_pretrained(
@@ -23,7 +26,9 @@ pipeline.to('cuda:0')
23
  def inference(input_img, num_inference_steps, guidance_scale, seed ):
24
  # Download an example image.
25
  cond = Image.open(input_img)
26
-
 
 
27
  # Run the pipeline!
28
  #result = pipeline(cond, num_inference_steps=75).images[0]
29
  result = pipeline(cond, num_inference_steps=num_inference_steps,
@@ -40,13 +45,17 @@ def inference(input_img, num_inference_steps, guidance_scale, seed ):
40
  return result
41
 
42
  def remove_background(result):
43
- result = Image.open(result)
44
- result = rembg.remove(result)
 
 
 
 
 
45
  return result
 
46
 
47
 
48
- import gradio as gr
49
-
50
  with gr.Blocks() as demo:
51
  gr.Markdown("<h1><center> Zero123++ Demo</center></h1>")
52
  with gr.Column():
 
1
  import torch
2
  import requests
3
+ import rembg
4
+ import random
5
+ import gradio as gr
6
+
7
  from PIL import Image
8
  from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
 
9
 
10
  # Load the pipeline
11
  pipeline = DiffusionPipeline.from_pretrained(
 
26
  def inference(input_img, num_inference_steps, guidance_scale, seed ):
27
  # Download an example image.
28
  cond = Image.open(input_img)
29
+ if seed==0:
30
+ seed = random.randint(1, 1000000)
31
+
32
  # Run the pipeline!
33
  #result = pipeline(cond, num_inference_steps=75).images[0]
34
  result = pipeline(cond, num_inference_steps=num_inference_steps,
 
45
  return result
46
 
47
  def remove_background(result):
48
+ # Check if the variable is a PIL Image
49
+ if isinstance(result, Image.Image):
50
+ result = rembg.remove(result)
51
+ # Check if the variable is a str filepath
52
+ elif isinstance(result, str):
53
+ result = Image.open(result)
54
+ result = rembg.remove(result)
55
  return result
56
+
57
 
58
 
 
 
59
  with gr.Blocks() as demo:
60
  gr.Markdown("<h1><center> Zero123++ Demo</center></h1>")
61
  with gr.Column():