Commit
•
17b4c6b
1
Parent(s):
c04f31b
Update app.py
Browse files
app.py
CHANGED
@@ -11,8 +11,9 @@ feature_extractor = ViTImageProcessor.from_pretrained(encoder_checkpoint)
|
|
11 |
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
|
12 |
caption_model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
|
13 |
|
14 |
-
#
|
15 |
-
|
|
|
16 |
|
17 |
def predict(image):
|
18 |
# Generate a caption from the image
|
@@ -22,7 +23,7 @@ def predict(image):
|
|
22 |
caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
|
23 |
|
24 |
# Generate an image from the caption
|
25 |
-
|
26 |
|
27 |
return caption_text, generated_images[0]
|
28 |
|
|
|
11 |
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
|
12 |
caption_model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
|
13 |
|
14 |
+
# Load the Stable Diffusion model
|
15 |
+
diffusion_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
|
16 |
+
diffusion_model = diffusion_model.to(device)
|
17 |
|
18 |
def predict(image):
|
19 |
# Generate a caption from the image
|
|
|
23 |
caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
|
24 |
|
25 |
# Generate an image from the caption
|
26 |
+
generated_image = diffusion_model(caption_text)["sample"][0]
|
27 |
|
28 |
return caption_text, generated_images[0]
|
29 |
|