zanemotiwala commited on
Commit
c04f31b
1 Parent(s): be2767a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, VisionEncoderDecoderModel, ViTImageProcessor, pipeline
4
+
5
+ # Initialize device and models for captioning
6
+ device = 'cpu'
7
+ encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
8
+ decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
9
+ model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
10
+ feature_extractor = ViTImageProcessor.from_pretrained(encoder_checkpoint)
11
+ tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
12
+ caption_model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
13
+
14
+ # Initialize the image generation model (e.g., Stable Diffusion)
15
+ image_gen_model = pipeline("text-to-image", model="CompVis/stable-diffusion-v1-4")
16
+
17
+ def predict(image):
18
+ # Generate a caption from the image
19
+ image = image.convert('RGB')
20
+ image_tensor = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
21
+ caption_ids = caption_model.generate(image_tensor, max_length=128, num_beams=3)[0]
22
+ caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
23
+
24
+ # Generate an image from the caption
25
+ generated_images = image_gen_model(caption_text, num_images=1)
26
+
27
+ return caption_text, generated_images[0]
28
+
29
+ # Set up Gradio interface
30
+ input = gr.Image(label="Upload any Image", type='pil')
31
+ outputs = [gr.Textbox(label="Caption"), gr.Image(label="Generated Image")]
32
+ examples = [f"example{i}.jpeg" for i in range(1, 3)]
33
+
34
+ title = "Image Captioning and Generation"
35
+ interface = gr.Interface(
36
+ fn=predict,
37
+ inputs=input,
38
+ outputs=outputs,
39
+ examples=examples,
40
+ title=title,
41
+ )
42
+ interface.launch(debug=True)