File size: 5,642 Bytes
c30a8ce
 
 
 
 
7d22bdf
 
4452beb
 
 
 
c30a8ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d22bdf
 
4452beb
7d22bdf
 
4452beb
 
 
 
 
 
 
 
 
 
7d22bdf
 
 
4452beb
 
 
 
 
 
 
 
 
486e870
c30a8ce
 
 
 
 
 
 
 
4452beb
c30a8ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4452beb
 
c30a8ce
 
 
 
 
 
 
 
4d223e5
c30a8ce
4452beb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d22bdf
c30a8ce
 
 
 
 
 
 
 
 
4d223e5
c30a8ce
 
 
 
7d22bdf
c30a8ce
 
 
7d22bdf
c30a8ce
 
 
 
 
 
 
 
 
 
 
 
 
 
7d22bdf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import streamlit.components.v1 as components
import numpy
import sahi.predict
import sahi.utils
from PIL import Image
import base64
import io
import os
import uuid

TEMP_DIR = "temp"


def sahi_mmdet_inference(
    image,
    detection_model,
    slice_height=512,
    slice_width=512,
    overlap_height_ratio=0.2,
    overlap_width_ratio=0.2,
    image_size=640,
    postprocess_type="UNIONMERGE",
    postprocess_match_metric="IOS",
    postprocess_match_threshold=0.5,
    postprocess_class_agnostic=False,
):

    # standard inference
    prediction_result_1 = sahi.predict.get_prediction(
        image=image, detection_model=detection_model, image_size=image_size
    )
    visual_result_1 = sahi.utils.cv.visualize_object_predictions(
        image=numpy.array(image),
        object_prediction_list=prediction_result_1.object_prediction_list,
    )
    output_1 = Image.fromarray(visual_result_1["image"])

    # sliced inference
    prediction_result_2 = sahi.predict.get_sliced_prediction(
        image=image,
        detection_model=detection_model,
        image_size=image_size,
        slice_height=slice_height,
        slice_width=slice_width,
        overlap_height_ratio=overlap_height_ratio,
        overlap_width_ratio=overlap_width_ratio,
        postprocess_type=postprocess_type,
        postprocess_match_metric=postprocess_match_metric,
        postprocess_match_threshold=postprocess_match_threshold,
        postprocess_class_agnostic=postprocess_class_agnostic,
    )
    visual_result_2 = sahi.utils.cv.visualize_object_predictions(
        image=numpy.array(image),
        object_prediction_list=prediction_result_2.object_prediction_list,
    )

    output_2 = Image.fromarray(visual_result_2["image"])

    return output_1, output_2


def pillow_to_base64(image: Image.Image):
    in_mem_file = io.BytesIO()
    image.save(in_mem_file, format="JPEG", subsampling=0, quality=100)
    img_bytes = in_mem_file.getvalue()  # bytes
    image_str = base64.b64encode(img_bytes).decode("utf-8")
    base64_src = f"data:image/jpg;base64,{image_str}"
    return base64_src


def local_file_to_base64(image_path: str):
    file_ = open(image_path, "rb")
    img_bytes = file_.read()
    image_str = base64.b64encode(img_bytes).decode("utf-8")
    file_.close()
    base64_src = f"data:image/jpg;base64,{image_str}"
    return base64_src


def pillow_local_file_to_base64(image: Image.Image):
    # pillow to local file
    img_path = TEMP_DIR + "/" + str(uuid.uuid4()) + ".jpg"
    image.save(img_path, subsampling=0, quality=100)
    # local file base64 str
    base64_src = local_file_to_base64(img_path)
    return base64_src


def image_comparison(
    img1: str,
    img2: str,
    label1: str = "1",
    label2: str = "2",
    width: int = 700,
    show_labels: bool = True,
    starting_position: int = 50,
    make_responsive: bool = True,
    in_memory=False,
):
    """Create a new juxtapose component.
    Parameters
    ----------
    img1: str, PosixPath, PIL.Image or URL
        Input image to compare
    img2: str, PosixPath, PIL.Image or URL
        Input image to compare
    label1: str or None
        Label for image 1
    label2: str or None
        Label for image 2
    width: int or None
        Width of the component in px
    show_labels: bool or None
        Show given labels on images
    starting_position: int or None
        Starting position of the slider as percent (0-100)
    make_responsive: bool or None
        Enable responsive mode
    in_memory: bool or None
        Handle pillow to base64 conversion in memory without saving to local
    Returns
    -------
    static_component: Boolean
        Returns a static component with a timeline
    """
    # prepare images
    img_width, img_height = img1.size
    h_to_w = img_height / img_width
    height = (width * h_to_w) * 0.95

    img1_pillow = sahi.utils.cv.read_image_as_pil(img1)
    img2_pillow = sahi.utils.cv.read_image_as_pil(img2)

    if in_memory:
        # create base64 str from pillow images
        img1 = pillow_to_base64(img1_pillow)
        img2 = pillow_to_base64(img2_pillow)
    else:
        # clean temp dir
        os.makedirs(TEMP_DIR, exist_ok=True)
        for file_ in os.listdir(TEMP_DIR):
            if file_.endswith(".jpg"):
                os.remove(TEMP_DIR + "/" + file_)
        # create base64 str from pillow images
        img1 = pillow_local_file_to_base64(img1_pillow)
        img2 = pillow_local_file_to_base64(img2_pillow)

    # load css + js
    cdn_path = "https://cdn.knightlab.com/libs/juxtapose/latest"
    css_block = f'<link rel="stylesheet" href="{cdn_path}/css/juxtapose.css">'
    js_block = f'<script src="{cdn_path}/js/juxtapose.min.js"></script>'

    # write html block
    htmlcode = f"""
        {css_block}
        {js_block}
        <div id="foo"style="height: {height}; width: {width or '%100'};"></div>
        <script>
        slider = new juxtapose.JXSlider('#foo',
            [
                {{
                    src: '{img1}',
                    label: '{label1}',
                }},
                {{
                    src: '{img2}',
                    label: '{label2}',
                }}
            ],
            {{
                animate: true,
                showLabels: {'true' if show_labels else 'false'},
                showCredits: true,
                startingPosition: "{starting_position}%",
                makeResponsive: {'true' if make_responsive else 'false'},
            }});
        </script>
        """
    static_component = components.html(htmlcode, height=height, width=width)

    return static_component