File size: 4,382 Bytes
c30a8ce
 
 
 
 
7d22bdf
 
c30a8ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d22bdf
 
 
 
 
 
 
 
 
c30a8ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d223e5
c30a8ce
7d22bdf
 
 
c30a8ce
 
 
 
 
 
 
 
 
4d223e5
c30a8ce
 
 
 
7d22bdf
c30a8ce
 
 
7d22bdf
c30a8ce
 
 
 
 
 
 
 
 
 
 
 
 
 
7d22bdf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import streamlit.components.v1 as components
import numpy
import sahi.predict
import sahi.utils
from PIL import Image
import base64
import io


def sahi_mmdet_inference(
    image,
    detection_model,
    slice_height=512,
    slice_width=512,
    overlap_height_ratio=0.2,
    overlap_width_ratio=0.2,
    image_size=640,
    postprocess_type="UNIONMERGE",
    postprocess_match_metric="IOS",
    postprocess_match_threshold=0.5,
    postprocess_class_agnostic=False,
):

    # standard inference
    prediction_result_1 = sahi.predict.get_prediction(
        image=image, detection_model=detection_model, image_size=image_size
    )
    visual_result_1 = sahi.utils.cv.visualize_object_predictions(
        image=numpy.array(image),
        object_prediction_list=prediction_result_1.object_prediction_list,
    )
    output_1 = Image.fromarray(visual_result_1["image"])

    # sliced inference
    prediction_result_2 = sahi.predict.get_sliced_prediction(
        image=image,
        detection_model=detection_model,
        image_size=image_size,
        slice_height=slice_height,
        slice_width=slice_width,
        overlap_height_ratio=overlap_height_ratio,
        overlap_width_ratio=overlap_width_ratio,
        postprocess_type=postprocess_type,
        postprocess_match_metric=postprocess_match_metric,
        postprocess_match_threshold=postprocess_match_threshold,
        postprocess_class_agnostic=postprocess_class_agnostic,
    )
    visual_result_2 = sahi.utils.cv.visualize_object_predictions(
        image=numpy.array(image),
        object_prediction_list=prediction_result_2.object_prediction_list,
    )

    output_2 = Image.fromarray(visual_result_2["image"])

    return output_1, output_2


def pillow_to_base64(image: Image.Image):
    in_mem_file = io.BytesIO()
    image.save(in_mem_file, format="PNG")
    img_bytes = in_mem_file.getvalue()  # bytes
    image_str = base64.b64encode(img_bytes).decode("utf-8")
    base64_src = f"data:image/png;base64,{image_str}"
    return base64_src


def imagecompare(
    img1: str,
    img2: str,
    label1: str = "1",
    label2: str = "2",
    width: int = 700,
    show_labels: bool = True,
    starting_position: int = 50,
    make_responsive: bool = True,
):
    """Create a new juxtapose component.
    Parameters
    ----------
    img1: str, PosixPath, PIL.Image or URL
        Input image to compare
    img2: str, PosixPath, PIL.Image or URL
        Input image to compare
    label1: str or None
        Label for image 1
    label2: str or None
        Label for image 2
    width: int or None
        Width of the component in px
    show_labels: bool or None
        Show given labels on images
    starting_position: int or None
        Starting position of the slider as percent (0-100)
    make_responsive: bool or None
        Enable responsive mode
    Returns
    -------
    static_component: Boolean
        Returns a static component with a timeline
    """
    # prepare images
    img_width, img_height = img1.size
    h_to_w = img_height / img_width
    height = (width * h_to_w) * 0.95

    img1 = pillow_to_base64(sahi.utils.cv.read_image_as_pil(img1))
    img2 = pillow_to_base64(sahi.utils.cv.read_image_as_pil(img2))

    # load css + js
    cdn_path = "https://cdn.knightlab.com/libs/juxtapose/latest"
    css_block = f'<link rel="stylesheet" href="{cdn_path}/css/juxtapose.css">'
    js_block = f'<script src="{cdn_path}/js/juxtapose.min.js"></script>'

    # write html block
    htmlcode = f"""
        {css_block}
        {js_block}
        <div id="foo"style="height: {height}; width: {width or '%100'};"></div>
        <script>
        slider = new juxtapose.JXSlider('#foo',
            [
                {{
                    src: '{img1}',
                    label: '{label1}',
                }},
                {{
                    src: '{img2}',
                    label: '{label2}',
                }}
            ],
            {{
                animate: true,
                showLabels: {'true' if show_labels else 'false'},
                showCredits: true,
                startingPosition: "{starting_position}%",
                makeResponsive: {'true' if make_responsive else 'false'},
            }});
        </script>
        """
    static_component = components.html(htmlcode, height=height, width=width)

    return static_component