File size: 4,957 Bytes
3cc543c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
Class to handle flagging in Gradio to Gantry.

Originally written by the FSDL educators at https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/blob/main/app_gradio/flagging.py
that has been adjusted for the geolocator project.
"""

import os
from typing import List, Optional, Union

import gantry
import gradio as gr
from gradio.components import Component
from smart_open import open

from .s3_util import (
    add_access_policy,
    enable_bucket_versioning,
    get_or_create_bucket,
    get_uri_of,
    make_key,
)
from .string_img_util import read_b64_string


class GantryImageToTextLogger(gr.FlaggingCallback):
    """
    A FlaggingCallback that logs flagged image-to-text data to Gantry via S3.
    """

    def __init__(
        self,
        application: str,
        version: Union[int, str, None] = None,
        api_key: Optional[str] = None,
    ):
        """Logs image-to-text data that was flagged in Gradio to Gantry.

        Images are logged to Amazon Web Services' Simple Storage Service (S3).

        The flagging_dir provided to the Gradio interface is used to set the
        name of the bucket on S3 into which images are logged.

        See the following tutorial by Dan Bader for a quick overview of S3 and the AWS SDK
        for Python, boto3: https://realpython.com/python-boto3-aws-s3/

        See https://gradio.app/docs/#flagging for details on how
        flagging data is handled by Gradio.

        See https://docs.gantry.io for information about logging data to Gantry.

        Parameters
        ----------
        application
            The name of the application on Gantry to which flagged data should be uploaded.
            Gantry validates and monitors data per application.
        version
            The schema version to use during validation by Gantry. If not provided, Gantry
            will use the latest version. A new version will be created if the provided version
            does not exist yet.
        api_key
            Optionally, provide your Gantry API key here. Provided for convenience
            when testing and developing locally or in notebooks. The API key can
            alternatively be provided via the GANTRY_API_KEY environment variable.
        """
        self.application = application
        self.version = version
        gantry.init(api_key=api_key)

    def setup(self, components: List[Component], flagging_dir: str):
        """Sets up the GantryImageToTextLogger by creating or attaching to an S3 Bucket."""
        self._counter = 0
        self.bucket = get_or_create_bucket(flagging_dir)
        enable_bucket_versioning(self.bucket)
        add_access_policy(self.bucket)
        (
            self.image_component_idx,
            self.text_component_idx,
            self.text_component2_idx,
        ) = self._find_image_video_and_text_components(components)

    def flag(self, flag_data, flag_option=None, flag_index=None, username=None) -> int:
        """Sends flagged outputs and feedback to Gantry and image inputs to S3."""

        image = flag_data[self.image_component_idx]
        text = flag_data[self.text_component_idx]
        text2 = flag_data[self.text_component2_idx]

        feedback = {"flag": flag_option}
        if username is not None:
            feedback["user"] = username

        data_type, image_buffer = read_b64_string(image, return_data_type=True)
        image_url = self._to_s3(image_buffer.read(), filetype=data_type)

        self._to_gantry(
            input_image_url=image_url,
            pred_location=text,
            pred_coordinates=text2,
            feedback=feedback,
        )
        self._counter += 1

        return self._counter

    def _to_gantry(self, input_image_url, pred_location, pred_coordinates, feedback):
        inputs = {"image": input_image_url}
        outputs = {"location": pred_location, "coordinates": pred_coordinates}

        gantry.log_record(
            self.application,
            self.version,
            inputs=inputs,
            outputs=outputs,
            feedback=feedback,
        )

    def _to_s3(self, image_bytes, key=None, filetype=None):
        if key is None:
            key = make_key(image_bytes, filetype=filetype)

        s3_uri = get_uri_of(self.bucket, key)

        with open(s3_uri, "wb") as s3_object:
            s3_object.write(image_bytes)

        return s3_uri

    def _find_image_video_and_text_components(self, components: List[Component]):
        """
        Manual indexing of images and text components
        """

        image_component_idx = 0
        text_component_idx = 1
        text_component2_idx = 2

        return (
            image_component_idx,
            text_component_idx,
            text_component2_idx,
        )


def get_api_key() -> Optional[str]:
    """Convenience method for fetching the Gantry API key."""
    api_key = os.environ.get("GANTRY_API_KEY")
    return api_key