zanemotiwala commited on
Commit
17b4c6b
1 Parent(s): c04f31b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -11,8 +11,9 @@ feature_extractor = ViTImageProcessor.from_pretrained(encoder_checkpoint)
11
  tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
12
  caption_model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
13
 
14
- # Initialize the image generation model (e.g., Stable Diffusion)
15
- image_gen_model = pipeline("text-to-image", model="CompVis/stable-diffusion-v1-4")
 
16
 
17
  def predict(image):
18
  # Generate a caption from the image
@@ -22,7 +23,7 @@ def predict(image):
22
  caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
23
 
24
  # Generate an image from the caption
25
- generated_images = image_gen_model(caption_text, num_images=1)
26
 
27
  return caption_text, generated_images[0]
28
 
 
11
  tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
12
  caption_model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
13
 
14
+ # Load the Stable Diffusion model
15
+ diffusion_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
16
+ diffusion_model = diffusion_model.to(device)
17
 
18
  def predict(image):
19
  # Generate a caption from the image
 
23
  caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
24
 
25
  # Generate an image from the caption
26
+ generated_image = diffusion_model(caption_text)["sample"][0]
27
 
28
  return caption_text, generated_images[0]
29