Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -18,6 +18,11 @@ import filelock
|
|
18 |
import glob
|
19 |
import json
|
20 |
import time
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
from gradio_client.documentation import document, set_documentation_group
|
23 |
|
@@ -278,455 +283,6 @@ path_markdown = """
|
|
278 |
|
279 |
|
280 |
|
281 |
-
|
282 |
-
def custom_hf_model_weights_iterator(
|
283 |
-
model_name_or_path: str,
|
284 |
-
cache_dir: Optional[str] = None,
|
285 |
-
use_np_cache: bool = False,
|
286 |
-
) -> Iterator[Tuple[str, torch.Tensor]]:
|
287 |
-
# ! if use vllm==0.1.4, use this to augment hf_model_weights_iterator loader
|
288 |
-
from vllm.model_executor.weight_utils import Disabledtqdm
|
289 |
-
# Prepare file lock directory to prevent multiple processes from
|
290 |
-
# downloading the same model weights at the same time.
|
291 |
-
lock_dir = cache_dir if cache_dir is not None else "/tmp"
|
292 |
-
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
|
293 |
-
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
|
294 |
-
|
295 |
-
# Download model weights from huggingface.
|
296 |
-
is_local = os.path.isdir(model_name_or_path)
|
297 |
-
if not is_local:
|
298 |
-
with lock:
|
299 |
-
hf_folder = snapshot_download(model_name_or_path,
|
300 |
-
allow_patterns="*.bin",
|
301 |
-
cache_dir=cache_dir,
|
302 |
-
local_files_only=True,
|
303 |
-
tqdm_class=Disabledtqdm)
|
304 |
-
else:
|
305 |
-
hf_folder = model_name_or_path
|
306 |
-
|
307 |
-
hf_bin_files = [
|
308 |
-
x for x in glob.glob(os.path.join(hf_folder, "*model*.bin"))
|
309 |
-
if not x.endswith("training_args.bin")
|
310 |
-
]
|
311 |
-
hf_safetensors_files = [
|
312 |
-
x for x in glob.glob(os.path.join(hf_folder, "*model*.safetensors"))
|
313 |
-
if not x.endswith("training_args.bin")
|
314 |
-
]
|
315 |
-
|
316 |
-
if use_np_cache:
|
317 |
-
# Convert the model weights from torch tensors to numpy arrays for
|
318 |
-
# faster loading.
|
319 |
-
np_folder = os.path.join(hf_folder, "np")
|
320 |
-
os.makedirs(np_folder, exist_ok=True)
|
321 |
-
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
322 |
-
with lock:
|
323 |
-
if not os.path.exists(weight_names_file):
|
324 |
-
weight_names = []
|
325 |
-
for bin_file in hf_bin_files:
|
326 |
-
state = torch.load(bin_file, map_location="cpu")
|
327 |
-
for name, param in state.items():
|
328 |
-
param_path = os.path.join(np_folder, name)
|
329 |
-
with open(param_path, "wb") as f:
|
330 |
-
np.save(f, param.cpu().detach().numpy())
|
331 |
-
weight_names.append(name)
|
332 |
-
with open(weight_names_file, "w") as f:
|
333 |
-
json.dump(weight_names, f)
|
334 |
-
|
335 |
-
with open(weight_names_file, "r") as f:
|
336 |
-
weight_names = json.load(f)
|
337 |
-
|
338 |
-
for name in weight_names:
|
339 |
-
param_path = os.path.join(np_folder, name)
|
340 |
-
with open(param_path, "rb") as f:
|
341 |
-
param = np.load(f)
|
342 |
-
yield name, torch.from_numpy(param)
|
343 |
-
else:
|
344 |
-
if len(hf_bin_files) > 0:
|
345 |
-
print(F'Load bin files: {hf_bin_files}')
|
346 |
-
for bin_file in hf_bin_files:
|
347 |
-
state = torch.load(bin_file, map_location="cpu")
|
348 |
-
for name, param in state.items():
|
349 |
-
yield name, param
|
350 |
-
del state
|
351 |
-
torch.cuda.empty_cache()
|
352 |
-
elif len(hf_safetensors_files) > 0:
|
353 |
-
print(F'Load safetensor files: {hf_safetensors_files}')
|
354 |
-
from safetensors.torch import load_file
|
355 |
-
for safe_file in hf_safetensors_files:
|
356 |
-
# state = torch.load(bin_file, map_location="cpu")
|
357 |
-
state = load_file(safe_file)
|
358 |
-
for name, param in state.items():
|
359 |
-
yield name, param
|
360 |
-
del state
|
361 |
-
torch.cuda.empty_cache()
|
362 |
-
else:
|
363 |
-
raise ValueError(f'no files available either bin or safe')
|
364 |
-
|
365 |
-
|
366 |
-
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
367 |
-
"""convert PySafeSlice object from safetensors to torch.Tensor
|
368 |
-
|
369 |
-
PySafeSlice object supports indexing, which is done before loading the
|
370 |
-
actual tensor and can reduce the amount of memory being read into the
|
371 |
-
memory. However, it does not support more advanced functionalities
|
372 |
-
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
|
373 |
-
tensor with these more complicated operators, we need to convert to
|
374 |
-
tensor first.
|
375 |
-
"""
|
376 |
-
if not isinstance(x, torch.Tensor):
|
377 |
-
x = x[:]
|
378 |
-
return x
|
379 |
-
|
380 |
-
|
381 |
-
def load_padded_tensor_parallel_vocab(
|
382 |
-
param: torch.Tensor,
|
383 |
-
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
|
384 |
-
tensor_model_parallel_rank: int,
|
385 |
-
) -> None:
|
386 |
-
shard_size = param.shape[0]
|
387 |
-
start_idx = tensor_model_parallel_rank * shard_size
|
388 |
-
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
389 |
-
loaded_weight = loaded_weight[start_idx:end_idx]
|
390 |
-
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
391 |
-
param[:loaded_weight.shape[0]].copy_(loaded_weight)
|
392 |
-
|
393 |
-
|
394 |
-
def llama_load_weights(
|
395 |
-
self,
|
396 |
-
model_name_or_path: str,
|
397 |
-
cache_dir: Optional[str] = None,
|
398 |
-
use_np_cache: bool = False,
|
399 |
-
load_format: str = "auto",
|
400 |
-
revision: Optional[str] = None
|
401 |
-
):
|
402 |
-
# if use vllm==0.1.4
|
403 |
-
from vllm.model_executor.weight_utils import (
|
404 |
-
load_tensor_parallel_weights
|
405 |
-
)
|
406 |
-
from vllm.model_executor.parallel_utils.parallel_state import (
|
407 |
-
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
408 |
-
tp_size = get_tensor_model_parallel_world_size()
|
409 |
-
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
410 |
-
|
411 |
-
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
412 |
-
kv_proj_shard_size = (self.config.hidden_size //
|
413 |
-
self.config.num_attention_heads *
|
414 |
-
getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) // tp_size)
|
415 |
-
attention_weight_specs = [
|
416 |
-
# (weight_name, shard_size, offset)
|
417 |
-
("q_proj", q_proj_shard_size, 0),
|
418 |
-
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
419 |
-
("v_proj", kv_proj_shard_size,
|
420 |
-
q_proj_shard_size + kv_proj_shard_size),
|
421 |
-
]
|
422 |
-
state_dict = self.state_dict()
|
423 |
-
need_to_load = len(state_dict)
|
424 |
-
loaded = 0
|
425 |
-
iterator = custom_hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache)
|
426 |
-
|
427 |
-
for name, loaded_weight in iterator:
|
428 |
-
if "rotary_emb.inv_freq" in name:
|
429 |
-
continue
|
430 |
-
|
431 |
-
if "embed_tokens" in name or "lm_head" in name:
|
432 |
-
param = state_dict[name]
|
433 |
-
# Consider padding in the vocab size.
|
434 |
-
padded_vocab_size = (param.shape[0] * tp_size)
|
435 |
-
# num_extra_rows = padded_vocab_size - self.config.vocab_size
|
436 |
-
num_extra_rows = padded_vocab_size - loaded_weight.size(0)
|
437 |
-
load_size = loaded_weight.size()
|
438 |
-
extra_rows = torch.empty(num_extra_rows,
|
439 |
-
loaded_weight.shape[1])
|
440 |
-
extra_rows = extra_rows.to(loaded_weight)
|
441 |
-
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
442 |
-
if num_extra_rows > 0:
|
443 |
-
print(f'Add empty to {num_extra_rows} extra row for {name}')
|
444 |
-
print(f'Load: {name} | {padded_vocab_size=} | {self.config.vocab_size=} | {num_extra_rows=} | {param.size()=} | {loaded_weight.size()=} | {load_size=}')
|
445 |
-
|
446 |
-
is_attention_weight = False
|
447 |
-
for weight_name, shard_size, offset in attention_weight_specs:
|
448 |
-
if weight_name not in name or "qkv_proj" in name:
|
449 |
-
continue
|
450 |
-
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
451 |
-
|
452 |
-
loaded_weight = loaded_weight[
|
453 |
-
shard_size * tensor_model_parallel_rank:shard_size *
|
454 |
-
(tensor_model_parallel_rank + 1)]
|
455 |
-
param_slice = param.data[offset:offset + shard_size]
|
456 |
-
assert param_slice.shape == loaded_weight.shape
|
457 |
-
|
458 |
-
param_slice.copy_(loaded_weight)
|
459 |
-
loaded += 1.0 / 3
|
460 |
-
is_attention_weight = True
|
461 |
-
break
|
462 |
-
if is_attention_weight:
|
463 |
-
continue
|
464 |
-
|
465 |
-
# ! qkv_proj is sharded differently if concatenated into qkv
|
466 |
-
# qkv: qqqq kkkk vvvv
|
467 |
-
# lweight: qq0qq1 kk0kk1 vv0vv1
|
468 |
-
# q_shard_size: hidden_size // tp_size = qq
|
469 |
-
# qkv_s0: qq0_kk0_vv0
|
470 |
-
# qkv_s1: qq1_kk1_vv1
|
471 |
-
if "qkv_proj" in name:
|
472 |
-
param = state_dict[name]
|
473 |
-
# loaded_weight
|
474 |
-
qsize = self.config.hidden_size
|
475 |
-
kvsize = self.config.hidden_size // self.config.num_attention_heads * getattr(self.config, "num_key_value_heads", self.config.num_attention_heads)
|
476 |
-
q_offsets = (
|
477 |
-
q_proj_shard_size * tensor_model_parallel_rank,
|
478 |
-
q_proj_shard_size * (tensor_model_parallel_rank + 1)
|
479 |
-
)
|
480 |
-
k_offsets = (
|
481 |
-
qsize + kv_proj_shard_size * tensor_model_parallel_rank,
|
482 |
-
qsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1)
|
483 |
-
)
|
484 |
-
v_offsets = (
|
485 |
-
qsize + kvsize + kv_proj_shard_size * tensor_model_parallel_rank,
|
486 |
-
qsize + kvsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1)
|
487 |
-
)
|
488 |
-
_loaded_weight = torch.cat(
|
489 |
-
[
|
490 |
-
loaded_weight[q_offsets[0]:q_offsets[1]],
|
491 |
-
loaded_weight[k_offsets[0]:k_offsets[1]],
|
492 |
-
loaded_weight[v_offsets[0]:v_offsets[1]],
|
493 |
-
], 0
|
494 |
-
)
|
495 |
-
assert param.shape == _loaded_weight.shape, f'{param.shape=} != {_loaded_weight.shape=}'
|
496 |
-
param.data.copy_(_loaded_weight)
|
497 |
-
loaded += 1.0
|
498 |
-
is_attention_weight = True
|
499 |
-
if is_attention_weight:
|
500 |
-
continue
|
501 |
-
|
502 |
-
|
503 |
-
is_gate_up_weight = False
|
504 |
-
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
505 |
-
if weight_name not in name or "gate_up_proj" in name:
|
506 |
-
continue
|
507 |
-
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
508 |
-
shard_size = param.shape[0] // 2
|
509 |
-
loaded_weight = loaded_weight[
|
510 |
-
shard_size * tensor_model_parallel_rank:shard_size *
|
511 |
-
(tensor_model_parallel_rank + 1)]
|
512 |
-
param_slice = param.data[shard_size * stride_id:shard_size *
|
513 |
-
(stride_id + 1)]
|
514 |
-
assert param_slice.shape == loaded_weight.shape
|
515 |
-
param_slice.copy_(loaded_weight)
|
516 |
-
loaded += 1.0 / 2
|
517 |
-
is_gate_up_weight = True
|
518 |
-
break
|
519 |
-
if is_gate_up_weight:
|
520 |
-
continue
|
521 |
-
|
522 |
-
if "gate_up_proj" in name:
|
523 |
-
param = state_dict[name]
|
524 |
-
shard_size = param.shape[0] // 2
|
525 |
-
intermediate_size = self.config.intermediate_size
|
526 |
-
g_offsets = (
|
527 |
-
shard_size * tensor_model_parallel_rank,
|
528 |
-
shard_size * (tensor_model_parallel_rank + 1)
|
529 |
-
)
|
530 |
-
u_offsets = (
|
531 |
-
intermediate_size + shard_size * tensor_model_parallel_rank,
|
532 |
-
intermediate_size + shard_size * (tensor_model_parallel_rank + 1)
|
533 |
-
)
|
534 |
-
_loaded_weight = torch.cat(
|
535 |
-
[
|
536 |
-
loaded_weight[g_offsets[0]:g_offsets[1]],
|
537 |
-
loaded_weight[u_offsets[0]:u_offsets[1]],
|
538 |
-
], 0
|
539 |
-
)
|
540 |
-
assert param.shape == _loaded_weight.shape
|
541 |
-
param.data.copy_(_loaded_weight)
|
542 |
-
loaded += 1.0
|
543 |
-
is_gate_up_weight = True
|
544 |
-
if is_gate_up_weight:
|
545 |
-
continue
|
546 |
-
|
547 |
-
|
548 |
-
param = state_dict[name]
|
549 |
-
load_tensor_parallel_weights(param, loaded_weight, name,
|
550 |
-
self._column_parallel_weights,
|
551 |
-
self._row_parallel_weights,
|
552 |
-
tensor_model_parallel_rank)
|
553 |
-
loaded += 1
|
554 |
-
|
555 |
-
if np.abs(loaded - need_to_load) < 0.01:
|
556 |
-
print(f'WARNING: only {loaded} params loaded out of {need_to_load}')
|
557 |
-
else:
|
558 |
-
print(f'Loaded all {loaded} params loaded out of {need_to_load}')
|
559 |
-
|
560 |
-
|
561 |
-
def new_llama_load_weights(
|
562 |
-
self,
|
563 |
-
model_name_or_path: str,
|
564 |
-
cache_dir: Optional[str] = None,
|
565 |
-
load_format: str = "auto",
|
566 |
-
revision: Optional[str] = None
|
567 |
-
):
|
568 |
-
# If use newest vllm, not been thoroughly tested yet.
|
569 |
-
from vllm.model_executor.weight_utils import (
|
570 |
-
load_tensor_parallel_weights, hf_model_weights_iterator
|
571 |
-
)
|
572 |
-
from vllm.model_executor.parallel_utils.parallel_state import (
|
573 |
-
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
574 |
-
|
575 |
-
if self.quant_config is None:
|
576 |
-
weight_suffixes = ["weight"]
|
577 |
-
else:
|
578 |
-
weight_suffixes = self.quant_config.get_tp_tensor_names()
|
579 |
-
|
580 |
-
column_parallel_weights: List[str] = []
|
581 |
-
for layer in self._column_parallel_layers:
|
582 |
-
for suffix in weight_suffixes:
|
583 |
-
column_parallel_weights.append(f"{layer}.{suffix}")
|
584 |
-
row_parallel_weights: List[str] = []
|
585 |
-
for layer in self._row_parallel_layers:
|
586 |
-
for suffix in weight_suffixes:
|
587 |
-
row_parallel_weights.append(f"{layer}.{suffix}")
|
588 |
-
|
589 |
-
tp_size = get_tensor_model_parallel_world_size()
|
590 |
-
tp_rank = get_tensor_model_parallel_rank()
|
591 |
-
assert tp_size == 1, f'tensorparallel >=2 not allowed. {tp_size}'
|
592 |
-
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
593 |
-
num_kv_heads_replicas = max(1,
|
594 |
-
tp_size // self.config.num_key_value_heads)
|
595 |
-
num_kv_heads_per_gpu = max(1,
|
596 |
-
self.config.num_key_value_heads // tp_size)
|
597 |
-
kv_proj_shard_size = (self.config.hidden_size //
|
598 |
-
self.config.num_attention_heads *
|
599 |
-
num_kv_heads_per_gpu)
|
600 |
-
attention_weight_specs = [
|
601 |
-
# (weight_name, shard_size, offset)
|
602 |
-
("q_proj", q_proj_shard_size, 0),
|
603 |
-
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
604 |
-
("v_proj", kv_proj_shard_size,
|
605 |
-
q_proj_shard_size + kv_proj_shard_size),
|
606 |
-
]
|
607 |
-
state_dict = self.state_dict()
|
608 |
-
need_to_load = len(state_dict)
|
609 |
-
loaded = 0
|
610 |
-
|
611 |
-
for name, loaded_weight in hf_model_weights_iterator(
|
612 |
-
model_name_or_path, cache_dir, load_format, revision):
|
613 |
-
if "rotary_emb.inv_freq" in name:
|
614 |
-
continue
|
615 |
-
|
616 |
-
is_packed = False
|
617 |
-
is_transposed = False
|
618 |
-
if self.quant_config is not None:
|
619 |
-
is_packed = self.quant_config.is_packed(name)
|
620 |
-
is_transposed = self.quant_config.is_transposed(name)
|
621 |
-
if is_transposed:
|
622 |
-
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
623 |
-
loaded_weight = loaded_weight.T
|
624 |
-
|
625 |
-
is_attention_weight = False
|
626 |
-
for weight_name, shard_size, offset in attention_weight_specs:
|
627 |
-
if weight_name not in name or "qkv_proj" in name:
|
628 |
-
continue
|
629 |
-
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
630 |
-
if is_transposed:
|
631 |
-
param = param.T
|
632 |
-
|
633 |
-
if is_packed:
|
634 |
-
shard_size //= self.quant_config.pack_factor
|
635 |
-
offset //= self.quant_config.pack_factor
|
636 |
-
|
637 |
-
if weight_name in ["k_proj", "v_proj"]:
|
638 |
-
shard_id = tp_rank // num_kv_heads_replicas
|
639 |
-
else:
|
640 |
-
shard_id = tp_rank
|
641 |
-
loaded_weight = loaded_weight[shard_size *
|
642 |
-
shard_id:shard_size *
|
643 |
-
(shard_id + 1)]
|
644 |
-
param_slice = param.data[offset:offset + shard_size]
|
645 |
-
assert param_slice.shape == loaded_weight.shape
|
646 |
-
|
647 |
-
param_slice.copy_(loaded_weight)
|
648 |
-
loaded += 1.0 / 3
|
649 |
-
is_attention_weight = True
|
650 |
-
break
|
651 |
-
if is_attention_weight:
|
652 |
-
continue
|
653 |
-
|
654 |
-
# TODO: need to figure out to do sharding with qkv_proj fused
|
655 |
-
|
656 |
-
is_gate_up_weight = False
|
657 |
-
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
658 |
-
if weight_name not in name or "gate_up_proj" in name:
|
659 |
-
continue
|
660 |
-
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
661 |
-
if is_transposed:
|
662 |
-
param = param.T
|
663 |
-
|
664 |
-
shard_size = param.shape[0] // 2
|
665 |
-
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
666 |
-
(tp_rank + 1)]
|
667 |
-
param_slice = param.data[shard_size * stride_id:shard_size *
|
668 |
-
(stride_id + 1)]
|
669 |
-
assert param_slice.shape == loaded_weight.shape
|
670 |
-
param_slice.copy_(loaded_weight)
|
671 |
-
loaded += 1.0 / 2
|
672 |
-
is_gate_up_weight = True
|
673 |
-
break
|
674 |
-
if is_gate_up_weight:
|
675 |
-
continue
|
676 |
-
|
677 |
-
# TODO: need to figure out to do sharding with gate_up_proj fused
|
678 |
-
|
679 |
-
param = state_dict[name]
|
680 |
-
if is_transposed:
|
681 |
-
param = param.T
|
682 |
-
|
683 |
-
if "embed_tokens" in name or "lm_head" in name:
|
684 |
-
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
685 |
-
tp_rank)
|
686 |
-
loaded += 1
|
687 |
-
continue
|
688 |
-
|
689 |
-
load_tensor_parallel_weights(param, loaded_weight, name,
|
690 |
-
column_parallel_weights,
|
691 |
-
row_parallel_weights, tp_rank)
|
692 |
-
loaded += 1
|
693 |
-
|
694 |
-
if np.abs(loaded - need_to_load) < 0.01:
|
695 |
-
print(f'WARNING: only {loaded} params loaded out of {need_to_load}')
|
696 |
-
else:
|
697 |
-
print(f'Loaded all {loaded} params loaded out of {need_to_load}')
|
698 |
-
|
699 |
-
|
700 |
-
# Reassign LlamaForCausalLM.load_weights with llama_load_weights
|
701 |
-
if not DEBUG:
|
702 |
-
|
703 |
-
try:
|
704 |
-
import vllm
|
705 |
-
from vllm.model_executor.model_loader import _MODEL_REGISTRY
|
706 |
-
from vllm.model_executor.models import LlamaForCausalLM
|
707 |
-
|
708 |
-
_MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
|
709 |
-
if vllm.__version__ == "0.1.4":
|
710 |
-
LlamaForCausalLM.load_weights = llama_load_weights
|
711 |
-
else:
|
712 |
-
LlamaForCausalLM.load_weights = new_llama_load_weights
|
713 |
-
|
714 |
-
if DTYPE == "bfloat16":
|
715 |
-
try:
|
716 |
-
compute_capability = torch.cuda.get_device_capability()
|
717 |
-
if compute_capability[0] < 8:
|
718 |
-
gpu_name = torch.cuda.get_device_name()
|
719 |
-
print(
|
720 |
-
"Bfloat16 is only supported on GPUs with compute capability "
|
721 |
-
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
722 |
-
f"{compute_capability[0]}.{compute_capability[1]}. --> Move to FLOAT16")
|
723 |
-
DTYPE = "float16"
|
724 |
-
except Exception as e:
|
725 |
-
print(f'Unable to obtain compute_capability: {e}')
|
726 |
-
except Exception as e:
|
727 |
-
print(f'Failing import and reconfigure VLLM: {str(e)}')
|
728 |
-
|
729 |
-
|
730 |
# ! ==================================================================
|
731 |
|
732 |
set_documentation_group("component")
|
@@ -734,41 +290,6 @@ set_documentation_group("component")
|
|
734 |
|
735 |
RES_PRINTED = False
|
736 |
|
737 |
-
def llama_chat_sys_input_seq_constructor(text, sys_prompt=SYSTEM_PROMPT_1, bos_token=BOS_TOKEN, eos_token=EOS_TOKEN):
|
738 |
-
return f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {text} {E_INST}"
|
739 |
-
|
740 |
-
|
741 |
-
def llama_chat_multiturn_sys_input_seq_constructor(
|
742 |
-
message: str,
|
743 |
-
history: List[Tuple[str, str]],
|
744 |
-
sys_prompt=SYSTEM_PROMPT_1,
|
745 |
-
bos_token=BOS_TOKEN,
|
746 |
-
eos_token=EOS_TOKEN,
|
747 |
-
include_end_instruct=True,
|
748 |
-
):
|
749 |
-
"""
|
750 |
-
```
|
751 |
-
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
752 |
-
<bos>[INST] Prompt [/INST] Answer <eos>
|
753 |
-
<bos>[INST] Prompt [/INST]
|
754 |
-
```
|
755 |
-
"""
|
756 |
-
text = ''
|
757 |
-
end_instr = f" {E_INST}" if include_end_instruct else ""
|
758 |
-
for i, (prompt, res) in enumerate(history):
|
759 |
-
if i == 0:
|
760 |
-
text += f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {prompt}{end_instr}"
|
761 |
-
else:
|
762 |
-
text += f"{bos_token}{B_INST} {prompt}{end_instr}"
|
763 |
-
|
764 |
-
if res is not None:
|
765 |
-
text += f" {res} {eos_token} "
|
766 |
-
if len(history) == 0 or text.strip() == '':
|
767 |
-
text = f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {message}{end_instr}"
|
768 |
-
else:
|
769 |
-
text += f"{bos_token}{B_INST} {message}{end_instr}"
|
770 |
-
return text
|
771 |
-
|
772 |
|
773 |
@document()
|
774 |
class ChatBot(gr.Chatbot):
|
@@ -966,29 +487,63 @@ def _setup_events(self) -> None:
|
|
966 |
)
|
967 |
|
968 |
# Reconfigure clear_btn to stop and clear text box
|
969 |
-
# if self.clear_btn:
|
970 |
-
# self.clear_btn.click(
|
971 |
-
# lambda: ([], [], None),
|
972 |
-
# None,
|
973 |
-
# [self.chatbot, self.chatbot_state, self.saved_input],
|
974 |
-
# queue=False,
|
975 |
-
# api_name=False,
|
976 |
-
# cancels=submit_event,
|
977 |
-
# )
|
978 |
|
979 |
|
980 |
def _display_input(
|
981 |
-
self, message: str, history:
|
982 |
-
) ->
|
983 |
if message is not None and message.strip() != "":
|
984 |
history.append([message, None])
|
985 |
return history, history
|
986 |
|
987 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
988 |
# replace
|
989 |
gr.ChatInterface._setup_stop_events = _setup_stop_events
|
990 |
gr.ChatInterface._setup_events = _setup_events
|
991 |
gr.ChatInterface._display_input = _display_input
|
|
|
992 |
|
993 |
|
994 |
@document()
|
@@ -1036,25 +591,6 @@ class CustomTabbedInterface(gr.Blocks):
|
|
1036 |
interface.render()
|
1037 |
|
1038 |
|
1039 |
-
|
1040 |
-
# def vllm_abort(self: Any):
|
1041 |
-
# sh = self.llm_engine.scheduler
|
1042 |
-
# for g in (sh.waiting + sh.running + sh.swapped):
|
1043 |
-
# sh.abort_seq_group(g.request_id)
|
1044 |
-
|
1045 |
-
# from vllm.sequence import SequenceStatus
|
1046 |
-
# scheduler = self.llm_engine.scheduler
|
1047 |
-
# for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
|
1048 |
-
# for seq_group in state_queue:
|
1049 |
-
# # if seq_group.request_id == request_id:
|
1050 |
-
# # Remove the sequence group from the state queue.
|
1051 |
-
# state_queue.remove(seq_group)
|
1052 |
-
# for seq in seq_group.seqs:
|
1053 |
-
# if seq.is_finished():
|
1054 |
-
# continue
|
1055 |
-
# scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
1056 |
-
|
1057 |
-
|
1058 |
def vllm_abort(self):
|
1059 |
sh = self.llm_engine.scheduler
|
1060 |
for g in (sh.waiting + sh.running + sh.swapped):
|
@@ -1231,6 +767,14 @@ def chatml_format(message, history=None, system_prompt=None):
|
|
1231 |
return chatml_chat_convo_format(conversations, True, default_system=system_prompt)
|
1232 |
|
1233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1234 |
def chat_response_stream_multiturn(
|
1235 |
message: str,
|
1236 |
history: List[Tuple[str, str]],
|
@@ -1242,6 +786,9 @@ def chat_response_stream_multiturn(
|
|
1242 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1
|
1243 |
) -> str:
|
1244 |
global LOG_FILE, LOG_PATH
|
|
|
|
|
|
|
1245 |
from vllm import LLM, SamplingParams
|
1246 |
"""Build multi turn
|
1247 |
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
@@ -1274,16 +821,12 @@ def chat_response_stream_multiturn(
|
|
1274 |
|
1275 |
message_safety = safety_check(message, history=history)
|
1276 |
if message_safety is not None:
|
1277 |
-
yield message_safety
|
1278 |
-
|
1279 |
|
1280 |
# history will be appended with message later on
|
1281 |
|
1282 |
-
# full_prompt = llama_chat_multiturn_sys_input_seq_constructor(
|
1283 |
-
# message, history, sys_prompt=system_prompt
|
1284 |
-
# )
|
1285 |
full_prompt = chatml_format(message.strip(), history=history, system_prompt=system_prompt)
|
1286 |
-
# print(full_prompt)
|
1287 |
|
1288 |
if len(tokenizer.encode(full_prompt, add_special_tokens=False)) >= 4050:
|
1289 |
raise gr.Error(f"Conversation or prompt is too long, please clear the chatbox or try shorter input.")
|
@@ -1334,6 +877,89 @@ def chat_response_stream_multiturn(
|
|
1334 |
if message_safety is not None:
|
1335 |
yield message_safety
|
1336 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1337 |
|
1338 |
|
1339 |
def maybe_log_conv_file(current_time, history, message, response, **kwargs):
|
@@ -1715,6 +1341,48 @@ CHAT_EXAMPLES = [
|
|
1715 |
|
1716 |
# performance items
|
1717 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1718 |
|
1719 |
def launch_demo():
|
1720 |
global demo, llm, DEBUG, LOG_FILE
|
@@ -1817,7 +1485,7 @@ def launch_demo():
|
|
1817 |
gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation'),
|
1818 |
gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens'),
|
1819 |
gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens'),
|
1820 |
-
gr.Textbox(value="
|
1821 |
gr.Number(value=0, label='current_time', visible=False),
|
1822 |
],
|
1823 |
outputs=[
|
@@ -1829,11 +1497,13 @@ def launch_demo():
|
|
1829 |
description=FILE_UPLOAD_DESCRIPTION,
|
1830 |
allow_flagging=False,
|
1831 |
examples=[
|
1832 |
-
["upload_chat.json", "chat", 0.2, 1024, 0.5, 0, "
|
1833 |
-
["upload_few_shot.json", "few-shot", 0.2, 128, 0.5, 0, "
|
1834 |
],
|
1835 |
cache_examples=False,
|
1836 |
)
|
|
|
|
|
1837 |
|
1838 |
demo_chat = gr.ChatInterface(
|
1839 |
response_fn,
|
@@ -1869,8 +1539,8 @@ def launch_demo():
|
|
1869 |
descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
|
1870 |
|
1871 |
demo = CustomTabbedInterface(
|
1872 |
-
interface_list=[demo_chat, demo_file_upload],
|
1873 |
-
tab_names=["Chat Interface", "Batch Inference"],
|
1874 |
title=f"{model_title}",
|
1875 |
description=descriptions,
|
1876 |
)
|
|
|
18 |
import glob
|
19 |
import json
|
20 |
import time
|
21 |
+
from gradio.routes import Request
|
22 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
23 |
+
from gradio.helpers import special_args
|
24 |
+
import anyio
|
25 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
26 |
|
27 |
from gradio_client.documentation import document, set_documentation_group
|
28 |
|
|
|
283 |
|
284 |
|
285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
# ! ==================================================================
|
287 |
|
288 |
set_documentation_group("component")
|
|
|
290 |
|
291 |
RES_PRINTED = False
|
292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
|
294 |
@document()
|
295 |
class ChatBot(gr.Chatbot):
|
|
|
487 |
)
|
488 |
|
489 |
# Reconfigure clear_btn to stop and clear text box
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
490 |
|
491 |
|
492 |
def _display_input(
|
493 |
+
self, message: str, history: List[List[Union[str, None]]]
|
494 |
+
) -> Tuple[List[List[Union[str, None]]], List[List[list[Union[str, None]]]]]:
|
495 |
if message is not None and message.strip() != "":
|
496 |
history.append([message, None])
|
497 |
return history, history
|
498 |
|
499 |
|
500 |
+
async def _stream_fn(
|
501 |
+
self,
|
502 |
+
message: str,
|
503 |
+
history_with_input,
|
504 |
+
request: Request,
|
505 |
+
*args,
|
506 |
+
) -> AsyncGenerator:
|
507 |
+
history = history_with_input[:-1]
|
508 |
+
inputs, _, _ = special_args(
|
509 |
+
self.fn, inputs=[message, history, *args], request=request
|
510 |
+
)
|
511 |
+
|
512 |
+
if self.is_async:
|
513 |
+
generator = self.fn(*inputs)
|
514 |
+
else:
|
515 |
+
generator = await anyio.to_thread.run_sync(
|
516 |
+
self.fn, *inputs, limiter=self.limiter
|
517 |
+
)
|
518 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
519 |
+
try:
|
520 |
+
first_response = await async_iteration(generator)
|
521 |
+
update = history + [[message, first_response]]
|
522 |
+
yield update, update
|
523 |
+
except StopIteration:
|
524 |
+
update = history + [[message, None]]
|
525 |
+
yield update, update
|
526 |
+
try:
|
527 |
+
async for response in generator:
|
528 |
+
update = history + [[message, response]]
|
529 |
+
yield update, update
|
530 |
+
except Exception as e:
|
531 |
+
# if "invalid" in str(e):
|
532 |
+
# yield history, history
|
533 |
+
# raise e
|
534 |
+
# else:
|
535 |
+
# raise e
|
536 |
+
yield history, history
|
537 |
+
raise e
|
538 |
+
|
539 |
+
|
540 |
+
|
541 |
+
|
542 |
# replace
|
543 |
gr.ChatInterface._setup_stop_events = _setup_stop_events
|
544 |
gr.ChatInterface._setup_events = _setup_events
|
545 |
gr.ChatInterface._display_input = _display_input
|
546 |
+
gr.ChatInterface._stream_fn = _stream_fn
|
547 |
|
548 |
|
549 |
@document()
|
|
|
591 |
interface.render()
|
592 |
|
593 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
594 |
def vllm_abort(self):
|
595 |
sh = self.llm_engine.scheduler
|
596 |
for g in (sh.waiting + sh.running + sh.swapped):
|
|
|
767 |
return chatml_chat_convo_format(conversations, True, default_system=system_prompt)
|
768 |
|
769 |
|
770 |
+
def debug_chat_response_stream_multiturn(*args, **kwargs):
|
771 |
+
message = "This is a debugging message"
|
772 |
+
for i in range(len(message)):
|
773 |
+
time.sleep(0.05)
|
774 |
+
yield message[:i]
|
775 |
+
|
776 |
+
|
777 |
+
|
778 |
def chat_response_stream_multiturn(
|
779 |
message: str,
|
780 |
history: List[Tuple[str, str]],
|
|
|
786 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1
|
787 |
) -> str:
|
788 |
global LOG_FILE, LOG_PATH
|
789 |
+
if DEBUG:
|
790 |
+
yield from debug_chat_response_stream_multiturn()
|
791 |
+
return
|
792 |
from vllm import LLM, SamplingParams
|
793 |
"""Build multi turn
|
794 |
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
|
|
821 |
|
822 |
message_safety = safety_check(message, history=history)
|
823 |
if message_safety is not None:
|
824 |
+
# yield message_safety
|
825 |
+
raise gr.Error(message_safety)
|
826 |
|
827 |
# history will be appended with message later on
|
828 |
|
|
|
|
|
|
|
829 |
full_prompt = chatml_format(message.strip(), history=history, system_prompt=system_prompt)
|
|
|
830 |
|
831 |
if len(tokenizer.encode(full_prompt, add_special_tokens=False)) >= 4050:
|
832 |
raise gr.Error(f"Conversation or prompt is too long, please clear the chatbox or try shorter input.")
|
|
|
877 |
if message_safety is not None:
|
878 |
yield message_safety
|
879 |
return
|
880 |
+
|
881 |
+
|
882 |
+
|
883 |
+
def debug_generate_free_form_stream(message):
|
884 |
+
output = " This is a debugging message...."
|
885 |
+
for i in range(len(output)):
|
886 |
+
time.sleep(0.05)
|
887 |
+
yield message + output[:i]
|
888 |
+
|
889 |
+
|
890 |
+
def generate_free_form_stream(
|
891 |
+
message: str,
|
892 |
+
temperature: float,
|
893 |
+
max_tokens: int,
|
894 |
+
frequency_penalty: float,
|
895 |
+
presence_penalty: float,
|
896 |
+
current_time: Optional[float] = None,
|
897 |
+
stop_strings: str = '<s>,</s>,<|im_start|>,<|im_end|>',
|
898 |
+
) -> str:
|
899 |
+
global LOG_FILE, LOG_PATH
|
900 |
+
if DEBUG:
|
901 |
+
yield from debug_generate_free_form_stream(message)
|
902 |
+
return
|
903 |
+
from vllm import LLM, SamplingParams
|
904 |
+
"""Build multi turn
|
905 |
+
"""
|
906 |
+
global llm, RES_PRINTED
|
907 |
+
assert llm is not None
|
908 |
+
tokenizer = llm.get_tokenizer()
|
909 |
+
# force removing all
|
910 |
+
vllm_abort(llm)
|
911 |
+
|
912 |
+
temperature = float(temperature)
|
913 |
+
frequency_penalty = float(frequency_penalty)
|
914 |
+
max_tokens = int(max_tokens)
|
915 |
+
|
916 |
+
stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
|
917 |
+
stop_strings = list(set(stop_strings + ['</s>', '<|im_start|>']))
|
918 |
+
|
919 |
+
sampling_params = SamplingParams(
|
920 |
+
temperature=temperature,
|
921 |
+
max_tokens=max_tokens,
|
922 |
+
frequency_penalty=frequency_penalty,
|
923 |
+
presence_penalty=presence_penalty,
|
924 |
+
stop=stop_strings,
|
925 |
+
# ignore_eos=True,
|
926 |
+
)
|
927 |
+
|
928 |
+
# full_prompt = message
|
929 |
+
if len(message) == 0:
|
930 |
+
raise gr.Error("The message cannot be empty!")
|
931 |
+
|
932 |
+
message_safety = safety_check(message)
|
933 |
+
if message_safety is not None:
|
934 |
+
raise gr.Error(message_safety)
|
935 |
+
|
936 |
+
if len(tokenizer.encode(message, add_special_tokens=False)) >= 4050:
|
937 |
+
raise gr.Error(f"Prompt is too long!")
|
938 |
+
|
939 |
+
cur_out = None
|
940 |
+
for j, gen in enumerate(vllm_generate_stream(llm, message, sampling_params)):
|
941 |
+
if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
|
942 |
+
# optionally check safety, and respond
|
943 |
+
if STREAM_CHECK_MULTIPLE > 0 and j % STREAM_CHECK_MULTIPLE == 0:
|
944 |
+
message_safety = safety_check(cur_out, history=None)
|
945 |
+
if message_safety is not None:
|
946 |
+
raise gr.Error(message_safety)
|
947 |
+
yield message + cur_out
|
948 |
+
assert len(gen) == 1, f'{gen}'
|
949 |
+
item = next(iter(gen.values()))
|
950 |
+
cur_out = item.outputs[0].text
|
951 |
+
#cur_out = "Our system is under maintenance, will be back soon!"
|
952 |
+
if j >= max_tokens - 2:
|
953 |
+
gr.Warning(f'The response hits limit of {max_tokens} tokens. Consider increase the max tokens parameter in the Additional Inputs.')
|
954 |
+
|
955 |
+
if cur_out is not None:
|
956 |
+
yield message + cur_out
|
957 |
+
|
958 |
+
message_safety = safety_check(message + cur_out, history=None)
|
959 |
+
if message_safety is not None:
|
960 |
+
raise gr.Error(message_safety)
|
961 |
+
|
962 |
+
|
963 |
|
964 |
|
965 |
def maybe_log_conv_file(current_time, history, message, response, **kwargs):
|
|
|
1341 |
|
1342 |
# performance items
|
1343 |
|
1344 |
+
def create_free_form_generation_demo():
|
1345 |
+
global short_model_path
|
1346 |
+
max_tokens = MAX_TOKENS
|
1347 |
+
temperature = TEMPERATURE
|
1348 |
+
frequence_penalty = FREQUENCE_PENALTY
|
1349 |
+
presence_penalty = PRESENCE_PENALTY
|
1350 |
+
|
1351 |
+
introduction = """
|
1352 |
+
## Free-form:
|
1353 |
+
Put any context string (like few-shot prompts) and get the model to generate.
|
1354 |
+
"""
|
1355 |
+
|
1356 |
+
with gr.Blocks() as demo_free_form:
|
1357 |
+
gr.Markdown(introduction)
|
1358 |
+
|
1359 |
+
with gr.Row():
|
1360 |
+
txt = gr.Textbox(
|
1361 |
+
scale=4,
|
1362 |
+
lines=16,
|
1363 |
+
show_label=False,
|
1364 |
+
placeholder="Enter any free form text and submit",
|
1365 |
+
container=False,
|
1366 |
+
)
|
1367 |
+
with gr.Row():
|
1368 |
+
free_submit_button = gr.Button('Submit')
|
1369 |
+
with gr.Row():
|
1370 |
+
temp = gr.Number(value=temperature, label='Temperature', info="Higher -> more random")
|
1371 |
+
length = gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation')
|
1372 |
+
freq_pen = gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens')
|
1373 |
+
pres_pen = gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens')
|
1374 |
+
stop_strings = gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1)
|
1375 |
+
|
1376 |
+
free_submit_button.click(
|
1377 |
+
generate_free_form_stream,
|
1378 |
+
[txt, temp, length, freq_pen, pres_pen, stop_strings],
|
1379 |
+
txt
|
1380 |
+
)
|
1381 |
+
return demo_free_form
|
1382 |
+
|
1383 |
+
|
1384 |
+
|
1385 |
+
|
1386 |
|
1387 |
def launch_demo():
|
1388 |
global demo, llm, DEBUG, LOG_FILE
|
|
|
1485 |
gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation'),
|
1486 |
gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens'),
|
1487 |
gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens'),
|
1488 |
+
gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1),
|
1489 |
gr.Number(value=0, label='current_time', visible=False),
|
1490 |
],
|
1491 |
outputs=[
|
|
|
1497 |
description=FILE_UPLOAD_DESCRIPTION,
|
1498 |
allow_flagging=False,
|
1499 |
examples=[
|
1500 |
+
["upload_chat.json", "chat", 0.2, 1024, 0.5, 0, "<s>,</s>,<|im_start|>"],
|
1501 |
+
["upload_few_shot.json", "few-shot", 0.2, 128, 0.5, 0, "<s>,</s>,<|im_start|>,\\n"]
|
1502 |
],
|
1503 |
cache_examples=False,
|
1504 |
)
|
1505 |
+
|
1506 |
+
demo_free_form = create_free_form_generation_demo()
|
1507 |
|
1508 |
demo_chat = gr.ChatInterface(
|
1509 |
response_fn,
|
|
|
1539 |
descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
|
1540 |
|
1541 |
demo = CustomTabbedInterface(
|
1542 |
+
interface_list=[demo_chat, demo_file_upload, demo_free_form],
|
1543 |
+
tab_names=["Chat Interface", "Batch Inference", "Free-form"],
|
1544 |
title=f"{model_title}",
|
1545 |
description=descriptions,
|
1546 |
)
|