Spaces:
Runtime error
Runtime error
fancyfeast
commited on
Commit
·
2a3d557
1
Parent(s):
f4d3067
Man the chatinterface is weird #3
Browse files
app.py
CHANGED
|
@@ -72,10 +72,10 @@ assert isinstance(end_of_header_id, int) and isinstance(end_of_turn_id, int)
|
|
| 72 |
|
| 73 |
@spaces.GPU()
|
| 74 |
@torch.no_grad()
|
| 75 |
-
def chat_joycaption(message: dict, history, temperature: float, max_new_tokens: int) -> Generator[str, None, None]:
|
| 76 |
torch.cuda.empty_cache()
|
| 77 |
|
| 78 |
-
|
| 79 |
|
| 80 |
# Prompts are always stripped in training for now
|
| 81 |
prompt = message['text'].strip()
|
|
@@ -88,7 +88,8 @@ def chat_joycaption(message: dict, history, temperature: float, max_new_tokens:
|
|
| 88 |
image = Image.open(message["files"][0])
|
| 89 |
|
| 90 |
# Log the prompt
|
| 91 |
-
|
|
|
|
| 92 |
|
| 93 |
# Preprocess image
|
| 94 |
# NOTE: I found the default processor for so400M to have worse results than just using PIL directly
|
|
@@ -148,7 +149,7 @@ def chat_joycaption(message: dict, history, temperature: float, max_new_tokens:
|
|
| 148 |
use_cache=True,
|
| 149 |
temperature=temperature,
|
| 150 |
top_k=None,
|
| 151 |
-
top_p=
|
| 152 |
streamer=streamer,
|
| 153 |
)
|
| 154 |
|
|
@@ -170,14 +171,14 @@ textbox = gr.MultimodalTextbox(file_types=["image"], file_count="single")
|
|
| 170 |
with gr.Blocks() as demo:
|
| 171 |
gr.HTML(TITLE)
|
| 172 |
gr.Markdown(DESCRIPTION)
|
| 173 |
-
gr.ChatInterface(
|
| 174 |
fn=chat_joycaption,
|
| 175 |
chatbot=chatbot,
|
| 176 |
type="messages",
|
| 177 |
fill_height=True,
|
| 178 |
multimodal=True,
|
| 179 |
textbox=textbox,
|
| 180 |
-
additional_inputs_accordion=
|
| 181 |
additional_inputs=[
|
| 182 |
gr.Slider(minimum=0,
|
| 183 |
maximum=1,
|
|
@@ -185,23 +186,27 @@ with gr.Blocks() as demo:
|
|
| 185 |
value=0.6,
|
| 186 |
label="Temperature",
|
| 187 |
render=False),
|
| 188 |
-
gr.Slider(minimum=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
maximum=4096,
|
| 190 |
step=1,
|
| 191 |
value=1024,
|
| 192 |
label="Max new tokens",
|
| 193 |
render=False ),
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
['How to setup a human base on Mars? Give short answer.'],
|
| 197 |
-
['Explain theory of relativity to me like I’m 8 years old.'],
|
| 198 |
-
['What is 9,000 * 9,000?'],
|
| 199 |
-
['Write a pun-filled happy birthday message to my friend Alex.'],
|
| 200 |
-
['Justify why a penguin might make a good king of the jungle.']
|
| 201 |
-
],
|
| 202 |
-
cache_examples=False,
|
| 203 |
)
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
if __name__ == "__main__":
|
| 207 |
demo.launch()
|
|
|
|
| 72 |
|
| 73 |
@spaces.GPU()
|
| 74 |
@torch.no_grad()
|
| 75 |
+
def chat_joycaption(message: dict, history, temperature: float, top_p: float, max_new_tokens: int, log_prompt: bool) -> Generator[str, None, None]:
|
| 76 |
torch.cuda.empty_cache()
|
| 77 |
|
| 78 |
+
chat_interface.chatbot_state
|
| 79 |
|
| 80 |
# Prompts are always stripped in training for now
|
| 81 |
prompt = message['text'].strip()
|
|
|
|
| 88 |
image = Image.open(message["files"][0])
|
| 89 |
|
| 90 |
# Log the prompt
|
| 91 |
+
if log_prompt:
|
| 92 |
+
print(f"Prompt: {prompt}")
|
| 93 |
|
| 94 |
# Preprocess image
|
| 95 |
# NOTE: I found the default processor for so400M to have worse results than just using PIL directly
|
|
|
|
| 149 |
use_cache=True,
|
| 150 |
temperature=temperature,
|
| 151 |
top_k=None,
|
| 152 |
+
top_p=top_p,
|
| 153 |
streamer=streamer,
|
| 154 |
)
|
| 155 |
|
|
|
|
| 171 |
with gr.Blocks() as demo:
|
| 172 |
gr.HTML(TITLE)
|
| 173 |
gr.Markdown(DESCRIPTION)
|
| 174 |
+
chat_interface = gr.ChatInterface(
|
| 175 |
fn=chat_joycaption,
|
| 176 |
chatbot=chatbot,
|
| 177 |
type="messages",
|
| 178 |
fill_height=True,
|
| 179 |
multimodal=True,
|
| 180 |
textbox=textbox,
|
| 181 |
+
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=True, render=True),
|
| 182 |
additional_inputs=[
|
| 183 |
gr.Slider(minimum=0,
|
| 184 |
maximum=1,
|
|
|
|
| 186 |
value=0.6,
|
| 187 |
label="Temperature",
|
| 188 |
render=False),
|
| 189 |
+
gr.Slider(minimum=0,
|
| 190 |
+
maximum=1,
|
| 191 |
+
step=0.05,
|
| 192 |
+
value=0.9,
|
| 193 |
+
label="Top p",
|
| 194 |
+
render=False),
|
| 195 |
+
gr.Slider(minimum=8,
|
| 196 |
maximum=4096,
|
| 197 |
step=1,
|
| 198 |
value=1024,
|
| 199 |
label="Max new tokens",
|
| 200 |
render=False ),
|
| 201 |
+
gr.Checkbox(label="Help improve JoyCaption by logging your text query", default=True, render=True),
|
| 202 |
+
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
)
|
| 204 |
|
| 205 |
+
def new_trim_history(self, message, history_with_input):
|
| 206 |
+
return message, []
|
| 207 |
+
|
| 208 |
+
chat_interface._process_msg_and_trim_history = new_trim_history.__get__(chat_interface, chat_interface.__class__)
|
| 209 |
+
|
| 210 |
|
| 211 |
if __name__ == "__main__":
|
| 212 |
demo.launch()
|