Spaces:
Runtime error
Runtime error
import argparse | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from io import BytesIO | |
from ferretui.eval.model_UI import load_model, inference | |
class interface: | |
def __init__(self, args, tokenizer, model, image_processor) -> None: | |
self.args = args | |
self.tokenizer = tokenizer | |
self.model = model | |
self.image_processor = image_processor | |
def run(self, image, qs): | |
output, image = inference(self.args, image, qs, self.tokenizer, self.model, self.image_processor) | |
return output, image | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model_path", type=str, default="./gemma2b-anyres") | |
parser.add_argument("--vision_model_path", type=str, default=None) | |
parser.add_argument("--model_base", type=str, default=None) | |
parser.add_argument("--image_path", type=str, default="") | |
parser.add_argument("--data_path", type=str, default="") | |
parser.add_argument("--answers_file", type=str, default="") | |
parser.add_argument("--conv_mode", type=str, default="ferret_gemma_instruct", | |
help="[ferret_gemma_instruct,ferret_llama_3,ferret_vicuna_v1]") | |
parser.add_argument("--num_chunks", type=int, default=1) | |
parser.add_argument("--chunk_idx", type=int, default=0) | |
parser.add_argument("--image_w", type=int, default=336) # 224 | |
parser.add_argument("--image_h", type=int, default=336) # 224 | |
parser.add_argument("--add_region_feature", action="store_true") | |
parser.add_argument("--region_format", type=str, default="box", choices=["point", "box", "segment", "free_shape"]) | |
parser.add_argument("--no_coor", action="store_true") | |
parser.add_argument("--temperature", type=float, default=0.01) | |
parser.add_argument("--top_p", type=float, default=0.3) | |
parser.add_argument("--num_beams", type=int, default=1) | |
parser.add_argument("--max_new_tokens", type=int, default=128) | |
parser.add_argument("--data_type", type=str, default='fp16', choices=['fp16', 'bf16', 'fp32']) | |
args = parser.parse_args() | |
if args.data_type == 'fp16': | |
args.data_type = torch.float16 | |
elif args.data_type == 'bf16': | |
args.data_type = torch.bfloat16 | |
else: | |
args.data_type = torch.float32 | |
tokenizer, model, image_processor, context_len = load_model(args) | |
gin = interface(args, tokenizer, model, image_processor) | |
iface = gr.Interface( | |
fn=gin.run, | |
inputs=[gr.Image(type="pil", label="Input image"), gr.Textbox(label="Question")], | |
outputs=[gr.Textbox(label="Answer"), gr.Image(type="pil", label="Output image")] | |
) | |
iface.launch() | |