zanemotiwala commited on
Commit
0858412
1 Parent(s): 900454a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -22,12 +22,15 @@ def predict(image):
22
  image_tensor = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
23
  caption_ids = caption_model.generate(image_tensor, max_length=128, num_beams=3)[0]
24
  caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
25
-
 
 
 
26
  # Generate an image from the caption
27
- generated_image = diffusion_model(caption_text)["sample"][0]
28
-
29
- return caption_text, generated_images[0]
30
 
 
 
31
  # Set up Gradio interface
32
  input = gr.Image(label="Upload any Image", type='pil')
33
  outputs = [gr.Textbox(label="Caption"), gr.Image(label="Generated Image")]
@@ -41,4 +44,4 @@ interface = gr.Interface(
41
  #examples=examples,
42
  title=title,
43
  )
44
- interface.launch(debug=True)
 
22
  image_tensor = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
23
  caption_ids = caption_model.generate(image_tensor, max_length=128, num_beams=3)[0]
24
  caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
25
+ return caption_text
26
+
27
+ def generate(image):
28
+ caption=predict(image)
29
  # Generate an image from the caption
30
+ generated_image = diffusion_model(caption)["sample"][0]
 
 
31
 
32
+ return caption, generated_images[0]
33
+
34
  # Set up Gradio interface
35
  input = gr.Image(label="Upload any Image", type='pil')
36
  outputs = [gr.Textbox(label="Caption"), gr.Image(label="Generated Image")]
 
44
  #examples=examples,
45
  title=title,
46
  )
47
+ interface.launch(share=True)