File size: 4,680 Bytes
c30a8ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
864cb02
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import streamlit.components.v1 as components
import streamlit as st
import numpy
import sahi.predict
import sahi.utils
from PIL import Image
import pathlib
import os
import uuid

STREAMLIT_STATIC_PATH = pathlib.Path(st.__path__[0]) / "static"


def sahi_mmdet_inference(
    image,
    detection_model,
    slice_height=512,
    slice_width=512,
    overlap_height_ratio=0.2,
    overlap_width_ratio=0.2,
    image_size=640,
    postprocess_type="UNIONMERGE",
    postprocess_match_metric="IOS",
    postprocess_match_threshold=0.5,
    postprocess_class_agnostic=False,
):

    # standard inference
    prediction_result_1 = sahi.predict.get_prediction(
        image=image, detection_model=detection_model, image_size=image_size
    )
    visual_result_1 = sahi.utils.cv.visualize_object_predictions(
        image=numpy.array(image),
        object_prediction_list=prediction_result_1.object_prediction_list,
    )
    output_1 = Image.fromarray(visual_result_1["image"])

    # sliced inference
    prediction_result_2 = sahi.predict.get_sliced_prediction(
        image=image,
        detection_model=detection_model,
        image_size=image_size,
        slice_height=slice_height,
        slice_width=slice_width,
        overlap_height_ratio=overlap_height_ratio,
        overlap_width_ratio=overlap_width_ratio,
        postprocess_type=postprocess_type,
        postprocess_match_metric=postprocess_match_metric,
        postprocess_match_threshold=postprocess_match_threshold,
        postprocess_class_agnostic=postprocess_class_agnostic,
    )
    visual_result_2 = sahi.utils.cv.visualize_object_predictions(
        image=numpy.array(image),
        object_prediction_list=prediction_result_2.object_prediction_list,
    )

    output_2 = Image.fromarray(visual_result_2["image"])

    return output_1, output_2


def imagecompare(
    img1: str,
    img2: str,
    label1: str = "1",
    label2: str = "2",
    width: int = 700,
    show_labels: bool = True,
    starting_position: int = 50,
    make_responsive: bool = True,
):
    """Create a new juxtapose component.
    Parameters
    ----------
    img1: str, PosixPath, PIL.Image or URL
        Input image to compare
    img2: str, PosixPath, PIL.Image or URL
        Input image to compare
    label1: str or None
        Label for image 1
    label2: str or None
        Label for image 2
    width: int or None
        Width of the component in px
    show_labels: bool or None
        Show given labels on images
    starting_position: int or None
        Starting position of the slider as percent (0-100)
    make_responsive: bool or None
        Enable responsive mode
    Returns
    -------
    static_component: Boolean
        Returns a static component with a timeline
    """
    # prepare images
    for file_ in os.listdir(STREAMLIT_STATIC_PATH):
        if file_.endswith(".png") and "favicon" not in file_:
            os.remove(str(STREAMLIT_STATIC_PATH / file_))

    image_1_name = str(uuid.uuid4()) + ".png"
    image_1_path = STREAMLIT_STATIC_PATH / image_1_name
    image_1_path = str(image_1_path.resolve())
    sahi.utils.cv.read_image_as_pil(img1).save(image_1_path)

    image_2_name = str(uuid.uuid4()) + ".png"
    image_2_path = STREAMLIT_STATIC_PATH / image_2_name
    image_2_path = str(image_2_path.resolve())
    sahi.utils.cv.read_image_as_pil(img2).save(image_2_path)

    img_width, img_height = img1.size
    h_to_w = img_height / img_width
    height = width * h_to_w - 20

    # load css + js
    cdn_path = "https://cdn.knightlab.com/libs/juxtapose/latest"
    css_block = f'<link rel="stylesheet" href="{cdn_path}/css/juxtapose.css">'
    js_block = f'<script src="{cdn_path}/js/juxtapose.min.js"></script>'

    # write html block
    htmlcode = f"""
        {css_block}
        {js_block}
        <div id="foo"style="height: '%100'; width: {width or '%100'};"></div>
        <script>
        slider = new juxtapose.JXSlider('#foo',
            [
                {{
                    src: '{image_1_name}',
                    label: '{label1}',
                }},
                {{
                    src: '{image_2_name}',
                    label: '{label2}',
                }}
            ],
            {{
                animate: true,
                showLabels: {'true' if show_labels else 'false'},
                showCredits: true,
                startingPosition: "{starting_position}%",
                makeResponsive: {'true' if make_responsive else 'false'},
            }});
        </script>
        """
    static_component = components.html(htmlcode, height=height, width=width)

    return static_component, image_1_path, image_2_path