hugo1234 commited on
Commit
f6d5624
1 Parent(s): e64ca11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -625
app.py CHANGED
@@ -1,645 +1,220 @@
1
- import os
2
- os.system('pip install bitsandbytes')
3
- os.system('pip install -q datasets loralib sentencepiece accelerate')
4
- # os.system('pip install -q git+https://github.com/zphang/transformers@c3dc391')
5
- # os.system('pip install -q git+https://github.com/huggingface/transformers')
6
- os.system('pip install -q git+https://github.com/mbehm/transformers')
7
- os.system('pip install -q git+https://github.com/huggingface/peft.git')
8
- # os.system('pip install gradio')
9
- # os.system('pip install torch')
10
- # os.system('pip install peft')
11
- # os.system('pip install transformers')
12
- os.system('pip install tenacity')
13
- os.system('pip install scipy')
14
- # os.system('pip install sentencepiece')
15
 
16
- import re
17
- import yaml
18
- import gc
19
- import copy
20
- import time
21
- from tenacity import RetryError
22
- from tenacity import retry, stop_after_attempt, wait_fixed
23
  import gradio as gr
24
- # import torch
25
- from peft import PeftModel
26
- from transformers import (
27
- LLaMATokenizer,
28
- LlamaForCausalLM,
29
- GenerationConfig,
30
- AutoModelForCausalLM,
31
- AutoModelForSeq2SeqLM,
32
- AutoTokenizer,
33
- LogitsProcessorList,
34
- MinNewTokensLengthLogitsProcessor,
35
- TemperatureLogitsWarper,
36
- TopPLogitsWarper,
37
- MinLengthLogitsProcessor
38
- )
39
-
40
- # assert torch.cuda.is_available(), "Change the runtime type to GPU"
41
-
42
- # constants
43
- num_of_characters_to_keep = 1000
44
-
45
- # regex
46
- html_tag_pattern = re.compile(r"<.*?>")
47
- multi_line_pattern = re.compile(r"\n+")
48
- multi_space_pattern = re.compile(r"( )")
49
- multi_br_tag_pattern = re.compile(re.compile(r'<br>\s*(<br>\s*)*'))
50
-
51
- # repl is short for replacement
52
- repl_linebreak = "\n"
53
- repl_empty_str = ""
54
-
55
- TITLE = "Galileo"
56
-
57
- ABSTRACT = """
58
- Stambecco is a Italian Instruction-following model based on the [LLaMA](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) model. It comes in two versions: 7b and 13b parameters. It is trained on an Italian version of the [GPT-4-LLM](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) dataset, a dataset of `GPT-4` generated instruction-following data.
59
- This demo is intended to show and evaluate the conversational capabilities of the model.
60
- For more information, please visit [the project's website](https://github.com/mchl-labs/stambecco).
61
- NOTE: Too long input (context, instruction) will not be allowed. Please keep context < 500 and instruction < 150
62
- """
63
-
64
- BOTTOM_LINE = """
65
- By default, this demo runs with streaming mode, but you can also run with dynamic batch generation model.
66
- Stambecco is built on the same concept as Standford Alpaca project, but using LoRA it lets us train and inference on a smaller GPUs such as RTX4090 for 7B version. Also, we could build very small size of checkpoints on top of base models thanks to [🤗 transformers](https://huggingface.co/docs/transformers/index), [🤗 peft](https://github.com/huggingface/peft), and [bitsandbytes](https://github.com/TimDettmers/bitsandbytes/tree/main) libraries.
67
- This demo currently runs 8Bit 7b version of the model.
68
- """
69
-
70
- DEFAULT_EXAMPLES = {
71
- "Typical Questions": [
72
- {
73
- "title": "Parlami di Giulio Cesare.",
74
- "examples": [
75
- ["1", "Scrivi un articolo su Giulio Cesare"],
76
- ["2", "Davvero?"],
77
- ["3", "Quanto era ricco Giulio Cesare?"],
78
- ["4", "Chi è stato il suo successore?"],
79
- ]
80
- },
81
- {
82
- "title": "Parigi",
83
- "examples": [
84
- ["1", "Scrivi un tema sulla città di Parigi"],
85
- ["2", "Fai un elenco di 5 posti da visitare assolutamente"],
86
- ["3", "Quali eventi importanti della Storia sono avvenuti a Parigi?"],
87
- ["4", "Quale è il periodo migliore per visitare Parigi?"],
88
- ]
89
- },
90
- {
91
- "title": "Scrivi un programma in Python che stampi i primi 10 numeri di Fibonacci",
92
- "examples": [
93
- ["1", "Scrivi un programma in Python che stampi i primi 10 numeri di Fibonacci"],
94
- ["2", "Potresti spiegarmi come funziona il codice?"],
95
- ["3", "Cos'è la ricorsione?"],
96
- ]
97
- }
98
- ],
99
- }
100
-
101
- SPECIAL_STRS = {
102
- "continue": "continua",
103
- "summarize": "Di cosa abbiamo discusso finora? Descrivi nella user's view."
104
- }
105
-
106
- PARENT_BLOCK_CSS = """
107
- #col_container {
108
- width: 95%;
109
- margin-left: auto;
110
- margin-right: auto;
111
- }
112
- #chatbot {
113
- height: 500px;
114
- overflow: auto;
115
- }
116
- """
117
-
118
- def load_model(
119
- base="decapoda-research/llama-7b-hf",
120
- finetuned="mchl-labs/stambecco-7b-plus",
121
- ):
122
- tokenizer = LLaMATokenizer.from_pretrained(base)
123
- tokenizer.pad_token_id = 0
124
- tokenizer.padding_side = "left"
125
-
126
- model = LlamaForCausalLM.from_pretrained(
127
- base,
128
- load_in_8bit=True,
129
- device_map="from_pretrained",
130
- # load_in_8bit_fp32_cpu_offload=True
131
- )
132
- # model = PeftModel.from_pretrained(model, finetuned, device_map={'': 0})
133
-
134
- model = PeftModel.from_pretrained(model, finetuned)
135
- return model, tokenizer
136
-
137
- def get_generation_config(path):
138
- with open(path, 'rb') as f:
139
- generation_config = yaml.safe_load(f.read())
140
-
141
- return GenerationConfig(**generation_config["generation_config"])
142
-
143
- def generate_prompt(prompt, histories, ctx=None, partial=False):
144
- convs = f"""Di seguito è riportata una cronologia delle istruzioni che descrivono le tasks, abbinate a un input che fornisce ulteriore contesto. Scrivi una risposta che completi adeguatamente la richiesta ricordando la cronologia della conversazione.
145
-
146
- """
147
-
148
- if ctx is not None:
149
- convs = f"""### Input: {ctx}
150
- """
151
-
152
- sub_convs = ""
153
- start_idx = 0
154
-
155
- for idx, history in enumerate(histories):
156
- history_prompt = history[0]
157
- history_response = history[1]
158
- if history_response == "✅ Riepilogo della conversazione effettuato e impostato come contesto" or history_prompt == SPECIAL_STRS["summarize"]:
159
- start_idx = idx
160
-
161
- # drop the previous conversations if user has summarized
162
- for history in histories[start_idx if start_idx == 0 else start_idx+1:]:
163
- history_prompt = history[0]
164
- history_response = history[1]
165
-
166
- history_response = history_response.replace("<br>", "\n")
167
- history_response = re.sub(
168
- html_tag_pattern, repl_empty_str, history_response
169
- )
170
-
171
- sub_convs = sub_convs + f"""### Istruzione: {history_prompt}
172
- ### Risposta: {history_response}
173
- """
174
-
175
- sub_convs = sub_convs + f"""### Istruzione: {prompt}
176
- ### Risposta:"""
177
-
178
- convs = convs + sub_convs
179
- return sub_convs if partial else convs, len(sub_convs)
180
-
181
- def common_post_process(original_str):
182
- original_str = re.sub(
183
- multi_line_pattern, repl_linebreak, original_str
184
- )
185
- return original_str
186
-
187
- def post_process_stream(bot_response):
188
- # sometimes model spits out text containing
189
- # "### Risposta:" and "### Istruzione: -> in this case, we want to stop generating
190
- if "### Risposta:" in bot_response or "### Input:" in bot_response:
191
- bot_response = bot_response.replace("### Risposta:", '').replace("### Input:", '').strip()
192
- return bot_response, True
193
-
194
- return common_post_process(bot_response), False
195
-
196
- def post_process_batch(bot_response):
197
- bot_response = bot_response.split("### Risposta:")[-1].strip()
198
- return common_post_process(bot_response)
199
-
200
- def post_processes_batch(bot_responses):
201
- return [post_process_batch(r) for r in bot_responses]
202
-
203
- def get_output_batch(
204
- model, tokenizer, prompts, generation_config
205
- ):
206
- if len(prompts) == 1:
207
- encoding = tokenizer(prompts, return_tensors="pt")
208
- input_ids = encoding["input_ids"].cuda()
209
- generated_id = model.generate(
210
- input_ids=input_ids,
211
- generation_config=generation_config,
212
- max_new_tokens=256
213
- )
214
-
215
- decoded = tokenizer.batch_decode(generated_id)
216
- del input_ids, generated_id
217
- torch.cuda.empty_cache()
218
- return decoded
219
  else:
220
- encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda')
221
- generated_ids = model.generate(
222
- **encodings,
223
- generation_config=generation_config,
224
- max_new_tokens=256
225
- )
226
-
227
- decoded = tokenizer.batch_decode(generated_ids)
228
- del encodings, generated_ids
229
- torch.cuda.empty_cache()
230
- return decoded
231
-
232
-
233
- # StreamModel is borrowed from basaran project
234
- # please find more info about it -> https://github.com/hyperonym/basaran
235
- class StreamModel:
236
- """StreamModel wraps around a language model to provide stream decoding."""
237
-
238
- def __init__(self, model, tokenizer):
239
- super().__init__()
240
- self.model = model
241
- self.tokenizer = tokenizer
242
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
243
-
244
- self.processor = LogitsProcessorList()
245
- self.processor.append(TemperatureLogitsWarper(0.9))
246
- self.processor.append(TopPLogitsWarper(0.75))
247
-
248
-
249
- def __call__(
250
- self,
251
- prompt,
252
- min_tokens=0,
253
- max_tokens=16,
254
- temperature=1.0,
255
- top_p=1.0,
256
- n=1,
257
- logprobs=0,
258
- ):
259
- """Create a completion stream for the provided prompt."""
260
- input_ids = self.tokenize(prompt)
261
- logprobs = max(logprobs, 0)
262
-
263
- # bigger than 1
264
- chunk_size = 2
265
- chunk_count = 0
266
-
267
- # Generate completion tokens.
268
- final_tokens = torch.empty(0)
269
 
270
- for tokens in self.generate(
271
- input_ids[None, :].repeat(n, 1),
272
- logprobs=logprobs,
273
- min_new_tokens=min_tokens,
274
- max_new_tokens=max_tokens,
275
- temperature=temperature,
276
- top_p=top_p,
277
- ):
278
- if chunk_count < chunk_size:
279
- chunk_count = chunk_count + 1
280
-
281
- final_tokens = torch.cat((final_tokens, tokens.to("cpu")))
282
-
283
- if chunk_count == chunk_size-1:
284
- chunk_count = 0
285
- yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
286
-
287
- if chunk_count > 0:
288
- yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
289
-
290
- del final_tokens, input_ids
291
- if self.device == "cuda":
292
- torch.cuda.empty_cache()
293
-
294
- def _infer(self, model_fn, **kwargs):
295
- with torch.inference_mode():
296
- return model_fn(**kwargs)
297
-
298
- def tokenize(self, text):
299
- """Tokenize a string into a tensor of token IDs."""
300
- batch = self.tokenizer.encode(text, return_tensors="pt")
301
- return batch[0].to(self.device)
302
-
303
- def generate(self, input_ids, logprobs=0, **kwargs):
304
- """Generate a stream of predicted tokens using the language model."""
305
-
306
- # Store the original batch size and input length.
307
- batch_size = input_ids.shape[0]
308
- input_length = input_ids.shape[-1]
309
-
310
- # Separate model arguments from generation config.
311
- config = self.model.generation_config
312
- config = copy.deepcopy(config)
313
- kwargs = config.update(**kwargs)
314
- kwargs["output_attentions"] = False
315
- kwargs["output_hidden_states"] = False
316
- kwargs["use_cache"] = True
317
-
318
- # Collect special token IDs.
319
- pad_token_id = config.pad_token_id
320
- bos_token_id = config.bos_token_id
321
- eos_token_id = config.eos_token_id
322
- if isinstance(eos_token_id, int):
323
- eos_token_id = [eos_token_id]
324
- if pad_token_id is None and eos_token_id is not None:
325
- pad_token_id = eos_token_id[0]
326
-
327
- # Generate from eos if no input is specified.
328
- if input_length == 0:
329
- input_ids = input_ids.new_ones((batch_size, 1)).long()
330
- if eos_token_id is not None:
331
- input_ids = input_ids * eos_token_id[0]
332
- input_length = 1
333
-
334
- # Keep track of which sequences are already finished.
335
- unfinished = input_ids.new_ones(batch_size)
336
-
337
- # Start auto-regressive generation.
338
- while True:
339
- inputs = self.model.prepare_inputs_for_generation(
340
- input_ids, **kwargs
341
- ) # noqa: E501
342
-
343
- outputs = self._infer(
344
- self.model,
345
- **inputs,
346
- # return_dict=True,
347
- output_attentions=False,
348
- output_hidden_states=False,
349
- )
350
-
351
- # Pre-process the probability distribution of the next tokens.
352
- logits = outputs.logits[:, -1, :]
353
- with torch.inference_mode():
354
- logits = self.processor(input_ids, logits)
355
- probs = torch.nn.functional.softmax(logits, dim=-1)
356
-
357
- # Select deterministic or stochastic decoding strategy.
358
- if (config.top_p is not None and config.top_p <= 0) or (
359
- config.temperature is not None and config.temperature <= 0
360
- ):
361
- tokens = torch.argmax(probs, dim=-1)[:, None]
362
- else:
363
- tokens = torch.multinomial(probs, num_samples=1)
364
-
365
- tokens = tokens.squeeze(1)
366
-
367
- # Finished sequences should have their next token be a padding.
368
- if pad_token_id is not None:
369
- tokens = tokens * unfinished + pad_token_id * (1 - unfinished)
370
-
371
- # Append selected tokens to the inputs.
372
- input_ids = torch.cat([input_ids, tokens[:, None]], dim=-1)
373
-
374
- # Mark sequences with eos tokens as finished.
375
- if eos_token_id is not None:
376
- not_eos = sum(tokens != i for i in eos_token_id)
377
- unfinished = unfinished.mul(not_eos.long())
378
-
379
- # Set status to -1 if exceeded the max length.
380
- status = unfinished.clone()
381
- if input_ids.shape[-1] - input_length >= config.max_new_tokens:
382
- status = 0 - status
383
-
384
- # Yield predictions and status.
385
- yield tokens
386
-
387
- # Stop when finished or exceeded the max length.
388
- if status.max() <= 0:
389
- break
390
-
391
- generation_config = get_generation_config(
392
- "./generation_config_default.yaml"
393
- )
394
-
395
- model, tokenizer = load_model(
396
- # base="decapoda-research/llama-13b-hf",
397
- # finetuned="mchl-labs/stambecco-13b-plus",
398
- )
399
 
400
- stream_model = StreamModel(model, tokenizer)
401
 
402
- def chat_stream(
403
- context,
404
- instruction,
405
- state_chatbot,
406
- ):
407
- if len(context) > 1000 or len(instruction) > 300:
408
- raise gr.Error("Context or prompt is too long!")
409
-
410
- bot_summarized_response = ''
411
- # user input should be appropriately formatted (don't be confused by the function name)
412
- instruction_display = instruction
413
- instruction_prompt, conv_length = generate_prompt(instruction, state_chatbot, context)
414
 
415
- if conv_length > num_of_characters_to_keep:
416
- instruction_prompt = generate_prompt(SPECIAL_STRS["summarize"], state_chatbot, context, partial=True)[0]
417
-
418
- state_chatbot = state_chatbot + [
419
- (
420
- None,
421
- "![](https://s2.gifyu.com/images/icons8-loading-circle.gif) Conversazione troppo lunga, sto riassumendo..."
422
- )
423
- ]
424
- yield (state_chatbot, state_chatbot, context)
425
-
426
- bot_summarized_response = get_output_batch(
427
- model, tokenizer, [instruction_prompt], generation_config
428
- )[0]
429
- bot_summarized_response = bot_summarized_response.split("### Risposta:")[-1].strip()
430
-
431
- state_chatbot[-1] = (
432
- None,
433
- "✅ Riepilogo della conversazione effettuato e impostato come contesto"
434
- )
435
- print(f"bot_summarized_response: {bot_summarized_response}")
436
- yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip())
437
-
438
- instruction_prompt = generate_prompt(instruction, state_chatbot, f"{context} {bot_summarized_response}")[0]
439
-
440
- bot_response = stream_model(
441
- instruction_prompt,
442
- max_tokens=256,
443
- temperature=1,
444
- top_p=0.9
445
- )
446
 
447
- instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display
448
- state_chatbot = state_chatbot + [(instruction_display, None)]
449
- yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip())
 
450
 
451
- prev_index = 0
452
- agg_tokens = ""
453
- cutoff_idx = 0
454
- for tokens in bot_response:
455
- tokens = tokens.strip()
456
- cur_token = tokens[prev_index:]
457
-
458
- if "#" in cur_token and agg_tokens == "":
459
- cutoff_idx = tokens.find("#")
460
- agg_tokens = tokens[cutoff_idx:]
461
-
462
- if agg_tokens != "":
463
- if len(agg_tokens) < len("### Istruzione:") :
464
- agg_tokens = agg_tokens + cur_token
465
- elif len(agg_tokens) >= len("### Istruzione:"):
466
- if tokens.find("### Istruzione:") > -1:
467
- processed_response, _ = post_process_stream(tokens[:tokens.find("### Istruzione:")].strip())
468
-
469
- state_chatbot[-1] = (
470
- instruction_display,
471
- processed_response
472
- )
473
- yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip())
474
- break
475
- else:
476
- agg_tokens = ""
477
- cutoff_idx = 0
478
-
479
- if agg_tokens == "":
480
- processed_response, to_exit = post_process_stream(tokens)
481
- state_chatbot[-1] = (instruction_display, processed_response)
482
- yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip())
483
 
484
- if to_exit:
485
- break
486
 
487
- prev_index = len(tokens)
488
-
489
- yield (
490
- state_chatbot,
491
- state_chatbot,
492
- f"{context} {bot_summarized_response}".strip()
493
- )
494
-
495
-
496
- def chat_batch(
497
- contexts,
498
- instructions,
499
- state_chatbots,
500
- ):
501
- state_results = []
502
- ctx_results = []
503
-
504
- instruct_prompts = [
505
- generate_prompt(instruct, histories, ctx)
506
- for ctx, instruct, histories in zip(contexts, instructions, state_chatbots)
507
- ]
508
-
509
- bot_responses = get_output_batch(
510
- model, tokenizer, instruct_prompts, generation_config
511
- )
512
- bot_responses = post_processes_batch(bot_responses)
513
-
514
- for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots):
515
- new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)]
516
- ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx)
517
- state_results.append(new_state_chatbot)
518
-
519
- return (state_results, state_results, ctx_results)
520
-
521
- def reset_textbox():
522
- return gr.Textbox.update(value='')
523
-
524
- def reset_everything(
525
- context_txtbox,
526
- instruction_txtbox,
527
- state_chatbot):
528
-
529
- state_chatbot = []
530
 
531
- return (
532
- state_chatbot,
533
- state_chatbot,
534
- gr.Textbox.update(value=''),
535
- gr.Textbox.update(value=''),
536
- )
537
-
538
- with gr.Blocks(css=PARENT_BLOCK_CSS) as demo:
539
- state_chatbot = gr.State([])
540
-
541
- with gr.Column(elem_id='col_container'):
542
- gr.Markdown(f"## {TITLE}\n\n\n{ABSTRACT}")
543
-
544
- with gr.Accordion("Context Setting", open=False):
545
- context_txtbox = gr.Textbox(placeholder="Surrounding information to AI", label="Enter Context")
546
- hidden_txtbox = gr.Textbox(placeholder="", label="Order", visible=False)
547
-
548
- chatbot = gr.Chatbot(elem_id='chatbot', label="Stambecco")
549
- instruction_txtbox = gr.Textbox(placeholder="What do you want to say to AI?", label="Instruction")
550
- with gr.Row():
551
- cancel_btn = gr.Button(value="Cancel")
552
- reset_btn = gr.Button(value="Reset")
553
-
554
- with gr.Accordion("Helper Buttons", open=False):
555
- gr.Markdown(f"`Continue` lets AI to complete the previous incomplete answers. `Summarize` lets AI to summarize the conversations so far.")
556
- continue_txtbox = gr.Textbox(value=SPECIAL_STRS["continue"], visible=False)
557
- summrize_txtbox = gr.Textbox(value=SPECIAL_STRS["summarize"], visible=False)
558
-
559
- continue_btn = gr.Button(value="Continue")
560
- summarize_btn = gr.Button(value="Summarize")
561
-
562
- gr.Markdown("#### Examples")
563
- for _, (category, examples) in enumerate(DEFAULT_EXAMPLES.items()):
564
- with gr.Accordion(category, open=False):
565
- if category == "Identity":
566
- for item in examples:
567
- with gr.Accordion(item["title"], open=False):
568
- gr.Examples(
569
- examples=item["examples"],
570
- inputs=[
571
- hidden_txtbox, context_txtbox, instruction_txtbox
572
- ],
573
- label=None
574
  )
575
- else:
576
- for item in examples:
577
- with gr.Accordion(item["title"], open=False):
578
- gr.Examples(
579
- examples=item["examples"],
580
- inputs=[
581
- hidden_txtbox, instruction_txtbox
582
- ],
583
- label=None
584
  )
585
-
586
- gr.Markdown(f"{BOTTOM_LINE}")
587
-
588
-
589
- send_event = instruction_txtbox.submit(
590
- chat_stream,
591
- [context_txtbox, instruction_txtbox, state_chatbot],
592
- [state_chatbot, chatbot, context_txtbox],
593
- )
594
- reset_event = instruction_txtbox.submit(
595
- reset_textbox,
596
- [],
597
- [instruction_txtbox],
598
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
 
600
- continue_event = continue_btn.click(
601
- chat_stream,
602
- [context_txtbox, continue_txtbox, state_chatbot],
603
- [state_chatbot, chatbot, context_txtbox],
604
- )
605
- reset_continue_event = continue_btn.click(
606
- reset_textbox,
607
- [],
608
- [instruction_txtbox],
 
 
 
 
609
  )
610
-
611
- summarize_event = summarize_btn.click(
612
- chat_stream,
613
- [context_txtbox, summrize_txtbox, state_chatbot],
614
- [state_chatbot, chatbot, context_txtbox],
615
  )
616
- summarize_reset_event = summarize_btn.click(
617
- reset_textbox,
618
- [],
619
- [instruction_txtbox],
620
  )
621
-
622
- cancel_btn.click(
623
- None, None, None,
624
- cancels=[
625
- send_event, continue_event, summarize_event
626
- ]
 
 
 
 
627
  )
 
628
 
629
- reset_btn.click(
630
- reset_everything,
631
- [context_txtbox, instruction_txtbox, state_chatbot],
632
- [state_chatbot, chatbot, context_txtbox, instruction_txtbox],
633
- cancels=[
634
- send_event, continue_event, summarize_event
635
- ]
636
- )
637
-
638
- demo.queue(
639
- concurrency_count=1,
640
- max_size=100,
641
- ).launch(
642
- max_threads=5,
643
- server_name="0.0.0.0",
644
- share=True
645
- )
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
 
 
 
 
 
 
 
 
 
 
 
 
3
 
 
 
 
 
 
 
 
4
  import gradio as gr
5
+ #from transformers import pipeline
6
+ import torch
7
+ from utils import *
8
+ from presets import *
9
+
10
+ #antwort=""
11
+ ######################################################################
12
+ #Modelle und Tokenizer
13
+
14
+ #Hugging Chat nutzen
15
+ # Create a chatbot connection
16
+ #chatbot = hugchat.ChatBot(cookie_path="cookies.json")
17
+
18
+ #Alternativ mit beliebigen Modellen:
19
+ #base_model = "project-baize/baize-v2-7b"
20
+ base_model = "EleutherAI/gpt-neo-1.3B"
21
+ tokenizer,model,device = load_tokenizer_and_model(base_model)
22
+
23
+
24
+ ########################################################################
25
+ #Chat KI nutzen, um Text zu generieren...
26
+ def predict(text,
27
+ chatbotGr,
28
+ history,
29
+ top_p,
30
+ temperature,
31
+ max_length_tokens,
32
+ max_context_length_tokens,):
33
+ if text=="":
34
+ yield chatbotGr,history,"Empty context."
35
+ return
36
+ try:
37
+ model
38
+ except:
39
+ yield [[text,"No Model Found"]],[],"No Model Found"
40
+ return
41
+
42
+ inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
43
+ if inputs is None:
44
+ yield chatbotGr,history,"Input too long."
45
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  else:
47
+ prompt,inputs=inputs
48
+ begin_length = len(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device)
51
+ torch.cuda.empty_cache()
52
+
53
+ #torch.no_grad() bedeutet, dass für die betreffenden tensoren keine Ableitungen berechnet werden bei der backpropagation
54
+ #hier soll das NN ja auch nicht geändert werden 8backprop ist nicht nötig), da es um interference-prompts geht!
55
+ with torch.no_grad():
56
+ #die vergangenen prompts werden alle als Tupel in history abgelegt sortiert nach 'Human' und 'AI'- dass sind daher auch die stop-words, die den jeweils nächsten Eintrag kennzeichnen
57
+ for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
58
+ if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
59
+ if "[|Human|]" in x:
60
+ x = x[:x.index("[|Human|]")].strip()
61
+ if "[|AI|]" in x:
62
+ x = x[:x.index("[|AI|]")].strip()
63
+ x = x.strip()
64
+ a, b= [[y[0],convert_to_markdown(y[1])] for y in history]+[[text, convert_to_markdown(x)]],history + [[text,x]]
65
+ yield a, b, "Generating..."
66
+ if shared_state.interrupted:
67
+ shared_state.recover()
68
+ try:
69
+ yield a, b, "Stop: Success"
70
+ return
71
+ except:
72
+ pass
73
+ del input_ids
74
+ gc.collect()
75
+ torch.cuda.empty_cache()
76
+
77
+ try:
78
+ yield a,b,"Generate: Success"
79
+ except:
80
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
 
82
 
83
+ def reset_chat():
84
+ #id_new = chatbot.new_conversation()
85
+ #chatbot.change_conversation(id_new)
86
+ reset_textbox()
 
 
 
 
 
 
 
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ ##########################################################
90
+ #Übersetzungs Ki nutzen
91
+ def translate():
92
+ return "Kommt noch!"
93
 
94
+ #Programmcode KI
95
+ def coding():
96
+ return "Kommt noch!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ #######################################################################
99
+ #Darstellung mit Gradio
100
 
101
+ with open("custom.css", "r", encoding="utf-8") as f:
102
+ customCSS = f.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
105
+ history = gr.State([])
106
+ user_question = gr.State("")
107
+ gr.Markdown("KIs am LI - wähle aus, was du bzgl. KI-Bots ausprobieren möchtest!")
108
+ with gr.Tabs():
109
+ with gr.TabItem("LI-Chat"):
110
+ with gr.Row():
111
+ gr.HTML(title)
112
+ status_display = gr.Markdown("Erfolg", elem_id="status_display")
113
+ gr.Markdown(description_top)
114
+ with gr.Row(scale=1).style(equal_height=True):
115
+ with gr.Column(scale=5):
116
+ with gr.Row(scale=1):
117
+ chatbotGr = gr.Chatbot(elem_id="LI_chatbot").style(height="100%")
118
+ with gr.Row(scale=1):
119
+ with gr.Column(scale=12):
120
+ user_input = gr.Textbox(
121
+ show_label=False, placeholder="Gib deinen Text / Frage ein."
122
+ ).style(container=False)
123
+ with gr.Column(min_width=100, scale=1):
124
+ submitBtn = gr.Button("Absenden")
125
+ with gr.Column(min_width=100, scale=1):
126
+ cancelBtn = gr.Button("Stoppen")
127
+ with gr.Row(scale=1):
128
+ emptyBtn = gr.Button(
129
+ "🧹 Neuer Chat",
130
+ )
131
+ with gr.Column():
132
+ with gr.Column(min_width=50, scale=1):
133
+ with gr.Tab(label="Parameter zum Model"):
134
+ gr.Markdown("# Parameters")
135
+ top_p = gr.Slider(
136
+ minimum=-0,
137
+ maximum=1.0,
138
+ value=0.95,
139
+ step=0.05,
140
+ interactive=True,
141
+ label="Top-p",
 
 
 
 
 
142
  )
143
+ temperature = gr.Slider(
144
+ minimum=0.1,
145
+ maximum=2.0,
146
+ value=1,
147
+ step=0.1,
148
+ interactive=True,
149
+ label="Temperature",
 
 
150
  )
151
+ max_length_tokens = gr.Slider(
152
+ minimum=0,
153
+ maximum=512,
154
+ value=512,
155
+ step=8,
156
+ interactive=True,
157
+ label="Max Generation Tokens",
158
+ )
159
+ max_context_length_tokens = gr.Slider(
160
+ minimum=0,
161
+ maximum=4096,
162
+ value=2048,
163
+ step=128,
164
+ interactive=True,
165
+ label="Max History Tokens",
166
+ )
167
+ gr.Markdown(description)
168
+
169
+ with gr.TabItem("Übersetzungen"):
170
+ with gr.Row():
171
+ gr.Textbox(
172
+ show_label=False, placeholder="Ist noch in Arbeit..."
173
+ ).style(container=False)
174
+ with gr.TabItem("Code-Generierungen"):
175
+ with gr.Row():
176
+ gr.Textbox(
177
+ show_label=False, placeholder="Ist noch in Arbeit..."
178
+ ).style(container=False)
179
 
180
+ predict_args = dict(
181
+ fn=predict,
182
+ inputs=[
183
+ user_question,
184
+ chatbotGr,
185
+ history,
186
+ top_p,
187
+ temperature,
188
+ max_length_tokens,
189
+ max_context_length_tokens,
190
+ ],
191
+ outputs=[chatbotGr, history, status_display],
192
+ show_progress=True,
193
  )
194
+
195
+ #neuer Chat
196
+ reset_args = dict(
197
+ #fn=reset_chat, inputs=[], outputs=[user_input, status_display]
198
+ fn=reset_textbox, inputs=[], outputs=[user_input, status_display]
199
  )
200
+
201
+ # Chatbot
202
+ transfer_input_args = dict(
203
+ fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn], show_progress=True
204
  )
205
+
206
+ #Listener auf Start-Click auf Button oder Return
207
+ predict_event1 = user_input.submit(**transfer_input_args).then(**predict_args)
208
+ predict_event2 = submitBtn.click(**transfer_input_args).then(**predict_args)
209
+
210
+ #Listener, Wenn reset...
211
+ emptyBtn.click(
212
+ reset_state,
213
+ outputs=[chatbotGr, history, status_display],
214
+ show_progress=True,
215
  )
216
+ emptyBtn.click(**reset_args)
217
 
218
+ demo.title = "LI Chat"
219
+ #demo.queue(concurrency_count=1).launch(share=True)
220
+ demo.queue(concurrency_count=1).launch(debug=True)