merve HF staff commited on
Commit
8400add
1 Parent(s): 53f5134

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -0
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import string
7
+
8
+ import gradio as gr
9
+ import PIL.Image
10
+ import torch
11
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
12
+ import re
13
+
14
+ DESCRIPTION = "# LLaVA 🌋"
15
+
16
+ def extract_response_pairs(text):
17
+ pattern = re.compile(r'(USER:.*?)ASSISTANT:(.*?)(?:$|USER:)', re.DOTALL)
18
+ matches = pattern.findall(text)
19
+
20
+ pairs = [(user.strip(), assistant.strip()) for user, assistant in matches]
21
+
22
+ return pairs
23
+
24
+
25
+ def postprocess_output(output: str) -> str:
26
+ if output and output[-1] not in string.punctuation:
27
+ output += "."
28
+ return output
29
+
30
+
31
+
32
+ def chat(image, text, temperature, length_penalty,
33
+ repetition_penalty, max_length, min_length, num_beams, top_p,
34
+ history_chat):
35
+
36
+ prompt = " ".join(history_chat)
37
+ prompt = f"USER: <image>\n{text}\nASSISTANT:"
38
+
39
+ outputs = pipe(image, prompt=prompt,
40
+ generate_kwargs={"temperature":temperature,
41
+ "length_penalty":length_penalty,
42
+ "repetition_penalty":repetition_penalty,
43
+ "max_length":max_length,
44
+ "min_length":min_length,
45
+ "num_beams":num_beams,
46
+ "top_p":top_p})
47
+
48
+ output = postprocess_output(outputs[0]["generated_text"])
49
+ history_chat.append(output)
50
+ print(f"history_chat is {history_chat}")
51
+ print(f"user response {user_response}")
52
+ print(f"assistant response {assistant_response}")
53
+
54
+
55
+ chat_val = extract_response_pairs(" ".join(history_chat))
56
+ return chat_val, history_chat
57
+
58
+
59
+ css = """
60
+ #mkd {
61
+ height: 500px;
62
+ overflow: auto;
63
+ border: 1px solid #ccc;
64
+ }
65
+ """
66
+ with gr.Blocks(css="style.css") as demo:
67
+ gr.Markdown(DESCRIPTION)
68
+ gr.Markdown("LLaVA is now available in transformers with 4-bit quantization!")
69
+ chatbot = gr.Chatbot(label="Chat", show_label=False)
70
+ with gr.Row():
71
+ image = gr.Image(type="pil")
72
+ text_input = gr.Text(label="Chat Input", show_label=False, max_lines=1, container=False)
73
+
74
+
75
+
76
+ history_chat = gr.State(value=[])
77
+ with gr.Row():
78
+ clear_chat_button = gr.Button("Clear")
79
+ chat_button = gr.Button("Submit", variant="primary")
80
+ with gr.Accordion(label="Advanced settings", open=False):
81
+ temperature = gr.Slider(
82
+ label="Temperature",
83
+ info="Used with nucleus sampling.",
84
+ minimum=0.5,
85
+ maximum=1.0,
86
+ step=0.1,
87
+ value=1.0,
88
+ )
89
+ length_penalty = gr.Slider(
90
+ label="Length Penalty",
91
+ info="Set to larger for longer sequence, used with beam search.",
92
+ minimum=-1.0,
93
+ maximum=2.0,
94
+ step=0.2,
95
+ value=1.0,
96
+ )
97
+ repetition_penalty = gr.Slider(
98
+ label="Repetition Penalty",
99
+ info="Larger value prevents repetition.",
100
+ minimum=1.0,
101
+ maximum=5.0,
102
+ step=0.5,
103
+ value=1.5,
104
+ )
105
+ max_length = gr.Slider(
106
+ label="Max Length",
107
+ minimum=1,
108
+ maximum=512,
109
+ step=1,
110
+ value=50,
111
+ )
112
+ min_length = gr.Slider(
113
+ label="Minimum Length",
114
+ minimum=1,
115
+ maximum=100,
116
+ step=1,
117
+ value=1,
118
+ )
119
+ num_beams = gr.Slider(
120
+ label="Number of Beams",
121
+ minimum=1,
122
+ maximum=10,
123
+ step=1,
124
+ value=5,
125
+ )
126
+ top_p = gr.Slider(
127
+ label="Top P",
128
+ info="Used with nucleus sampling.",
129
+ minimum=0.5,
130
+ maximum=1.0,
131
+ step=0.1,
132
+ value=0.9,
133
+ )
134
+ chat_output = [
135
+ chatbot,
136
+ history_chat
137
+ ]
138
+ chat_button.click(fn=chat, inputs=[image,
139
+ text_input,
140
+ temperature,
141
+ length_penalty,
142
+ repetition_penalty,
143
+ max_length,
144
+ min_length,
145
+ num_beams,
146
+ top_p,
147
+ history_chat],
148
+ outputs=chat_output,
149
+ api_name="Chat",
150
+ )
151
+
152
+ chat_inputs = [
153
+ image,
154
+ text_input,
155
+ temperature,
156
+ length_penalty,
157
+ repetition_penalty,
158
+ max_length,
159
+ min_length,
160
+ num_beams,
161
+ top_p,
162
+ history_chat
163
+ ]
164
+ text_input.submit(
165
+ fn=chat,
166
+ inputs=chat_inputs,
167
+ outputs=chat_output
168
+ ).success(
169
+ fn=lambda: "",
170
+ outputs=chat_inputs,
171
+ queue=False,
172
+ api_name=False,
173
+ )
174
+ clear_chat_button.click(
175
+ fn=lambda: ([], []),
176
+ inputs=None,
177
+ outputs=[
178
+ chatbot,
179
+ history_chat
180
+ ],
181
+ queue=False,
182
+ api_name="clear",
183
+ )
184
+ image.change(
185
+ fn=lambda: ([], []),
186
+ inputs=None,
187
+ outputs=[
188
+ chatbot,
189
+ history_chat
190
+ ],
191
+ queue=False,
192
+ )
193
+
194
+
195
+ if __name__ == "__main__":
196
+ demo.queue(max_size=10).launch()