zanemotiwala commited on
Commit
78f4578
1 Parent(s): 9278785

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -19
app.py CHANGED
@@ -1,6 +1,5 @@
1
- import torch
2
  import gradio as gr
3
- from transformers import AutoTokenizer, VisionEncoderDecoderModel, ViTImageProcessor, pipeline
4
  from diffusers import StableDiffusionPipeline
5
 
6
  # Initialize device and models for captioning
@@ -16,32 +15,38 @@ caption_model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(d
16
  diffusion_model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
17
  diffusion_model = diffusion_model.to(device)
18
 
19
- def predict(image):
20
  # Generate a caption from the image
21
  image = image.convert('RGB')
22
  image_tensor = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
23
  caption_ids = caption_model.generate(image_tensor, max_length=128, num_beams=3)[0]
24
  caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
25
  return caption_text
26
-
27
- def generate(image):
28
- caption=predict(image)
29
  # Generate an image from the caption
30
  generated_image = diffusion_model(caption)["sample"][0]
 
31
 
32
- return caption, generated_images[0]
33
-
34
  # Set up Gradio interface
35
- input = gr.Image(label="Upload any Image", type='pil')
36
- outputs = [gr.Textbox(label="Caption"), gr.Image(label="Generated Image")]
37
- #examples = ['example1.jpeg']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  title = "Image Captioning and Generation"
40
- interface = gr.Interface(
41
- fn=generate,
42
- inputs=input,
43
- outputs=outputs,
44
- #examples=examples,
45
- title=title,
46
- )
47
- interface.launch()
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, VisionEncoderDecoderModel, ViTImageProcessor
3
  from diffusers import StableDiffusionPipeline
4
 
5
  # Initialize device and models for captioning
 
15
  diffusion_model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
16
  diffusion_model = diffusion_model.to(device)
17
 
18
+ def get_caption(image):
19
  # Generate a caption from the image
20
  image = image.convert('RGB')
21
  image_tensor = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
22
  caption_ids = caption_model.generate(image_tensor, max_length=128, num_beams=3)[0]
23
  caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
24
  return caption_text
25
+
26
+ def generate_image(caption):
 
27
  # Generate an image from the caption
28
  generated_image = diffusion_model(caption)["sample"][0]
29
+ return generated_image
30
 
 
 
31
  # Set up Gradio interface
32
+ image_input = gr.Image(label="Upload any Image", type='pil')
33
+ caption_output = gr.Textbox(label="Caption")
34
+ generated_image_output = gr.Image(label="Generated Image")
35
+
36
+ with gr.Blocks() as demo:
37
+ with gr.Row():
38
+ with gr.Column():
39
+ image = image_input()
40
+ get_caption_btn = gr.Button("Get Caption")
41
+ with gr.Column():
42
+ caption = caption_output()
43
+ generate_image_btn = gr.Button("Generate Image")
44
+ with gr.Row():
45
+ generated_image = generated_image_output()
46
+
47
+ caption.update(get_caption_btn.click(get_caption, inputs=image, outputs=caption))
48
+ generated_image.update(generate_image_btn.click(generate_image, inputs=caption, outputs=generated_image))
49
 
50
  title = "Image Captioning and Generation"
51
+ demo.launch(title=title)
52
+