import argparse import torch import gradio as gr from PIL import Image from io import BytesIO from ferretui.eval.model_UI import load_model, inference class interface: def __init__(self, args, tokenizer, model, image_processor) -> None: self.args = args self.tokenizer = tokenizer self.model = model self.image_processor = image_processor def run(self, image, qs): output, image = inference(self.args, image, qs, self.tokenizer, self.model, self.image_processor) return output, image if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default="./gemma2b-anyres") parser.add_argument("--vision_model_path", type=str, default=None) parser.add_argument("--model_base", type=str, default=None) parser.add_argument("--image_path", type=str, default="") parser.add_argument("--data_path", type=str, default="") parser.add_argument("--answers_file", type=str, default="") parser.add_argument("--conv_mode", type=str, default="ferret_gemma_instruct", help="[ferret_gemma_instruct,ferret_llama_3,ferret_vicuna_v1]") parser.add_argument("--num_chunks", type=int, default=1) parser.add_argument("--chunk_idx", type=int, default=0) parser.add_argument("--image_w", type=int, default=336) # 224 parser.add_argument("--image_h", type=int, default=336) # 224 parser.add_argument("--add_region_feature", action="store_true") parser.add_argument("--region_format", type=str, default="box", choices=["point", "box", "segment", "free_shape"]) parser.add_argument("--no_coor", action="store_true") parser.add_argument("--temperature", type=float, default=0.01) parser.add_argument("--top_p", type=float, default=0.3) parser.add_argument("--num_beams", type=int, default=1) parser.add_argument("--max_new_tokens", type=int, default=128) parser.add_argument("--data_type", type=str, default='fp16', choices=['fp16', 'bf16', 'fp32']) args = parser.parse_args() if args.data_type == 'fp16': args.data_type = torch.float16 elif args.data_type == 'bf16': args.data_type = torch.bfloat16 else: args.data_type = torch.float32 tokenizer, model, image_processor, context_len = load_model(args) gin = interface(args, tokenizer, model, image_processor) iface = gr.Interface( fn=gin.run, inputs=[gr.Image(type="pil", label="Input image"), gr.Textbox(label="Question")], outputs=[gr.Textbox(label="Answer"), gr.Image(type="pil", label="Output image")] ) iface.launch()