initial commit
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -20,7 +20,11 @@ def post_process(img: np.array) -> np.array: 
     | 
|
| 20 | 
         
             
                return img
         
     | 
| 21 | 
         | 
| 22 | 
         | 
| 23 | 
         
            -
            def inference(img_array: np.array) -> np.array:
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 24 | 
         
             
                ort_inputs = {ort_session.get_inputs()[0].name: img_array}
         
     | 
| 25 | 
         
             
                ort_outs = ort_session.run(None, ort_inputs)
         
     | 
| 26 | 
         | 
| 
         @@ -28,6 +32,7 @@ def inference(img_array: np.array) -> np.array: 
     | 
|
| 28 | 
         | 
| 29 | 
         | 
| 30 | 
         
             
            def convert_pil_to_cv2(image):
         
     | 
| 
         | 
|
| 31 | 
         
             
                open_cv_image = np.array(image)
         
     | 
| 32 | 
         
             
                # RGB to BGR
         
     | 
| 33 | 
         
             
                open_cv_image = open_cv_image[:, :, ::-1].copy()
         
     | 
| 
         @@ -35,6 +40,7 @@ def convert_pil_to_cv2(image): 
     | 
|
| 35 | 
         | 
| 36 | 
         | 
| 37 | 
         
             
            def upscale(image):
         
     | 
| 
         | 
|
| 38 | 
         
             
                img = convert_pil_to_cv2(image)
         
     | 
| 39 | 
         
             
                if img.ndim == 2:
         
     | 
| 40 | 
         
             
                    img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
         
     | 
| 
         @@ -42,26 +48,20 @@ def upscale(image): 
     | 
|
| 42 | 
         
             
                if img.shape[2] == 4:
         
     | 
| 43 | 
         
             
                    alpha = img[:, :, 3]  # GRAY
         
     | 
| 44 | 
         
             
                    alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR)  # BGR
         
     | 
| 45 | 
         
            -
                    alpha_output = post_process(inference(pre_process(alpha)))  # BGR
         
     | 
| 46 | 
         
             
                    alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY)  # GRAY
         
     | 
| 47 | 
         | 
| 48 | 
         
             
                    img = img[:, :, 0:3]  # BGR
         
     | 
| 49 | 
         
            -
                    image_output = post_process(inference(pre_process(img)))  # BGR
         
     | 
| 50 | 
         
             
                    image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA)  # BGRA
         
     | 
| 51 | 
         
             
                    image_output[:, :, 3] = alpha_output
         
     | 
| 52 | 
         | 
| 53 | 
         
             
                elif img.shape[2] == 3:
         
     | 
| 54 | 
         
            -
                    image_output = post_process(inference(pre_process(img)))  # BGR
         
     | 
| 55 | 
         | 
| 56 | 
         
             
                return image_output
         
     | 
| 57 | 
         | 
| 58 | 
         | 
| 59 | 
         
            -
            model_path = "models/model.ort"
         
     | 
| 60 | 
         
            -
            options = onnxruntime.SessionOptions()
         
     | 
| 61 | 
         
            -
            options.intra_op_num_threads = 1
         
     | 
| 62 | 
         
            -
            options.inter_op_num_threads = 1
         
     | 
| 63 | 
         
            -
            ort_session = onnxruntime.InferenceSession(model_path, options)
         
     | 
| 64 | 
         
            -
             
     | 
| 65 | 
         
             
            examples = [f"examples/example_{i+1}.png" for i in range(5)]
         
     | 
| 66 | 
         
             
            css = ".output-image, .input-image, .image-preview {height: 480px !important} "
         
     | 
| 67 | 
         | 
| 
         | 
|
| 20 | 
         
             
                return img
         
     | 
| 21 | 
         | 
| 22 | 
         | 
| 23 | 
         
            +
            def inference(model_path: str, img_array: np.array) -> np.array:
         
     | 
| 24 | 
         
            +
                options = onnxruntime.SessionOptions()
         
     | 
| 25 | 
         
            +
                options.intra_op_num_threads = 1
         
     | 
| 26 | 
         
            +
                options.inter_op_num_threads = 1
         
     | 
| 27 | 
         
            +
                ort_session = onnxruntime.InferenceSession(model_path, options)
         
     | 
| 28 | 
         
             
                ort_inputs = {ort_session.get_inputs()[0].name: img_array}
         
     | 
| 29 | 
         
             
                ort_outs = ort_session.run(None, ort_inputs)
         
     | 
| 30 | 
         | 
| 
         | 
|
| 32 | 
         | 
| 33 | 
         | 
| 34 | 
         
             
            def convert_pil_to_cv2(image):
         
     | 
| 35 | 
         
            +
                # pil_image = image.convert("RGB")
         
     | 
| 36 | 
         
             
                open_cv_image = np.array(image)
         
     | 
| 37 | 
         
             
                # RGB to BGR
         
     | 
| 38 | 
         
             
                open_cv_image = open_cv_image[:, :, ::-1].copy()
         
     | 
| 
         | 
|
| 40 | 
         | 
| 41 | 
         | 
| 42 | 
         
             
            def upscale(image):
         
     | 
| 43 | 
         
            +
                model_path = "models/model.ort"
         
     | 
| 44 | 
         
             
                img = convert_pil_to_cv2(image)
         
     | 
| 45 | 
         
             
                if img.ndim == 2:
         
     | 
| 46 | 
         
             
                    img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
         
     | 
| 
         | 
|
| 48 | 
         
             
                if img.shape[2] == 4:
         
     | 
| 49 | 
         
             
                    alpha = img[:, :, 3]  # GRAY
         
     | 
| 50 | 
         
             
                    alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR)  # BGR
         
     | 
| 51 | 
         
            +
                    alpha_output = post_process(inference(model_path, pre_process(alpha)))  # BGR
         
     | 
| 52 | 
         
             
                    alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY)  # GRAY
         
     | 
| 53 | 
         | 
| 54 | 
         
             
                    img = img[:, :, 0:3]  # BGR
         
     | 
| 55 | 
         
            +
                    image_output = post_process(inference(model_path, pre_process(img)))  # BGR
         
     | 
| 56 | 
         
             
                    image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA)  # BGRA
         
     | 
| 57 | 
         
             
                    image_output[:, :, 3] = alpha_output
         
     | 
| 58 | 
         | 
| 59 | 
         
             
                elif img.shape[2] == 3:
         
     | 
| 60 | 
         
            +
                    image_output = post_process(inference(model_path, pre_process(img)))  # BGR
         
     | 
| 61 | 
         | 
| 62 | 
         
             
                return image_output
         
     | 
| 63 | 
         | 
| 64 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 65 | 
         
             
            examples = [f"examples/example_{i+1}.png" for i in range(5)]
         
     | 
| 66 | 
         
             
            css = ".output-image, .input-image, .image-preview {height: 480px !important} "
         
     | 
| 67 | 
         |