from threading import Thread import gradio as gr import torch from PIL import Image from transformers import PreTrainedModel # for type hint from transformers import TextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer # Moondream from transformers import YolosImageProcessor, YolosForObjectDetection # YOLOS-small-300 # --- YOLOS --- # yolos_id = "hustvl/yolos-small-300" yolos_processor: YolosImageProcessor = YolosImageProcessor.from_pretrained(yolos_id) yolos_model: YolosForObjectDetection = YolosForObjectDetection.from_pretrained(yolos_id) # --- Moondream --- # # Moondream does not support the HuggingFace pipeline system, so we have to do it manually moondream_id = "vikhyatk/moondream2" moondream_revision = "2024-04-02" moondream_tokenizer = AutoTokenizer.from_pretrained(moondream_id, revision=moondream_revision) moondream_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( moondream_id, trust_remote_code=True, revision=moondream_revision ) moondream_model.eval() def answer_question(img, prompt): """ Submits an image and prompt to the Moondream model. :param img: :param prompt: :return: yields the output buffer string """ image_embeds = moondream_model.encode_image(img) streamer = TextIteratorStreamer(moondream_tokenizer, skip_special_tokens=True) thread = Thread( target=moondream_model.answer_question, kwargs={ "image_embeds": image_embeds, "question": prompt, "tokenizer": moondream_tokenizer, "streamer": streamer, }, ) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield buffer.strip() def detect_objects(img: Image.Image): """ Submits an image to the YOLOS-Small-300 model for object detection. :param img: :return: """ inputs = yolos_processor(img, return_tensors="pt") outputs = yolos_model(**inputs) target_sizes = torch.tensor([tuple(reversed(img.size))]) results = yolos_processor.post_process_object_detection(outputs, threshold=0.7, target_sizes=target_sizes)[0] box_images = [] for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): box = [round(i, 2) for i in box.tolist()] print( f"Detected {yolos_model.config.id2label[label.item()]} with confidence " f"{round(score.item(), 3)} at location {box}" ) box_images.append(( img.crop((box[0], box[1], box[2], box[3])), f"{yolos_model.config.id2label[label.item()]} ({round(score.item(), 3)})") ) box_images.append((img, f"original")) return box_images def get_selected_index(evt: gr.SelectData) -> int: """ Listener for the gallery selection event. :return: index of the currently selected image """ return evt.index def to_moondream(images: list[tuple[Image.Image, str | None]], index: int) -> tuple[gr.Tabs, Image.Image]: """ Listener that sends selected gallery image to the moondream model. :param images: list of images from yolos_gallery :param index: index of selected gallery image :return: selected tab and selected image (no caption) """ return gr.Tabs(selected='moondream'), images[index][0] def enable_button() -> gr.Button: """ Helper function for Gradio event listeners. :return: a button with ``interactive=True`` and ``variant="primary"`` """ return gr.Button(interactive=True, variant="primary") def disable_button() -> gr.Button: """ Helper function for Gradio event listeners. :return: a button with ``interactive=False`` and ``variant="secondary"`` """ return gr.Button(interactive=False, variant="secondary") if __name__ == "__main__": with gr.Blocks() as app: gr.Markdown( """ # Food Identifier Final project for IAT 481 at Simon Fraser University, Spring 2024. **Models used:** - [hustvl/yolos-small-300](https://huggingface.co/hustvl/yolos-small-300) - [vikhyatk/moondream2](https://huggingface.co/vikhyatk/moondream2) """ ) selected_image = gr.Number(visible=False, precision=0) # Referenced: https://github.com/gradio-app/gradio/issues/7726#issuecomment-2028051431 with gr.Tabs() as tabs: with gr.Tab("Object Detection", id='yolos'): with gr.Row(equal_height=False): with gr.Column(): yolos_submit = gr.Button("Detect Objects", interactive=False) yolos_input = gr.Image(label="Input Image", type="pil", interactive=True, mirror_webcam=False) with gr.Column(): proceed_button = gr.Button("Select for Captioning", interactive=False) yolos_gallery = gr.Gallery(label="Detected Objects", object_fit="scale-down", columns=3, show_share_button=False, selected_index=None, allow_preview=False, type="pil", interactive=False) with gr.Tab("Captioning", id='moondream'): with gr.Row(equal_height=False): with gr.Column(): with gr.Group(): moon_prompt = gr.Textbox(label="Ask a question about the image:", value="What is this food item? Include any text on labels.") moon_submit = gr.Button("Submit", interactive=False) moon_img = gr.Image(label="Image", type="pil", interactive=True, mirror_webcam=False) moon_output = gr.TextArea(label="Answer", interactive=False) # --- YOLOS --- # yolos_input.upload(enable_button, None, yolos_submit) yolos_input.clear(disable_button, None, yolos_submit) yolos_submit.click(detect_objects, yolos_input, yolos_gallery) yolos_gallery.select(get_selected_index, None, selected_image) yolos_gallery.select(enable_button, None, proceed_button) proceed_button.click(to_moondream, [yolos_gallery, selected_image], [tabs, moon_img]) proceed_button.click(enable_button, None, moon_submit) # --- Moondream --- # moon_img.upload(enable_button, None, moon_submit) moon_img.clear(disable_button, None, moon_submit) moon_submit.click(answer_question, [moon_img, moon_prompt], moon_output) app.queue().launch()