Spaces:
Running
Running
import os | |
import gradio as gr | |
from haystack.components.generators import HuggingFaceTGIGenerator | |
from haystack.components.builders.prompt_builder import PromptBuilder | |
from haystack import Pipeline | |
from haystack.utils import Secret | |
from image_captioner import ImageCaptioner | |
description = """ | |
# Captionate 📸 | |
### Create Instagram captions for your pics! | |
* Upload your photo or select one from the examples | |
* Choose your model | |
* ✨ Captionate! ✨ | |
It uses [Salesforce/blip-image-captioning-base](https://huggingface.co/Salesforce/blip-image-captioning-base) model for image2task caption generation task. | |
For Instagrammable captions, `mistralai/Mistral-7B-Instruct-v0.2` performs best, but try different models to see how they react to the same prompt. | |
Built by [Bilge Yucel](https://twitter.com/bilgeycl) using [Haystack 2.0](https://github.com/deepset-ai/haystack) 💙 | |
""" | |
prompt_template = """ | |
You will receive a descriptive text of a photo. | |
Try to generate a nice Instagram caption with a phrase rhyming with the text. Include emojis in the caption. | |
Descriptive text: {{caption}}; | |
Instagram Caption: | |
""" | |
hf_api_key = os.environ["HF_API_KEY"] | |
def generate_caption(image_file_path, model_name): | |
image_to_text = ImageCaptioner( | |
model_name="Salesforce/blip-image-captioning-base", | |
) | |
prompt_builder = PromptBuilder(template=prompt_template) | |
generator = HuggingFaceTGIGenerator(model=model_name, token=Secret.from_token(hf_api_key), generation_kwargs={"max_new_tokens":50}) | |
captioning_pipeline = Pipeline() | |
captioning_pipeline.add_component("image_to_text", image_to_text) | |
captioning_pipeline.add_component("prompt_builder", prompt_builder) | |
captioning_pipeline.add_component("generator", generator) | |
captioning_pipeline.connect("image_to_text.caption", "prompt_builder.caption") | |
captioning_pipeline.connect("prompt_builder", "generator") | |
result = captioning_pipeline.run({"image_to_text":{"image_file_path":image_file_path}}) | |
return result["generator"]["replies"][0] | |
with gr.Blocks(theme="soft") as demo: | |
gr.Markdown(value=description) | |
with gr.Row(): | |
image = gr.Image(type="filepath") | |
with gr.Column(): | |
model_name = gr.Dropdown(["mistralai/Mistral-7B-Instruct-v0.2","OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", "tiiuae/falcon-7b-instruct", "tiiuae/falcon-7b", "HuggingFaceH4/starchat-beta", "bigscience/bloom", "google/flan-t5-xxl"], value="mistralai/Mistral-7B-Instruct-v0.2", label="Choose your model!") | |
gr.Examples(["./whale.png", "./rainbow.jpeg", "./selfie.png"], inputs=image, label="Click on any example") | |
submit_btn = gr.Button("✨ Captionate ✨") | |
caption = gr.Textbox(label="Caption", show_copy_button=True) | |
submit_btn.click(fn=generate_caption, inputs=[image, model_name], outputs=[caption]) | |
if __name__ == "__main__": | |
demo.launch() |