import gradio as gr from PIL import Image from transformers import BitsAndBytesConfig, PaliGemmaForConditionalGeneration, PaliGemmaProcessor import spaces import torch import os access_token = os.getenv('HF_token') model_id = "selamw/BirdWatcher" bnb_config = BitsAndBytesConfig(load_in_8bit=True) def convert_to_markdown(input_text): """Converts bird information text to Markdown format, making specific keywords bold and adding headings. Args: input_text (str): The input text containing bird information. Returns: str: The formatted Markdown text. """ bold_words = ['Look:', 'Cool Fact!:', 'Habitat:', 'Food:', 'Birdie Behaviors:'] # Split into title and content based on the first ":", handling extra whitespace if ":" in input_text: title, content = map(str.strip, input_text.split(":", 1)) else: title = input_text content = "" # Bold the keywords for word in bold_words: content = content.replace(word, f'\n\n**{word}') # Construct the Markdown output with headings formatted_output = f"**{title}**{content}" return formatted_output.strip() @spaces.GPU def infer_fin_pali(image, question): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, token=access_token) processor = PaliGemmaProcessor.from_pretrained(model_id, token=access_token) inputs = processor(images=image, text=question, return_tensors="pt").to(device) predictions = model.generate(**inputs, max_new_tokens=512) decoded_output = processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n") # Ensure proper Markdown formatting formatted_output = convert_to_markdown(decoded_output) return formatted_output css = """ #mkd { height: 500px; overflow: auto; border: 1px solid #ccc; } h1 { text-align: center; } h3 { text-align: center; } h2 { text-align: center; } span.gray-text { color: gray; } """ with gr.Blocks(css=css) as demo: gr.HTML("
(There are over 11,000 bird species in the world, and this model was fine-tuned with over 500)
") with gr.Tab(label="Bird Identification"): with gr.Row(): input_img = gr.Image(label="Input Bird Image") with gr.Column(): with gr.Row(): question = gr.Text(label="Default Prompt", value="Describe this bird species", elem_id="default-prompt", interactive=True) with gr.Row(): submit_btn = gr.Button(value="Run") with gr.Row(): output = gr.Markdown(label="Response") # Use Markdown component to display output submit_btn.click(infer_fin_pali, [input_img, question], [output]) gr.Examples( [["01.jpg", "Describe this bird species"], ["02.jpg", "Describe this bird species"], ["03.jpg", "Describe this bird species"], ["04.jpg", "Describe this bird species"], ["05.jpg", "Describe this bird species"], ["06.jpg", "Describe this bird species"]], inputs=[input_img, question], outputs=[output], fn=infer_fin_pali, label='Examples 👇' ) demo.launch(debug=True, share=True)