Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	
		alessandro trinca tornidor
		
	commited on
		
		
					Commit 
							
							·
						
						c5fe4a2
	
1
								Parent(s):
							
							f623930
								
[feat] restart from an example app to check what's working
Browse files- app.py +26 -377
 - utils/session_logger.py +36 -0
 
    	
        app.py
    CHANGED
    
    | 
         @@ -1,390 +1,39 @@ 
     | 
|
| 1 | 
         
            -
            import argparse
         
     | 
| 2 | 
         
            -
            import os
         
     | 
| 3 | 
         
            -
            import re
         
     | 
| 4 | 
         
            -
            import sys
         
     | 
| 5 | 
         
            -
            import logging
         
     | 
| 6 | 
         
            -
            from typing import Callable
         
     | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
            -
            from fastapi import FastAPI, File, UploadFile, Request
         
     | 
| 9 | 
         
            -
            from fastapi.responses import HTMLResponse, RedirectResponse
         
     | 
| 10 | 
         
            -
            from fastapi.staticfiles import StaticFiles
         
     | 
| 11 | 
         
            -
            from fastapi.templating import Jinja2Templates
         
     | 
| 12 | 
         
            -
             
     | 
| 13 | 
         
            -
            import cv2
         
     | 
| 14 | 
         
             
            import gradio as gr
         
     | 
| 15 | 
         
            -
            import  
     | 
| 16 | 
         
            -
            import numpy as np
         
     | 
| 17 | 
         
            -
            import torch
         
     | 
| 18 | 
         
            -
            import torch.nn.functional as F
         
     | 
| 19 | 
         
            -
            from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
         
     | 
| 20 | 
         
            -
             
     | 
| 21 | 
         
            -
            from model.LISA import LISAForCausalLM
         
     | 
| 22 | 
         
            -
            from model.llava import conversation as conversation_lib
         
     | 
| 23 | 
         
            -
            from model.llava.mm_utils import tokenizer_image_token
         
     | 
| 24 | 
         
            -
            from model.segment_anything.utils.transforms import ResizeLongestSide
         
     | 
| 25 | 
         
            -
            from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
         
     | 
| 26 | 
         
            -
                                     DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
         
     | 
| 27 | 
         | 
| 
         | 
|
| 28 | 
         | 
| 29 | 
         
            -
            CUSTOM_GRADIO_PATH = "/gradio"
         
     | 
| 30 | 
         
            -
            app = FastAPI()
         
     | 
| 31 | 
         | 
| 32 | 
         
            -
             
     | 
| 33 | 
         
            -
             
     | 
| 34 | 
         
            -
            app.mount("/static", StaticFiles(directory=FASTAPI_STATIC), name="static")
         
     | 
| 35 | 
         
            -
            templates = Jinja2Templates(directory="templates")
         
     | 
| 36 | 
         | 
| 37 | 
         | 
| 38 | 
         
            -
             
     | 
| 39 | 
         
            -
             
     | 
| 40 | 
         
            -
             
     | 
| 41 | 
         
            -
             
     | 
| 42 | 
         
            -
             
     | 
| 43 | 
         
            -
             
     | 
| 44 | 
         
            -
             
     | 
| 45 | 
         
            -
             
     | 
| 46 | 
         
            -
             
     | 
| 47 | 
         
            -
                        "code",
         
     | 
| 48 | 
         
            -
                        "em",
         
     | 
| 49 | 
         
            -
                        "i",
         
     | 
| 50 | 
         
            -
                        "li",
         
     | 
| 51 | 
         
            -
                        "ol",
         
     | 
| 52 | 
         
            -
                        "strong",
         
     | 
| 53 | 
         
            -
                        "ul",
         
     | 
| 54 | 
         
            -
                    },
         
     | 
| 55 | 
         
            -
                    attributes={
         
     | 
| 56 | 
         
            -
                        "a": {"href", "title"},
         
     | 
| 57 | 
         
            -
                        "abbr": {"title"},
         
     | 
| 58 | 
         
            -
                        "acronym": {"title"},
         
     | 
| 59 | 
         
            -
                    },
         
     | 
| 60 | 
         
            -
                    url_schemes={"http", "https", "mailto"},
         
     | 
| 61 | 
         
            -
                    link_rel=None,
         
     | 
| 62 | 
         
            -
                )
         
     | 
| 63 | 
         
            -
                return input_str
         
     | 
| 64 | 
         | 
| 65 | 
         | 
| 66 | 
         
            -
            @ 
     | 
| 67 | 
         
            -
             
     | 
| 68 | 
         
            -
                logging.info( 
     | 
| 69 | 
         
            -
                 
     | 
| 70 | 
         
            -
                logging.info(f" 
     | 
| 71 | 
         
            -
                return  
     | 
| 72 | 
         
            -
                    "home.html", {"clean_request": clean_request}
         
     | 
| 73 | 
         
            -
                )
         
     | 
| 74 | 
         | 
| 75 | 
         | 
| 76 | 
         
            -
             
     | 
| 77 | 
         
            -
             
     | 
| 78 | 
         
            -
                [
         
     | 
| 79 | 
         
            -
                     
     | 
| 80 | 
         
            -
                    "./resources/imgs/example1.jpg",
         
     | 
| 81 | 
         
             
                ],
         
     | 
| 82 | 
         
            -
                [
         
     | 
| 83 | 
         
            -
                     
     | 
| 84 | 
         
            -
                    "./resources/imgs/example2.jpg",
         
     | 
| 85 | 
         
             
                ],
         
     | 
| 86 | 
         
            -
             
     | 
| 87 | 
         
            -
                    "Assuming you are an autonomous driving robot, what part of the diagram would you manipulate to control the direction of travel? Please output segmentation mask and explain why.",
         
     | 
| 88 | 
         
            -
                    "./resources/imgs/example1.jpg",
         
     | 
| 89 | 
         
            -
                ],
         
     | 
| 90 | 
         
            -
                [
         
     | 
| 91 | 
         
            -
                    "What can make the woman stand higher? Please output segmentation mask and explain why.",
         
     | 
| 92 | 
         
            -
                    "./resources/imgs/example3.jpg",
         
     | 
| 93 | 
         
            -
                ],
         
     | 
| 94 | 
         
            -
            ]
         
     | 
| 95 | 
         
            -
            output_labels = ["Segmentation Output"]
         
     | 
| 96 | 
         
            -
             
     | 
| 97 | 
         
            -
            title = "LISA: Reasoning Segmentation via Large Language Model"
         
     | 
| 98 | 
         
            -
             
     | 
| 99 | 
         
            -
            description = """
         
     | 
| 100 | 
         
            -
            <font size=4>
         
     | 
| 101 | 
         
            -
            This is the online demo of LISA. \n
         
     | 
| 102 | 
         
            -
            If multiple users are using it at the same time, they will enter a queue, which may delay some time. \n
         
     | 
| 103 | 
         
            -
            **Note**: **Different prompts can lead to significantly varied results**. \n
         
     | 
| 104 | 
         
            -
            **Note**: Please try to **standardize** your input text prompts to **avoid ambiguity**, and also pay attention to whether the **punctuations** of the input are correct. \n
         
     | 
| 105 | 
         
            -
            **Note**: Current model is **LISA-13B-llama2-v0-explanatory**, and 4-bit quantization may impair text-generation quality. \n
         
     | 
| 106 | 
         
            -
            **Usage**: <br>
         
     | 
| 107 | 
         
            -
             (1) To let LISA **segment something**, input prompt like: "Can you segment xxx in this image?", "What is xxx in this image? Please output segmentation mask."; <br>
         
     | 
| 108 | 
         
            -
             (2) To let LISA **output an explanation**, input prompt like: "What is xxx in this image? Please output segmentation mask and explain why."; <br>
         
     | 
| 109 | 
         
            -
             (3) To obtain **solely language output**, you can input like what you should do in current multi-modal LLM (e.g., LLaVA). <br>
         
     | 
| 110 | 
         
            -
            Hope you can enjoy our work!
         
     | 
| 111 | 
         
            -
            </font>
         
     | 
| 112 | 
         
            -
            """
         
     | 
| 113 | 
         
            -
             
     | 
| 114 | 
         
            -
            article = """
         
     | 
| 115 | 
         
            -
            <p style='text-align: center'>
         
     | 
| 116 | 
         
            -
            <a href='https://arxiv.org/abs/2308.00692' target='_blank'>
         
     | 
| 117 | 
         
            -
            Preprint Paper
         
     | 
| 118 | 
         
            -
            </a>
         
     | 
| 119 | 
         
            -
            \n
         
     | 
| 120 | 
         
            -
            <p style='text-align: center'>
         
     | 
| 121 | 
         
            -
            <a href='https://github.com/dvlab-research/LISA' target='_blank'>   Github Repo </a></p>
         
     | 
| 122 | 
         
            -
            """
         
     | 
| 123 | 
         
            -
             
     | 
| 124 | 
         
            -
             
     | 
| 125 | 
         
            -
            def parse_args(args_to_parse):
         
     | 
| 126 | 
         
            -
                parser = argparse.ArgumentParser(description="LISA chat")
         
     | 
| 127 | 
         
            -
                parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1")
         
     | 
| 128 | 
         
            -
                parser.add_argument("--vis_save_path", default="./vis_output", type=str)
         
     | 
| 129 | 
         
            -
                parser.add_argument(
         
     | 
| 130 | 
         
            -
                    "--precision",
         
     | 
| 131 | 
         
            -
                    default="fp16",
         
     | 
| 132 | 
         
            -
                    type=str,
         
     | 
| 133 | 
         
            -
                    choices=["fp32", "bf16", "fp16"],
         
     | 
| 134 | 
         
            -
                    help="precision for inference",
         
     | 
| 135 | 
         
            -
                )
         
     | 
| 136 | 
         
            -
                parser.add_argument("--image_size", default=1024, type=int, help="image size")
         
     | 
| 137 | 
         
            -
                parser.add_argument("--model_max_length", default=512, type=int)
         
     | 
| 138 | 
         
            -
                parser.add_argument("--lora_r", default=8, type=int)
         
     | 
| 139 | 
         
            -
                parser.add_argument(
         
     | 
| 140 | 
         
            -
                    "--vision-tower", default="openai/clip-vit-large-patch14", type=str
         
     | 
| 141 | 
         
            -
                )
         
     | 
| 142 | 
         
            -
                parser.add_argument("--local-rank", default=0, type=int, help="node rank")
         
     | 
| 143 | 
         
            -
                parser.add_argument("--load_in_8bit", action="store_true", default=False)
         
     | 
| 144 | 
         
            -
                parser.add_argument("--load_in_4bit", action="store_true", default=False)
         
     | 
| 145 | 
         
            -
                parser.add_argument("--use_mm_start_end", action="store_true", default=True)
         
     | 
| 146 | 
         
            -
                parser.add_argument(
         
     | 
| 147 | 
         
            -
                    "--conv_type",
         
     | 
| 148 | 
         
            -
                    default="llava_v1",
         
     | 
| 149 | 
         
            -
                    type=str,
         
     | 
| 150 | 
         
            -
                    choices=["llava_v1", "llava_llama_2"],
         
     | 
| 151 | 
         
            -
                )
         
     | 
| 152 | 
         
            -
                return parser.parse_args(args_to_parse)
         
     | 
| 153 | 
         
            -
             
     | 
| 154 | 
         
            -
             
     | 
| 155 | 
         
            -
            def set_image_precision_by_args(input_image, precision):
         
     | 
| 156 | 
         
            -
                if precision == "bf16":
         
     | 
| 157 | 
         
            -
                    input_image = input_image.bfloat16()
         
     | 
| 158 | 
         
            -
                elif precision == "fp16":
         
     | 
| 159 | 
         
            -
                    input_image = input_image.half()
         
     | 
| 160 | 
         
            -
                else:
         
     | 
| 161 | 
         
            -
                    input_image = input_image.float()
         
     | 
| 162 | 
         
            -
                return input_image
         
     | 
| 163 | 
         
            -
             
     | 
| 164 | 
         
            -
             
     | 
| 165 | 
         
            -
            def preprocess(
         
     | 
| 166 | 
         
            -
                x,
         
     | 
| 167 | 
         
            -
                pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
         
     | 
| 168 | 
         
            -
                pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
         
     | 
| 169 | 
         
            -
                img_size=1024,
         
     | 
| 170 | 
         
            -
            ) -> torch.Tensor:
         
     | 
| 171 | 
         
            -
                """Normalize pixel values and pad to a square input."""
         
     | 
| 172 | 
         
            -
                # Normalize colors
         
     | 
| 173 | 
         
            -
                x = (x - pixel_mean) / pixel_std
         
     | 
| 174 | 
         
            -
                # Pad
         
     | 
| 175 | 
         
            -
                h, w = x.shape[-2:]
         
     | 
| 176 | 
         
            -
                padh = img_size - h
         
     | 
| 177 | 
         
            -
                padw = img_size - w
         
     | 
| 178 | 
         
            -
                x = F.pad(x, (0, padw, 0, padh))
         
     | 
| 179 | 
         
            -
                return x
         
     | 
| 180 | 
         
            -
             
     | 
| 181 | 
         
            -
             
     | 
| 182 | 
         
            -
            def get_model(args_to_parse):
         
     | 
| 183 | 
         
            -
                os.makedirs(args_to_parse.vis_save_path, exist_ok=True)
         
     | 
| 184 | 
         
            -
             
     | 
| 185 | 
         
            -
                # global tokenizer, tokenizer
         
     | 
| 186 | 
         
            -
                # Create model
         
     | 
| 187 | 
         
            -
                _tokenizer = AutoTokenizer.from_pretrained(
         
     | 
| 188 | 
         
            -
                    args_to_parse.version,
         
     | 
| 189 | 
         
            -
                    cache_dir=None,
         
     | 
| 190 | 
         
            -
                    model_max_length=args_to_parse.model_max_length,
         
     | 
| 191 | 
         
            -
                    padding_side="right",
         
     | 
| 192 | 
         
            -
                    use_fast=False,
         
     | 
| 193 | 
         
            -
                )
         
     | 
| 194 | 
         
            -
                _tokenizer.pad_token = _tokenizer.unk_token
         
     | 
| 195 | 
         
            -
                args_to_parse.seg_token_idx = _tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
         
     | 
| 196 | 
         
            -
                torch_dtype = torch.float32
         
     | 
| 197 | 
         
            -
                if args_to_parse.precision == "bf16":
         
     | 
| 198 | 
         
            -
                    torch_dtype = torch.bfloat16
         
     | 
| 199 | 
         
            -
                elif args_to_parse.precision == "fp16":
         
     | 
| 200 | 
         
            -
                    torch_dtype = torch.half
         
     | 
| 201 | 
         
            -
                kwargs = {"torch_dtype": torch_dtype}
         
     | 
| 202 | 
         
            -
                if args_to_parse.load_in_4bit:
         
     | 
| 203 | 
         
            -
                    kwargs.update(
         
     | 
| 204 | 
         
            -
                        {
         
     | 
| 205 | 
         
            -
                            "torch_dtype": torch.half,
         
     | 
| 206 | 
         
            -
                            "load_in_4bit": True,
         
     | 
| 207 | 
         
            -
                            "quantization_config": BitsAndBytesConfig(
         
     | 
| 208 | 
         
            -
                                load_in_4bit=True,
         
     | 
| 209 | 
         
            -
                                bnb_4bit_compute_dtype=torch.float16,
         
     | 
| 210 | 
         
            -
                                bnb_4bit_use_double_quant=True,
         
     | 
| 211 | 
         
            -
                                bnb_4bit_quant_type="nf4",
         
     | 
| 212 | 
         
            -
                                llm_int8_skip_modules=["visual_model"],
         
     | 
| 213 | 
         
            -
                            ),
         
     | 
| 214 | 
         
            -
                        }
         
     | 
| 215 | 
         
            -
                    )
         
     | 
| 216 | 
         
            -
                elif args_to_parse.load_in_8bit:
         
     | 
| 217 | 
         
            -
                    kwargs.update(
         
     | 
| 218 | 
         
            -
                        {
         
     | 
| 219 | 
         
            -
                            "torch_dtype": torch.half,
         
     | 
| 220 | 
         
            -
                            "quantization_config": BitsAndBytesConfig(
         
     | 
| 221 | 
         
            -
                                llm_int8_skip_modules=["visual_model"],
         
     | 
| 222 | 
         
            -
                                load_in_8bit=True,
         
     | 
| 223 | 
         
            -
                            ),
         
     | 
| 224 | 
         
            -
                        }
         
     | 
| 225 | 
         
            -
                    )
         
     | 
| 226 | 
         
            -
                _model = LISAForCausalLM.from_pretrained(
         
     | 
| 227 | 
         
            -
                    args_to_parse.version, low_cpu_mem_usage=True, vision_tower=args_to_parse.vision_tower, seg_token_idx=args_to_parse.seg_token_idx, **kwargs
         
     | 
| 228 | 
         
            -
                )
         
     | 
| 229 | 
         
            -
                _model.config.eos_token_id = _tokenizer.eos_token_id
         
     | 
| 230 | 
         
            -
                _model.config.bos_token_id = _tokenizer.bos_token_id
         
     | 
| 231 | 
         
            -
                _model.config.pad_token_id = _tokenizer.pad_token_id
         
     | 
| 232 | 
         
            -
                _model.get_model().initialize_vision_modules(_model.get_model().config)
         
     | 
| 233 | 
         
            -
                vision_tower = _model.get_model().get_vision_tower()
         
     | 
| 234 | 
         
            -
                vision_tower.to(dtype=torch_dtype)
         
     | 
| 235 | 
         
            -
                if args_to_parse.precision == "bf16":
         
     | 
| 236 | 
         
            -
                    _model = _model.bfloat16().cuda()
         
     | 
| 237 | 
         
            -
                elif (
         
     | 
| 238 | 
         
            -
                        args_to_parse.precision == "fp16" and (not args_to_parse.load_in_4bit) and (not args_to_parse.load_in_8bit)
         
     | 
| 239 | 
         
            -
                ):
         
     | 
| 240 | 
         
            -
                    vision_tower = _model.get_model().get_vision_tower()
         
     | 
| 241 | 
         
            -
                    _model.model.vision_tower = None
         
     | 
| 242 | 
         
            -
                    import deepspeed
         
     | 
| 243 | 
         
            -
             
     | 
| 244 | 
         
            -
                    model_engine = deepspeed.init_inference(
         
     | 
| 245 | 
         
            -
                        model=_model,
         
     | 
| 246 | 
         
            -
                        dtype=torch.half,
         
     | 
| 247 | 
         
            -
                        replace_with_kernel_inject=True,
         
     | 
| 248 | 
         
            -
                        replace_method="auto",
         
     | 
| 249 | 
         
            -
                    )
         
     | 
| 250 | 
         
            -
                    _model = model_engine.module
         
     | 
| 251 | 
         
            -
                    _model.model.vision_tower = vision_tower.half().cuda()
         
     | 
| 252 | 
         
            -
                elif args_to_parse.precision == "fp32":
         
     | 
| 253 | 
         
            -
                    _model = _model.float().cuda()
         
     | 
| 254 | 
         
            -
                vision_tower = _model.get_model().get_vision_tower()
         
     | 
| 255 | 
         
            -
                vision_tower.to(device=args_to_parse.local_rank)
         
     | 
| 256 | 
         
            -
                _clip_image_processor = CLIPImageProcessor.from_pretrained(_model.config.vision_tower)
         
     | 
| 257 | 
         
            -
                _transform = ResizeLongestSide(args_to_parse.image_size)
         
     | 
| 258 | 
         
            -
                _model.eval()
         
     | 
| 259 | 
         
            -
                return _model, _clip_image_processor, _tokenizer, _transform
         
     | 
| 260 | 
         
            -
             
     | 
| 261 | 
         
            -
             
     | 
| 262 | 
         
            -
            def get_inference_model_by_args(args_to_parse):
         
     | 
| 263 | 
         
            -
                model, clip_image_processor, tokenizer, transform = get_model(args_to_parse)
         
     | 
| 264 | 
         
            -
             
     | 
| 265 | 
         
            -
                ## to be implemented
         
     | 
| 266 | 
         
            -
                def inference(input_str, input_image):
         
     | 
| 267 | 
         
            -
                    ## filter out special chars
         
     | 
| 268 | 
         
            -
             
     | 
| 269 | 
         
            -
                    input_str = get_cleaned_input(input_str)
         
     | 
| 270 | 
         
            -
                    logging.info(f"input_str type: {type(input_str)}, input_image type: {type(input_image)}.")
         
     | 
| 271 | 
         
            -
                    logging.info(f"input_str: {input_str}.")
         
     | 
| 272 | 
         
            -
             
     | 
| 273 | 
         
            -
                    ## input valid check
         
     | 
| 274 | 
         
            -
                    if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
         
     | 
| 275 | 
         
            -
                        output_str = "[Error] Invalid input: ", input_str
         
     | 
| 276 | 
         
            -
                        # output_image = np.zeros((128, 128, 3))
         
     | 
| 277 | 
         
            -
                        ## error happened
         
     | 
| 278 | 
         
            -
                        output_image = cv2.imread("./resources/error_happened.png")[:, :, ::-1]
         
     | 
| 279 | 
         
            -
                        return output_image, output_str
         
     | 
| 280 | 
         
            -
             
     | 
| 281 | 
         
            -
                    # Model Inference
         
     | 
| 282 | 
         
            -
                    conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy()
         
     | 
| 283 | 
         
            -
                    conv.messages = []
         
     | 
| 284 | 
         
            -
             
     | 
| 285 | 
         
            -
                    prompt = input_str
         
     | 
| 286 | 
         
            -
                    prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
         
     | 
| 287 | 
         
            -
                    if args_to_parse.use_mm_start_end:
         
     | 
| 288 | 
         
            -
                        replace_token = (
         
     | 
| 289 | 
         
            -
                            DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
         
     | 
| 290 | 
         
            -
                        )
         
     | 
| 291 | 
         
            -
                        prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
         
     | 
| 292 | 
         
            -
             
     | 
| 293 | 
         
            -
                    conv.append_message(conv.roles[0], prompt)
         
     | 
| 294 | 
         
            -
                    conv.append_message(conv.roles[1], "")
         
     | 
| 295 | 
         
            -
                    prompt = conv.get_prompt()
         
     | 
| 296 | 
         
            -
             
     | 
| 297 | 
         
            -
                    image_np = cv2.imread(input_image)
         
     | 
| 298 | 
         
            -
                    image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
         
     | 
| 299 | 
         
            -
                    original_size_list = [image_np.shape[:2]]
         
     | 
| 300 | 
         
            -
             
     | 
| 301 | 
         
            -
                    image_clip = (
         
     | 
| 302 | 
         
            -
                        clip_image_processor.preprocess(image_np, return_tensors="pt")[
         
     | 
| 303 | 
         
            -
                            "pixel_values"
         
     | 
| 304 | 
         
            -
                        ][0]
         
     | 
| 305 | 
         
            -
                        .unsqueeze(0)
         
     | 
| 306 | 
         
            -
                        .cuda()
         
     | 
| 307 | 
         
            -
                    )
         
     | 
| 308 | 
         
            -
                    logging.info(f"image_clip type: {type(image_clip)}.")
         
     | 
| 309 | 
         
            -
                    image_clip = set_image_precision_by_args(image_clip, args_to_parse.precision)
         
     | 
| 310 | 
         
            -
             
     | 
| 311 | 
         
            -
                    image = transform.apply_image(image_np)
         
     | 
| 312 | 
         
            -
                    resize_list = [image.shape[:2]]
         
     | 
| 313 | 
         
            -
             
     | 
| 314 | 
         
            -
                    image = (
         
     | 
| 315 | 
         
            -
                        preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
         
     | 
| 316 | 
         
            -
                        .unsqueeze(0)
         
     | 
| 317 | 
         
            -
                        .cuda()
         
     | 
| 318 | 
         
            -
                    )
         
     | 
| 319 | 
         
            -
                    logging.info(f"image_clip type: {type(image_clip)}.")
         
     | 
| 320 | 
         
            -
                    image = set_image_precision_by_args(image, args_to_parse.precision)
         
     | 
| 321 | 
         
            -
             
     | 
| 322 | 
         
            -
                    input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
         
     | 
| 323 | 
         
            -
                    input_ids = input_ids.unsqueeze(0).cuda()
         
     | 
| 324 | 
         
            -
             
     | 
| 325 | 
         
            -
                    output_ids, pred_masks = model.evaluate(
         
     | 
| 326 | 
         
            -
                        image_clip,
         
     | 
| 327 | 
         
            -
                        image,
         
     | 
| 328 | 
         
            -
                        input_ids,
         
     | 
| 329 | 
         
            -
                        resize_list,
         
     | 
| 330 | 
         
            -
                        original_size_list,
         
     | 
| 331 | 
         
            -
                        max_new_tokens=512,
         
     | 
| 332 | 
         
            -
                        tokenizer=tokenizer,
         
     | 
| 333 | 
         
            -
                    )
         
     | 
| 334 | 
         
            -
                    output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
         
     | 
| 335 | 
         
            -
             
     | 
| 336 | 
         
            -
                    text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
         
     | 
| 337 | 
         
            -
                    text_output = text_output.replace("\n", "").replace("  ", " ")
         
     | 
| 338 | 
         
            -
                    text_output = text_output.split("ASSISTANT: ")[-1]
         
     | 
| 339 | 
         
            -
             
     | 
| 340 | 
         
            -
                    logging.info(f"text_output type: {type(text_output)}, text_output: {text_output}.")
         
     | 
| 341 | 
         
            -
                    save_img = None
         
     | 
| 342 | 
         
            -
                    for i, pred_mask in enumerate(pred_masks):
         
     | 
| 343 | 
         
            -
                        if pred_mask.shape[0] == 0:
         
     | 
| 344 | 
         
            -
                            continue
         
     | 
| 345 | 
         
            -
             
     | 
| 346 | 
         
            -
                        pred_mask = pred_mask.detach().cpu().numpy()[0]
         
     | 
| 347 | 
         
            -
                        pred_mask = pred_mask > 0
         
     | 
| 348 | 
         
            -
             
     | 
| 349 | 
         
            -
                        save_img = image_np.copy()
         
     | 
| 350 | 
         
            -
                        save_img[pred_mask] = (
         
     | 
| 351 | 
         
            -
                            image_np * 0.5
         
     | 
| 352 | 
         
            -
                            + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
         
     | 
| 353 | 
         
            -
                        )[pred_mask]
         
     | 
| 354 | 
         
            -
             
     | 
| 355 | 
         
            -
                    output_str = "ASSITANT: " + text_output  # input_str
         
     | 
| 356 | 
         
            -
                    if save_img is not None:
         
     | 
| 357 | 
         
            -
                        output_image = save_img  # input_image
         
     | 
| 358 | 
         
            -
                    else:
         
     | 
| 359 | 
         
            -
                        ## no seg output
         
     | 
| 360 | 
         
            -
                        output_image = cv2.imread("./resources/no_seg_out.png")[:, :, ::-1]
         
     | 
| 361 | 
         
            -
                    return output_image, output_str
         
     | 
| 362 | 
         
            -
             
     | 
| 363 | 
         
            -
                return inference
         
     | 
| 364 | 
         
            -
             
     | 
| 365 | 
         
            -
             
     | 
| 366 | 
         
            -
            def get_gradio_interface(
         
     | 
| 367 | 
         
            -
                    fn_inference: Callable
         
     | 
| 368 | 
         
            -
                ):
         
     | 
| 369 | 
         
            -
                return gr.Interface(
         
     | 
| 370 | 
         
            -
                    fn_inference,
         
     | 
| 371 | 
         
            -
                    inputs=[
         
     | 
| 372 | 
         
            -
                        gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
         
     | 
| 373 | 
         
            -
                        gr.Image(type="filepath", label="Input Image")
         
     | 
| 374 | 
         
            -
                    ],
         
     | 
| 375 | 
         
            -
                    outputs=[
         
     | 
| 376 | 
         
            -
                        gr.Image(type="pil", label="Segmentation Output"),
         
     | 
| 377 | 
         
            -
                        gr.Textbox(lines=1, placeholder=None, label="Text Output"),
         
     | 
| 378 | 
         
            -
                    ],
         
     | 
| 379 | 
         
            -
                    title=title,
         
     | 
| 380 | 
         
            -
                    description=description,
         
     | 
| 381 | 
         
            -
                    article=article,
         
     | 
| 382 | 
         
            -
                    examples=examples,
         
     | 
| 383 | 
         
            -
                    allow_flagging="auto",
         
     | 
| 384 | 
         
            -
                )
         
     | 
| 385 | 
         
            -
             
     | 
| 386 | 
         
            -
             
     | 
| 387 | 
         
            -
            args = parse_args(sys.argv[1:])
         
     | 
| 388 | 
         
            -
            inference_fn = get_inference_model_by_args(args)
         
     | 
| 389 | 
         
            -
            io = get_gradio_interface(inference_fn)
         
     | 
| 390 | 
         
             
            app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            import gradio as gr
         
     | 
| 2 | 
         
            +
            from fastapi import FastAPI
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 3 | 
         | 
| 4 | 
         
            +
            from utils import session_logger
         
     | 
| 5 | 
         | 
| 
         | 
|
| 
         | 
|
| 6 | 
         | 
| 7 | 
         
            +
            CUSTOM_GRADIO_PATH = "/"
         
     | 
| 8 | 
         
            +
            app = FastAPI(title="lisa_app", version="1.0")
         
     | 
| 
         | 
|
| 
         | 
|
| 9 | 
         | 
| 10 | 
         | 
| 11 | 
         
            +
            @app.get("/health")
         
     | 
| 12 | 
         
            +
            @session_logger.set_uuid_logging
         
     | 
| 13 | 
         
            +
            def health() -> str:
         
     | 
| 14 | 
         
            +
                try:
         
     | 
| 15 | 
         
            +
                    logging.info("health check")
         
     | 
| 16 | 
         
            +
                    return json.dumps({"msg": "ok"})
         
     | 
| 17 | 
         
            +
                except Exception as e:
         
     | 
| 18 | 
         
            +
                    logging.error(f"exception:{e}.")
         
     | 
| 19 | 
         
            +
                    return json.dumps({"msg": "request failed"})
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 20 | 
         | 
| 21 | 
         | 
| 22 | 
         
            +
            @session_logger.set_uuid_logging
         
     | 
| 23 | 
         
            +
            def request_formatter(text: str) -> str:
         
     | 
| 24 | 
         
            +
                logging.info("start request formatting...")
         
     | 
| 25 | 
         
            +
                formatted_text = f"transformed {text}."
         
     | 
| 26 | 
         
            +
                logging.info(f"formatted request as {formatted_text}.")
         
     | 
| 27 | 
         
            +
                return formatted_text
         
     | 
| 
         | 
|
| 
         | 
|
| 28 | 
         | 
| 29 | 
         | 
| 30 | 
         
            +
            io = gr.Interface(
         
     | 
| 31 | 
         
            +
                request_formatter,
         
     | 
| 32 | 
         
            +
                inputs=[
         
     | 
| 33 | 
         
            +
                    gr.Textbox(lines=1, placeholder=None, label="Text input"),
         
     | 
| 
         | 
|
| 34 | 
         
             
                ],
         
     | 
| 35 | 
         
            +
                outputs=[
         
     | 
| 36 | 
         
            +
                    gr.Textbox(lines=1, placeholder=None, label="Text Output"),
         
     | 
| 
         | 
|
| 37 | 
         
             
                ],
         
     | 
| 38 | 
         
            +
            )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 39 | 
         
             
            app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)
         
     | 
    	
        utils/session_logger.py
    ADDED
    
    | 
         @@ -0,0 +1,36 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import contextvars
         
     | 
| 2 | 
         
            +
            import logging
         
     | 
| 3 | 
         
            +
            from functools import wraps
         
     | 
| 4 | 
         
            +
            from typing import Callable
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            logging_uuid = contextvars.ContextVar("uuid")
         
     | 
| 7 | 
         
            +
            formatter = '%(asctime)s | %(uuid)s [%(pathname)s:%(module)s %(lineno)d] %(levelname)s | %(message)s'
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            loggingType = logging.CRITICAL | logging.ERROR | logging.WARNING | logging.INFO | logging.DEBUG
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            def change_logging(level_log: loggingType = logging.INFO) -> None:
         
     | 
| 14 | 
         
            +
                old_factory = logging.getLogRecordFactory()
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def record_factory(*args, **kwargs):
         
     | 
| 17 | 
         
            +
                    record = old_factory(*args, **kwargs)
         
     | 
| 18 | 
         
            +
                    record.uuid = logging_uuid.get("uuid")
         
     | 
| 19 | 
         
            +
                    if isinstance(record.msg, str):
         
     | 
| 20 | 
         
            +
                        record.msg = record.msg.replace("\\", "\\\\").replace("\n", "\\n")
         
     | 
| 21 | 
         
            +
                    return record
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                logging.setLogRecordFactory(record_factory)
         
     | 
| 24 | 
         
            +
                logging.basicConfig(level=level_log, format=formatter, force=True)
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            def set_uuid_logging(func: Callable) -> Callable:
         
     | 
| 28 | 
         
            +
                @wraps(func)
         
     | 
| 29 | 
         
            +
                def wrapper(*args, **kwargs):
         
     | 
| 30 | 
         
            +
                    import uuid
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                    current_uuid = f"{uuid.uuid4()}"
         
     | 
| 33 | 
         
            +
                    logging_uuid.set(current_uuid)
         
     | 
| 34 | 
         
            +
                    return func(*args, **kwargs)
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                return wrapper
         
     |