Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from transformers import pipeline | |
| import ast | |
| CAPTION_MODELS = { | |
| 'blip-base': 'Salesforce/blip-image-captioning-base', | |
| 'blip-large': 'Salesforce/blip-image-captioning-large', | |
| 'vit-gpt2-coco-en': 'ydshieh/vit-gpt2-coco-en', | |
| 'blip2-2.7b_8bit': 'Mediocreatmybest/blip2-opt-2.7b_8bit', | |
| 'blip2-2.7b-fp16': 'Mediocreatmybest/blip2-opt-2.7b-fp16-sharded', | |
| } | |
| # Create a dictionary to store loaded models | |
| loaded_models = {} | |
| # Simple caption creation | |
| def caption_image(model_choice, image_input, url_inputs, load_in_8bit, device): | |
| if image_input is not None: | |
| input_data = [image_input] | |
| else: | |
| input_data = ast.literal_eval(url_inputs) # interpret the input string as a list | |
| captions = [] | |
| model_key = (model_choice, load_in_8bit) # Create a tuple to represent the unique combination of model and 8bit loading | |
| # Check if the model is already loaded | |
| if model_key in loaded_models: | |
| captioner = loaded_models[model_key] | |
| else: | |
| model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {} | |
| dtype = torch.float16 if load_in_8bit else torch.float32 # Set dtype based on the value of load_in_8bit | |
| captioner = pipeline(task="image-to-text", | |
| model=CAPTION_MODELS[model_choice], | |
| max_new_tokens=30, | |
| device=device, # Use selected device | |
| model_kwargs=model_kwargs, | |
| torch_dtype=dtype, # Set the floating point | |
| use_fast=True | |
| ) | |
| # Store the loaded model | |
| loaded_models[model_key] = captioner | |
| for input_item in input_data: | |
| caption = captioner(input_item)[0]['generated_text'] | |
| captions.append(str(caption).strip()) | |
| return captions | |
| def launch(model_choice, image_input, url_inputs, load_in_8bit, device): | |
| return caption_image(model_choice, image_input, url_inputs, load_in_8bit, device) | |
| model_dropdown = gr.Dropdown(choices=list(CAPTION_MODELS.keys()), label='Select Caption Model') | |
| image_input = gr.Image(type="pil", label="Input Image", multiple=True) # Enable multiple inputs | |
| url_inputs = gr.Textbox(label="Input URLs", description="Enter URLs in a list format, e.g., ['url1', 'url2', 'url3']") | |
| load_in_8bit = gr.Checkbox(label="Load model in 8bit") | |
| device = gr.Radio(['cpu', 'cuda'], label='Select device', default='cpu') | |
| iface = gr.Interface(launch, inputs=[model_dropdown, image_input, url_inputs, load_in_8bit, device], | |
| outputs=gr.outputs.Textbox(type="text", label="Caption")) | |
| iface.launch() |