import lancedb import lancedb.embeddings.imagebind from lancedb.embeddings import get_registry from lancedb.pydantic import LanceModel, Vector import gradio as gr from downloader import dowload_and_save_audio, dowload_and_save_image, base_path model = get_registry().get("imagebind").create() class TextModel(LanceModel): text: str image_uri: str = model.SourceField() audio_path: str vector: Vector(model.ndims()) = model.VectorField() text_list = ["A bird", "A dragon", "A car"] image_paths = dowload_and_save_image() audio_paths = dowload_and_save_audio() # Load data inputs = [ {"text": a, "audio_path": b, "image_uri": c} for a, b, c in zip(text_list, audio_paths, image_paths) ] db = lancedb.connect("/tmp/lancedb") table = db.create_table("img_bind", schema=TextModel) table.add(inputs) def process_image(inp_img) -> str: actual = ( table.search(inp_img, vector_column_name="vector") .limit(1) .to_pydantic(TextModel)[0] ) return actual.text, actual.audio_path def process_text(inp_text) -> str: actual = ( table.search(inp_text, vector_column_name="vector") .limit(1) .to_pydantic(TextModel)[0] ) return actual.image_uri, actual.audio_path def process_audio(inp_audio) -> str: actual = ( table.search(inp_audio, vector_column_name="vector") .limit(1) .to_pydantic(TextModel)[0] ) return actual.image_uri, actual.text im_to_at = gr.Interface( process_image, gr.Image(type="filepath", value=image_paths[0]), [gr.Text(label="Output Text"), gr.Audio(label="Output Audio")], examples=image_paths, allow_flagging="never", ) txt_to_ia = gr.Interface( process_text, gr.Textbox(label="Enter a prompt:"), [gr.Image(label="Output Image"), gr.Audio(label="Output Audio")], allow_flagging="never", examples=text_list, ) a_to_it = gr.Interface( process_audio, gr.Audio(type="filepath", value=audio_paths[0]), [gr.Image(label="Output Image"), gr.Text(label="Output Text")], examples=audio_paths, allow_flagging="never", ) demo = gr.TabbedInterface( [im_to_at, txt_to_ia, a_to_it], ["Image to Text/Audio", "Text to Image/Audio", "Audio to Image/Text"], ) if __name__ == "__main__": demo.launch(share=True, allowed_paths=[f"{base_path}/test_inputs/"])