luodian commited on
Commit
f49524d
1 Parent(s): e09648c
Files changed (2) hide show
  1. app.py +166 -0
  2. utils.py +82 -0
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import datetime
3
+ import json
4
+ import base64
5
+ from PIL import Image
6
+ import gradio as gr
7
+ import hashlib
8
+ import requests
9
+ from utils import build_logger
10
+ import io
11
+
12
+ LOGDIR = "log"
13
+ logger = build_logger("otter", LOGDIR)
14
+
15
+ no_change_btn = gr.Button.update()
16
+ enable_btn = gr.Button.update(interactive=True)
17
+ disable_btn = gr.Button.update(interactive=False)
18
+
19
+
20
+ def decode_image(encoded_image: str) -> Image:
21
+ decoded_bytes = base64.b64decode(encoded_image.encode("utf-8"))
22
+ buffer = io.BytesIO(decoded_bytes)
23
+ image = Image.open(buffer)
24
+ return image
25
+
26
+
27
+ def encode_image(image: Image.Image, format: str = "PNG") -> str:
28
+ with io.BytesIO() as buffer:
29
+ image.save(buffer, format=format)
30
+ encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
31
+ return encoded_image
32
+
33
+
34
+ def get_conv_log_filename():
35
+ t = datetime.datetime.now()
36
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
37
+ return name
38
+
39
+
40
+ def get_conv_image_dir():
41
+ name = os.path.join(LOGDIR, "images")
42
+ os.makedirs(name, exist_ok=True)
43
+ return name
44
+
45
+
46
+ def get_image_name(image, image_dir=None):
47
+ buffer = io.BytesIO()
48
+ image.save(buffer, format="PNG")
49
+ image_bytes = buffer.getvalue()
50
+ md5 = hashlib.md5(image_bytes).hexdigest()
51
+
52
+ if image_dir is not None:
53
+ image_name = os.path.join(image_dir, md5 + ".png")
54
+ else:
55
+ image_name = md5 + ".png"
56
+
57
+ return image_name
58
+
59
+
60
+ def resize_image(image, max_size):
61
+ width, height = image.size
62
+ aspect_ratio = float(width) / float(height)
63
+
64
+ if width > height:
65
+ new_width = max_size
66
+ new_height = int(new_width / aspect_ratio)
67
+ else:
68
+ new_height = max_size
69
+ new_width = int(new_height * aspect_ratio)
70
+
71
+ resized_image = image.resize((new_width, new_height))
72
+ return resized_image
73
+
74
+ def http_bot(image_input, text_input, request: gr.Request):
75
+ logger.info(f"http_bot. ip: {request.client.host}")
76
+ print(f"Prompt request: {text_input}")
77
+
78
+ base64_image_str = encode_image(image_input)
79
+
80
+ payload = {
81
+ "content": [
82
+ {
83
+ "prompt": text_input,
84
+ "image": base64_image_str,
85
+ }
86
+ ],
87
+ "token": "sk-OtterHD",
88
+ }
89
+
90
+ print(
91
+ "request: ",
92
+ {
93
+ "prompt": text_input,
94
+ "image": base64_image_str[:10],
95
+ },
96
+ )
97
+
98
+ url = "http://10.128.0.40:8890/app/otter"
99
+ headers = {"Content-Type": "application/json"}
100
+
101
+ response = requests.post(url, headers=headers, data=json.dumps(payload))
102
+ results = response.json()
103
+ print("response: ", {"result": results["result"]})
104
+ return results["result"]
105
+
106
+ title = """
107
+ # OTTER-HD: A High-Resolution Multi-modality Model
108
+ [[Otter Codebase]](https://github.com/Luodian/Otter) [[Paper]]() [[Checkpoints & Benchmarks]](https://huggingface.co/Otter-AI)
109
+
110
+ """
111
+
112
+ css = """
113
+ #mkd {
114
+ height: 1000px;
115
+ overflow: auto;
116
+ border: 1px solid #ccc;
117
+ }
118
+ """
119
+
120
+ if __name__ == "__main__":
121
+ with gr.Blocks(css=css) as demo:
122
+ gr.Markdown(title)
123
+ dialog_state = gr.State()
124
+ input_state = gr.State()
125
+ with gr.Tab("Ask a Question"):
126
+ with gr.Row(equal_height=True):
127
+ with gr.Column(scale=2):
128
+ image_input = gr.Image(label="Upload a High-Res Image", type="pil").style(height=600)
129
+ with gr.Column(scale=1):
130
+ vqa_output = gr.Textbox(label="Output").style(height=600)
131
+ text_input = gr.Textbox(label="Ask a Question")
132
+
133
+ vqa_btn = gr.Button("Send It")
134
+
135
+ gr.Examples(
136
+ [
137
+ [
138
+ "./assets/IMG_00095.png",
139
+ "How many camels are inside this image?",
140
+ ],
141
+ [
142
+ "./assets/IMG_00095.png",
143
+ "How many people are inside this image?",
144
+ ],
145
+ [
146
+ "./assets/IMG_00012.png",
147
+ "How many apples are there?",
148
+ ],
149
+ # ["./assets/./IMG_00012.png", "How many apples are there? Count them row by row."],
150
+ [
151
+ "./assets/IMG_00080.png",
152
+ "What is this and where is it from?",
153
+ ],
154
+ [
155
+ "./assets/IMG_00094.png",
156
+ "What's important on this website?",
157
+ ],
158
+ ],
159
+ inputs=[image_input, text_input],
160
+ outputs=[vqa_output],
161
+ fn=http_bot,
162
+ label="Click on any Examples below👇",
163
+ )
164
+ vqa_btn.click(fn=http_bot, inputs=[image_input, text_input], outputs=vqa_output)
165
+
166
+ demo.launch()
utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import logging.handlers
3
+ import os
4
+ import sys
5
+
6
+ handler = None
7
+
8
+
9
+ def build_logger(logger_name, logger_dir):
10
+ global handler
11
+
12
+ formatter = logging.Formatter(
13
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
14
+ datefmt="%Y-%m-%d %H:%M:%S",
15
+ )
16
+
17
+ # Set the format of root handlers
18
+ if not logging.getLogger().handlers:
19
+ logging.basicConfig(level=logging.INFO)
20
+ logging.getLogger().handlers[0].setFormatter(formatter)
21
+
22
+ # Redirect stdout and stderr to loggers
23
+ stdout_logger = logging.getLogger("stdout")
24
+ stdout_logger.setLevel(logging.INFO)
25
+ sl = StreamToLogger(stdout_logger, logging.INFO)
26
+ sys.stdout = sl
27
+
28
+ stderr_logger = logging.getLogger("stderr")
29
+ stderr_logger.setLevel(logging.ERROR)
30
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
31
+ sys.stderr = sl
32
+
33
+ # Get logger
34
+ logger = logging.getLogger(logger_name)
35
+ logger.setLevel(logging.INFO)
36
+
37
+ # Add a file handler for all loggers
38
+ if handler is None:
39
+ os.makedirs(logger_dir, exist_ok=True)
40
+ filename = os.path.join(logger_dir, logger_name + ".log")
41
+ handler = logging.handlers.TimedRotatingFileHandler(filename, when="D", utc=True)
42
+ handler.setFormatter(formatter)
43
+
44
+ for name, item in logging.root.manager.loggerDict.items():
45
+ if isinstance(item, logging.Logger):
46
+ item.addHandler(handler)
47
+
48
+ return logger
49
+
50
+
51
+ class StreamToLogger(object):
52
+ """
53
+ Fake file-like stream object that redirects writes to a logger instance.
54
+ """
55
+
56
+ def __init__(self, logger, log_level=logging.INFO):
57
+ self.terminal = sys.stdout
58
+ self.logger = logger
59
+ self.log_level = log_level
60
+ self.linebuf = ""
61
+
62
+ def __getattr__(self, attr):
63
+ return getattr(self.terminal, attr)
64
+
65
+ def write(self, buf):
66
+ temp_linebuf = self.linebuf + buf
67
+ self.linebuf = ""
68
+ for line in temp_linebuf.splitlines(True):
69
+ # From the io.TextIOWrapper docs:
70
+ # On output, if newline is None, any '\n' characters written
71
+ # are translated to the system default line separator.
72
+ # By default sys.stdout.write() expects '\n' newlines and then
73
+ # translates them so this is still cross platform.
74
+ if line[-1] == "\n":
75
+ self.logger.log(self.log_level, line.rstrip())
76
+ else:
77
+ self.linebuf += line
78
+
79
+ def flush(self):
80
+ if self.linebuf != "":
81
+ self.logger.log(self.log_level, self.linebuf.rstrip())
82
+ self.linebuf = ""