hugo1234 commited on
Commit
e64ca11
·
1 Parent(s): a43103a

Create app_old.py

Browse files
Files changed (1) hide show
  1. app_old.py +674 -0
app_old.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Hugging Face's logo Hugging Face
3
+
4
+ Models
5
+ Datasets
6
+ Spaces
7
+ Docs
8
+ Pricing
9
+
10
+ Hugging Face is way more fun with friends and colleagues! 🤗 Join an organization
11
+ Spaces:
12
+ hugo1234
13
+ /
14
+ galileo
15
+ private
16
+ App
17
+ Files
18
+ Community
19
+ Settings
20
+ galileo
21
+ / app.py
22
+ hugo1234's picture
23
+ hugo1234
24
+ Update app.py
25
+ a43103a
26
+ about 16 hours ago
27
+ raw
28
+ history
29
+ blame
30
+ 22.8 kB
31
+ import os
32
+ os.system('pip install bitsandbytes')
33
+ os.system('pip install -q datasets loralib sentencepiece accelerate')
34
+ # os.system('pip install -q git+https://github.com/zphang/transformers@c3dc391')
35
+ # os.system('pip install -q git+https://github.com/huggingface/transformers')
36
+ os.system('pip install -q git+https://github.com/mbehm/transformers')
37
+ os.system('pip install -q git+https://github.com/huggingface/peft.git')
38
+ # os.system('pip install gradio')
39
+ # os.system('pip install torch')
40
+ # os.system('pip install peft')
41
+ # os.system('pip install transformers')
42
+ os.system('pip install tenacity')
43
+ os.system('pip install scipy')
44
+ # os.system('pip install sentencepiece')
45
+
46
+ import re
47
+ import yaml
48
+ import gc
49
+ import copy
50
+ import time
51
+ from tenacity import RetryError
52
+ from tenacity import retry, stop_after_attempt, wait_fixed
53
+ import gradio as gr
54
+ # import torch
55
+ from peft import PeftModel
56
+ from transformers import (
57
+ LLaMATokenizer,
58
+ LlamaForCausalLM,
59
+ GenerationConfig,
60
+ AutoModelForCausalLM,
61
+ AutoModelForSeq2SeqLM,
62
+ AutoTokenizer,
63
+ LogitsProcessorList,
64
+ MinNewTokensLengthLogitsProcessor,
65
+ TemperatureLogitsWarper,
66
+ TopPLogitsWarper,
67
+ MinLengthLogitsProcessor
68
+ )
69
+
70
+ # assert torch.cuda.is_available(), "Change the runtime type to GPU"
71
+
72
+ # constants
73
+ num_of_characters_to_keep = 1000
74
+
75
+ # regex
76
+ html_tag_pattern = re.compile(r"<.*?>")
77
+ multi_line_pattern = re.compile(r"\n+")
78
+ multi_space_pattern = re.compile(r"( )")
79
+ multi_br_tag_pattern = re.compile(re.compile(r'<br>\s*(<br>\s*)*'))
80
+
81
+ # repl is short for replacement
82
+ repl_linebreak = "\n"
83
+ repl_empty_str = ""
84
+
85
+ TITLE = "Galileo"
86
+
87
+ ABSTRACT = """
88
+ Stambecco is a Italian Instruction-following model based on the [LLaMA](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) model. It comes in two versions: 7b and 13b parameters. It is trained on an Italian version of the [GPT-4-LLM](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) dataset, a dataset of `GPT-4` generated instruction-following data.
89
+ This demo is intended to show and evaluate the conversational capabilities of the model.
90
+ For more information, please visit [the project's website](https://github.com/mchl-labs/stambecco).
91
+ NOTE: Too long input (context, instruction) will not be allowed. Please keep context < 500 and instruction < 150
92
+ """
93
+
94
+ BOTTOM_LINE = """
95
+ By default, this demo runs with streaming mode, but you can also run with dynamic batch generation model.
96
+ Stambecco is built on the same concept as Standford Alpaca project, but using LoRA it lets us train and inference on a smaller GPUs such as RTX4090 for 7B version. Also, we could build very small size of checkpoints on top of base models thanks to [🤗 transformers](https://huggingface.co/docs/transformers/index), [🤗 peft](https://github.com/huggingface/peft), and [bitsandbytes](https://github.com/TimDettmers/bitsandbytes/tree/main) libraries.
97
+ This demo currently runs 8Bit 7b version of the model.
98
+ """
99
+
100
+ DEFAULT_EXAMPLES = {
101
+ "Typical Questions": [
102
+ {
103
+ "title": "Parlami di Giulio Cesare.",
104
+ "examples": [
105
+ ["1", "Scrivi un articolo su Giulio Cesare"],
106
+ ["2", "Davvero?"],
107
+ ["3", "Quanto era ricco Giulio Cesare?"],
108
+ ["4", "Chi è stato il suo successore?"],
109
+ ]
110
+ },
111
+ {
112
+ "title": "Parigi",
113
+ "examples": [
114
+ ["1", "Scrivi un tema sulla città di Parigi"],
115
+ ["2", "Fai un elenco di 5 posti da visitare assolutamente"],
116
+ ["3", "Quali eventi importanti della Storia sono avvenuti a Parigi?"],
117
+ ["4", "Quale è il periodo migliore per visitare Parigi?"],
118
+ ]
119
+ },
120
+ {
121
+ "title": "Scrivi un programma in Python che stampi i primi 10 numeri di Fibonacci",
122
+ "examples": [
123
+ ["1", "Scrivi un programma in Python che stampi i primi 10 numeri di Fibonacci"],
124
+ ["2", "Potresti spiegarmi come funziona il codice?"],
125
+ ["3", "Cos'è la ricorsione?"],
126
+ ]
127
+ }
128
+ ],
129
+ }
130
+
131
+ SPECIAL_STRS = {
132
+ "continue": "continua",
133
+ "summarize": "Di cosa abbiamo discusso finora? Descrivi nella user's view."
134
+ }
135
+
136
+ PARENT_BLOCK_CSS = """
137
+ #col_container {
138
+ width: 95%;
139
+ margin-left: auto;
140
+ margin-right: auto;
141
+ }
142
+ #chatbot {
143
+ height: 500px;
144
+ overflow: auto;
145
+ }
146
+ """
147
+
148
+ def load_model(
149
+ base="decapoda-research/llama-7b-hf",
150
+ finetuned="mchl-labs/stambecco-7b-plus",
151
+ ):
152
+ tokenizer = LLaMATokenizer.from_pretrained(base)
153
+ tokenizer.pad_token_id = 0
154
+ tokenizer.padding_side = "left"
155
+
156
+ model = LlamaForCausalLM.from_pretrained(
157
+ base,
158
+ load_in_8bit=True,
159
+ device_map="from_pretrained",
160
+ # load_in_8bit_fp32_cpu_offload=True
161
+ )
162
+ # model = PeftModel.from_pretrained(model, finetuned, device_map={'': 0})
163
+
164
+ model = PeftModel.from_pretrained(model, finetuned)
165
+ return model, tokenizer
166
+
167
+ def get_generation_config(path):
168
+ with open(path, 'rb') as f:
169
+ generation_config = yaml.safe_load(f.read())
170
+
171
+ return GenerationConfig(**generation_config["generation_config"])
172
+
173
+ def generate_prompt(prompt, histories, ctx=None, partial=False):
174
+ convs = f"""Di seguito è riportata una cronologia delle istruzioni che descrivono le tasks, abbinate a un input che fornisce ulteriore contesto. Scrivi una risposta che completi adeguatamente la richiesta ricordando la cronologia della conversazione.
175
+ """
176
+
177
+ if ctx is not None:
178
+ convs = f"""### Input: {ctx}
179
+ """
180
+
181
+ sub_convs = ""
182
+ start_idx = 0
183
+
184
+ for idx, history in enumerate(histories):
185
+ history_prompt = history[0]
186
+ history_response = history[1]
187
+ if history_response == "✅ Riepilogo della conversazione effettuato e impostato come contesto" or history_prompt == SPECIAL_STRS["summarize"]:
188
+ start_idx = idx
189
+
190
+ # drop the previous conversations if user has summarized
191
+ for history in histories[start_idx if start_idx == 0 else start_idx+1:]:
192
+ history_prompt = history[0]
193
+ history_response = history[1]
194
+
195
+ history_response = history_response.replace("<br>", "\n")
196
+ history_response = re.sub(
197
+ html_tag_pattern, repl_empty_str, history_response
198
+ )
199
+
200
+ sub_convs = sub_convs + f"""### Istruzione: {history_prompt}
201
+ ### Risposta: {history_response}
202
+ """
203
+
204
+ sub_convs = sub_convs + f"""### Istruzione: {prompt}
205
+ ### Risposta:"""
206
+
207
+ convs = convs + sub_convs
208
+ return sub_convs if partial else convs, len(sub_convs)
209
+
210
+ def common_post_process(original_str):
211
+ original_str = re.sub(
212
+ multi_line_pattern, repl_linebreak, original_str
213
+ )
214
+ return original_str
215
+
216
+ def post_process_stream(bot_response):
217
+ # sometimes model spits out text containing
218
+ # "### Risposta:" and "### Istruzione: -> in this case, we want to stop generating
219
+ if "### Risposta:" in bot_response or "### Input:" in bot_response:
220
+ bot_response = bot_response.replace("### Risposta:", '').replace("### Input:", '').strip()
221
+ return bot_response, True
222
+
223
+ return common_post_process(bot_response), False
224
+
225
+ def post_process_batch(bot_response):
226
+ bot_response = bot_response.split("### Risposta:")[-1].strip()
227
+ return common_post_process(bot_response)
228
+
229
+ def post_processes_batch(bot_responses):
230
+ return [post_process_batch(r) for r in bot_responses]
231
+
232
+ def get_output_batch(
233
+ model, tokenizer, prompts, generation_config
234
+ ):
235
+ if len(prompts) == 1:
236
+ encoding = tokenizer(prompts, return_tensors="pt")
237
+ input_ids = encoding["input_ids"].cuda()
238
+ generated_id = model.generate(
239
+ input_ids=input_ids,
240
+ generation_config=generation_config,
241
+ max_new_tokens=256
242
+ )
243
+
244
+ decoded = tokenizer.batch_decode(generated_id)
245
+ del input_ids, generated_id
246
+ torch.cuda.empty_cache()
247
+ return decoded
248
+ else:
249
+ encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda')
250
+ generated_ids = model.generate(
251
+ **encodings,
252
+ generation_config=generation_config,
253
+ max_new_tokens=256
254
+ )
255
+
256
+ decoded = tokenizer.batch_decode(generated_ids)
257
+ del encodings, generated_ids
258
+ torch.cuda.empty_cache()
259
+ return decoded
260
+
261
+
262
+ # StreamModel is borrowed from basaran project
263
+ # please find more info about it -> https://github.com/hyperonym/basaran
264
+ class StreamModel:
265
+ """StreamModel wraps around a language model to provide stream decoding."""
266
+
267
+ def __init__(self, model, tokenizer):
268
+ super().__init__()
269
+ self.model = model
270
+ self.tokenizer = tokenizer
271
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
272
+
273
+ self.processor = LogitsProcessorList()
274
+ self.processor.append(TemperatureLogitsWarper(0.9))
275
+ self.processor.append(TopPLogitsWarper(0.75))
276
+
277
+
278
+ def __call__(
279
+ self,
280
+ prompt,
281
+ min_tokens=0,
282
+ max_tokens=16,
283
+ temperature=1.0,
284
+ top_p=1.0,
285
+ n=1,
286
+ logprobs=0,
287
+ ):
288
+ """Create a completion stream for the provided prompt."""
289
+ input_ids = self.tokenize(prompt)
290
+ logprobs = max(logprobs, 0)
291
+
292
+ # bigger than 1
293
+ chunk_size = 2
294
+ chunk_count = 0
295
+
296
+ # Generate completion tokens.
297
+ final_tokens = torch.empty(0)
298
+
299
+ for tokens in self.generate(
300
+ input_ids[None, :].repeat(n, 1),
301
+ logprobs=logprobs,
302
+ min_new_tokens=min_tokens,
303
+ max_new_tokens=max_tokens,
304
+ temperature=temperature,
305
+ top_p=top_p,
306
+ ):
307
+ if chunk_count < chunk_size:
308
+ chunk_count = chunk_count + 1
309
+
310
+ final_tokens = torch.cat((final_tokens, tokens.to("cpu")))
311
+
312
+ if chunk_count == chunk_size-1:
313
+ chunk_count = 0
314
+ yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
315
+
316
+ if chunk_count > 0:
317
+ yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
318
+
319
+ del final_tokens, input_ids
320
+ if self.device == "cuda":
321
+ torch.cuda.empty_cache()
322
+
323
+ def _infer(self, model_fn, **kwargs):
324
+ with torch.inference_mode():
325
+ return model_fn(**kwargs)
326
+
327
+ def tokenize(self, text):
328
+ """Tokenize a string into a tensor of token IDs."""
329
+ batch = self.tokenizer.encode(text, return_tensors="pt")
330
+ return batch[0].to(self.device)
331
+
332
+ def generate(self, input_ids, logprobs=0, **kwargs):
333
+ """Generate a stream of predicted tokens using the language model."""
334
+
335
+ # Store the original batch size and input length.
336
+ batch_size = input_ids.shape[0]
337
+ input_length = input_ids.shape[-1]
338
+
339
+ # Separate model arguments from generation config.
340
+ config = self.model.generation_config
341
+ config = copy.deepcopy(config)
342
+ kwargs = config.update(**kwargs)
343
+ kwargs["output_attentions"] = False
344
+ kwargs["output_hidden_states"] = False
345
+ kwargs["use_cache"] = True
346
+
347
+ # Collect special token IDs.
348
+ pad_token_id = config.pad_token_id
349
+ bos_token_id = config.bos_token_id
350
+ eos_token_id = config.eos_token_id
351
+ if isinstance(eos_token_id, int):
352
+ eos_token_id = [eos_token_id]
353
+ if pad_token_id is None and eos_token_id is not None:
354
+ pad_token_id = eos_token_id[0]
355
+
356
+ # Generate from eos if no input is specified.
357
+ if input_length == 0:
358
+ input_ids = input_ids.new_ones((batch_size, 1)).long()
359
+ if eos_token_id is not None:
360
+ input_ids = input_ids * eos_token_id[0]
361
+ input_length = 1
362
+
363
+ # Keep track of which sequences are already finished.
364
+ unfinished = input_ids.new_ones(batch_size)
365
+
366
+ # Start auto-regressive generation.
367
+ while True:
368
+ inputs = self.model.prepare_inputs_for_generation(
369
+ input_ids, **kwargs
370
+ ) # noqa: E501
371
+
372
+ outputs = self._infer(
373
+ self.model,
374
+ **inputs,
375
+ # return_dict=True,
376
+ output_attentions=False,
377
+ output_hidden_states=False,
378
+ )
379
+
380
+ # Pre-process the probability distribution of the next tokens.
381
+ logits = outputs.logits[:, -1, :]
382
+ with torch.inference_mode():
383
+ logits = self.processor(input_ids, logits)
384
+ probs = torch.nn.functional.softmax(logits, dim=-1)
385
+
386
+ # Select deterministic or stochastic decoding strategy.
387
+ if (config.top_p is not None and config.top_p <= 0) or (
388
+ config.temperature is not None and config.temperature <= 0
389
+ ):
390
+ tokens = torch.argmax(probs, dim=-1)[:, None]
391
+ else:
392
+ tokens = torch.multinomial(probs, num_samples=1)
393
+
394
+ tokens = tokens.squeeze(1)
395
+
396
+ # Finished sequences should have their next token be a padding.
397
+ if pad_token_id is not None:
398
+ tokens = tokens * unfinished + pad_token_id * (1 - unfinished)
399
+
400
+ # Append selected tokens to the inputs.
401
+ input_ids = torch.cat([input_ids, tokens[:, None]], dim=-1)
402
+
403
+ # Mark sequences with eos tokens as finished.
404
+ if eos_token_id is not None:
405
+ not_eos = sum(tokens != i for i in eos_token_id)
406
+ unfinished = unfinished.mul(not_eos.long())
407
+
408
+ # Set status to -1 if exceeded the max length.
409
+ status = unfinished.clone()
410
+ if input_ids.shape[-1] - input_length >= config.max_new_tokens:
411
+ status = 0 - status
412
+
413
+ # Yield predictions and status.
414
+ yield tokens
415
+
416
+ # Stop when finished or exceeded the max length.
417
+ if status.max() <= 0:
418
+ break
419
+
420
+ generation_config = get_generation_config(
421
+ "./generation_config_default.yaml"
422
+ )
423
+
424
+ model, tokenizer = load_model(
425
+ # base="decapoda-research/llama-13b-hf",
426
+ # finetuned="mchl-labs/stambecco-13b-plus",
427
+ )
428
+
429
+ stream_model = StreamModel(model, tokenizer)
430
+
431
+ def chat_stream(
432
+ context,
433
+ instruction,
434
+ state_chatbot,
435
+ ):
436
+ if len(context) > 1000 or len(instruction) > 300:
437
+ raise gr.Error("Context or prompt is too long!")
438
+
439
+ bot_summarized_response = ''
440
+ # user input should be appropriately formatted (don't be confused by the function name)
441
+ instruction_display = instruction
442
+ instruction_prompt, conv_length = generate_prompt(instruction, state_chatbot, context)
443
+
444
+ if conv_length > num_of_characters_to_keep:
445
+ instruction_prompt = generate_prompt(SPECIAL_STRS["summarize"], state_chatbot, context, partial=True)[0]
446
+
447
+ state_chatbot = state_chatbot + [
448
+ (
449
+ None,
450
+ "![](https://s2.gifyu.com/images/icons8-loading-circle.gif) Conversazione troppo lunga, sto riassumendo..."
451
+ )
452
+ ]
453
+ yield (state_chatbot, state_chatbot, context)
454
+
455
+ bot_summarized_response = get_output_batch(
456
+ model, tokenizer, [instruction_prompt], generation_config
457
+ )[0]
458
+ bot_summarized_response = bot_summarized_response.split("### Risposta:")[-1].strip()
459
+
460
+ state_chatbot[-1] = (
461
+ None,
462
+ "✅ Riepilogo della conversazione effettuato e impostato come contesto"
463
+ )
464
+ print(f"bot_summarized_response: {bot_summarized_response}")
465
+ yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip())
466
+
467
+ instruction_prompt = generate_prompt(instruction, state_chatbot, f"{context} {bot_summarized_response}")[0]
468
+
469
+ bot_response = stream_model(
470
+ instruction_prompt,
471
+ max_tokens=256,
472
+ temperature=1,
473
+ top_p=0.9
474
+ )
475
+
476
+ instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display
477
+ state_chatbot = state_chatbot + [(instruction_display, None)]
478
+ yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip())
479
+
480
+ prev_index = 0
481
+ agg_tokens = ""
482
+ cutoff_idx = 0
483
+ for tokens in bot_response:
484
+ tokens = tokens.strip()
485
+ cur_token = tokens[prev_index:]
486
+
487
+ if "#" in cur_token and agg_tokens == "":
488
+ cutoff_idx = tokens.find("#")
489
+ agg_tokens = tokens[cutoff_idx:]
490
+
491
+ if agg_tokens != "":
492
+ if len(agg_tokens) < len("### Istruzione:") :
493
+ agg_tokens = agg_tokens + cur_token
494
+ elif len(agg_tokens) >= len("### Istruzione:"):
495
+ if tokens.find("### Istruzione:") > -1:
496
+ processed_response, _ = post_process_stream(tokens[:tokens.find("### Istruzione:")].strip())
497
+
498
+ state_chatbot[-1] = (
499
+ instruction_display,
500
+ processed_response
501
+ )
502
+ yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip())
503
+ break
504
+ else:
505
+ agg_tokens = ""
506
+ cutoff_idx = 0
507
+
508
+ if agg_tokens == "":
509
+ processed_response, to_exit = post_process_stream(tokens)
510
+ state_chatbot[-1] = (instruction_display, processed_response)
511
+ yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip())
512
+
513
+ if to_exit:
514
+ break
515
+
516
+ prev_index = len(tokens)
517
+
518
+ yield (
519
+ state_chatbot,
520
+ state_chatbot,
521
+ f"{context} {bot_summarized_response}".strip()
522
+ )
523
+
524
+
525
+ def chat_batch(
526
+ contexts,
527
+ instructions,
528
+ state_chatbots,
529
+ ):
530
+ state_results = []
531
+ ctx_results = []
532
+
533
+ instruct_prompts = [
534
+ generate_prompt(instruct, histories, ctx)
535
+ for ctx, instruct, histories in zip(contexts, instructions, state_chatbots)
536
+ ]
537
+
538
+ bot_responses = get_output_batch(
539
+ model, tokenizer, instruct_prompts, generation_config
540
+ )
541
+ bot_responses = post_processes_batch(bot_responses)
542
+
543
+ for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots):
544
+ new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)]
545
+ ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx)
546
+ state_results.append(new_state_chatbot)
547
+
548
+ return (state_results, state_results, ctx_results)
549
+
550
+ def reset_textbox():
551
+ return gr.Textbox.update(value='')
552
+
553
+ def reset_everything(
554
+ context_txtbox,
555
+ instruction_txtbox,
556
+ state_chatbot):
557
+
558
+ state_chatbot = []
559
+
560
+ return (
561
+ state_chatbot,
562
+ state_chatbot,
563
+ gr.Textbox.update(value=''),
564
+ gr.Textbox.update(value=''),
565
+ )
566
+
567
+ with gr.Blocks(css=PARENT_BLOCK_CSS) as demo:
568
+ state_chatbot = gr.State([])
569
+
570
+ with gr.Column(elem_id='col_container'):
571
+ gr.Markdown(f"## {TITLE}\n\n\n{ABSTRACT}")
572
+
573
+ with gr.Accordion("Context Setting", open=False):
574
+ context_txtbox = gr.Textbox(placeholder="Surrounding information to AI", label="Enter Context")
575
+ hidden_txtbox = gr.Textbox(placeholder="", label="Order", visible=False)
576
+
577
+ chatbot = gr.Chatbot(elem_id='chatbot', label="Stambecco")
578
+ instruction_txtbox = gr.Textbox(placeholder="What do you want to say to AI?", label="Instruction")
579
+ with gr.Row():
580
+ cancel_btn = gr.Button(value="Cancel")
581
+ reset_btn = gr.Button(value="Reset")
582
+
583
+ with gr.Accordion("Helper Buttons", open=False):
584
+ gr.Markdown(f"`Continue` lets AI to complete the previous incomplete answers. `Summarize` lets AI to summarize the conversations so far.")
585
+ continue_txtbox = gr.Textbox(value=SPECIAL_STRS["continue"], visible=False)
586
+ summrize_txtbox = gr.Textbox(value=SPECIAL_STRS["summarize"], visible=False)
587
+
588
+ continue_btn = gr.Button(value="Continue")
589
+ summarize_btn = gr.Button(value="Summarize")
590
+
591
+ gr.Markdown("#### Examples")
592
+ for _, (category, examples) in enumerate(DEFAULT_EXAMPLES.items()):
593
+ with gr.Accordion(category, open=False):
594
+ if category == "Identity":
595
+ for item in examples:
596
+ with gr.Accordion(item["title"], open=False):
597
+ gr.Examples(
598
+ examples=item["examples"],
599
+ inputs=[
600
+ hidden_txtbox, context_txtbox, instruction_txtbox
601
+ ],
602
+ label=None
603
+ )
604
+ else:
605
+ for item in examples:
606
+ with gr.Accordion(item["title"], open=False):
607
+ gr.Examples(
608
+ examples=item["examples"],
609
+ inputs=[
610
+ hidden_txtbox, instruction_txtbox
611
+ ],
612
+ label=None
613
+ )
614
+
615
+ gr.Markdown(f"{BOTTOM_LINE}")
616
+
617
+
618
+ send_event = instruction_txtbox.submit(
619
+ chat_stream,
620
+ [context_txtbox, instruction_txtbox, state_chatbot],
621
+ [state_chatbot, chatbot, context_txtbox],
622
+ )
623
+ reset_event = instruction_txtbox.submit(
624
+ reset_textbox,
625
+ [],
626
+ [instruction_txtbox],
627
+ )
628
+
629
+ continue_event = continue_btn.click(
630
+ chat_stream,
631
+ [context_txtbox, continue_txtbox, state_chatbot],
632
+ [state_chatbot, chatbot, context_txtbox],
633
+ )
634
+ reset_continue_event = continue_btn.click(
635
+ reset_textbox,
636
+ [],
637
+ [instruction_txtbox],
638
+ )
639
+
640
+ summarize_event = summarize_btn.click(
641
+ chat_stream,
642
+ [context_txtbox, summrize_txtbox, state_chatbot],
643
+ [state_chatbot, chatbot, context_txtbox],
644
+ )
645
+ summarize_reset_event = summarize_btn.click(
646
+ reset_textbox,
647
+ [],
648
+ [instruction_txtbox],
649
+ )
650
+
651
+ cancel_btn.click(
652
+ None, None, None,
653
+ cancels=[
654
+ send_event, continue_event, summarize_event
655
+ ]
656
+ )
657
+
658
+ reset_btn.click(
659
+ reset_everything,
660
+ [context_txtbox, instruction_txtbox, state_chatbot],
661
+ [state_chatbot, chatbot, context_txtbox, instruction_txtbox],
662
+ cancels=[
663
+ send_event, continue_event, summarize_event
664
+ ]
665
+ )
666
+
667
+ demo.queue(
668
+ concurrency_count=1,
669
+ max_size=100,
670
+ ).launch(
671
+ max_threads=5,
672
+ server_name="0.0.0.0",
673
+ share=True
674
+ )