File size: 5,420 Bytes
3f7cfab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import copy
import torch

from generate import eval_func_param_names, get_score_model, get_model, evaluate, check_locals
from prompter import non_hf_types
from utils import clear_torch_cache, NullContext, get_kwargs


def run_cli(  # for local function:
        base_model=None, lora_weights=None, inference_server=None,
        debug=None, chat_context=None,
        examples=None, memory_restriction_level=None,
        # for get_model:
        score_model=None, load_8bit=None, load_4bit=None, load_half=None, infer_devices=None, tokenizer_base_model=None,
        gpu_id=None, local_files_only=None, resume_download=None, use_auth_token=None,
        trust_remote_code=None, offload_folder=None, compile_model=None,
        # for some evaluate args
        stream_output=None, prompt_type=None, prompt_dict=None,
        temperature=None, top_p=None, top_k=None, num_beams=None,
        max_new_tokens=None, min_new_tokens=None, early_stopping=None, max_time=None, repetition_penalty=None,
        num_return_sequences=None, do_sample=None, chat=None,
        langchain_mode=None, document_choice=None, top_k_docs=None, chunk=None, chunk_size=None,
        # for evaluate kwargs
        src_lang=None, tgt_lang=None, concurrency_count=None, save_dir=None, sanitize_bot_response=None,
        model_state0=None,
        max_max_new_tokens=None,
        is_public=None,
        max_max_time=None,
        raise_generate_gpu_exceptions=None, load_db_if_exists=None, dbs=None, user_path=None,
        detect_user_path_changes_every_query=None,
        use_openai_embedding=None, use_openai_model=None, hf_embedding_model=None,
        db_type=None, n_jobs=None, first_para=None, text_limit=None, verbose=None, cli=None, reverse_docs=None,
        use_cache=None,
        auto_reduce_chunks=None, max_chunks=None, model_lock=None, force_langchain_evaluate=None,
        model_state_none=None,
        # unique to this function:
        cli_loop=None,
):
    check_locals(**locals())

    score_model = ""  # FIXME: For now, so user doesn't have to pass
    n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
    device = 'cpu' if n_gpus == 0 else 'cuda'
    context_class = NullContext if n_gpus > 1 or n_gpus == 0 else torch.device

    with context_class(device):
        from functools import partial

        # get score model
        smodel, stokenizer, sdevice = get_score_model(reward_type=True,
                                                      **get_kwargs(get_score_model, exclude_names=['reward_type'],
                                                                   **locals()))

        model, tokenizer, device = get_model(reward_type=False,
                                             **get_kwargs(get_model, exclude_names=['reward_type'], **locals()))
        model_dict = dict(base_model=base_model, tokenizer_base_model=tokenizer_base_model, lora_weights=lora_weights,
                          inference_server=inference_server, prompt_type=prompt_type, prompt_dict=prompt_dict)
        model_state = dict(model=model, tokenizer=tokenizer, device=device)
        model_state.update(model_dict)
        my_db_state = [None]
        fun = partial(evaluate, model_state, my_db_state,
                      **get_kwargs(evaluate, exclude_names=['model_state', 'my_db_state'] + eval_func_param_names,
                                   **locals()))

        example1 = examples[-1]  # pick reference example
        all_generations = []
        while True:
            clear_torch_cache()
            instruction = input("\nEnter an instruction: ")
            if instruction == "exit":
                break

            eval_vars = copy.deepcopy(example1)
            eval_vars[eval_func_param_names.index('instruction')] = \
                eval_vars[eval_func_param_names.index('instruction_nochat')] = instruction
            eval_vars[eval_func_param_names.index('iinput')] = \
                eval_vars[eval_func_param_names.index('iinput_nochat')] = ''  # no input yet
            eval_vars[eval_func_param_names.index('context')] = ''  # no context yet

            # grab other parameters, like langchain_mode
            for k in eval_func_param_names:
                if k in locals():
                    eval_vars[eval_func_param_names.index(k)] = locals()[k]

            gener = fun(*tuple(eval_vars))
            outr = ''
            res_old = ''
            for gen_output in gener:
                res = gen_output['response']
                extra = gen_output['sources']
                if base_model not in non_hf_types or base_model in ['llama']:
                    if not stream_output:
                        print(res)
                    else:
                        # then stream output for gradio that has full output each generation, so need here to show only new chars
                        diff = res[len(res_old):]
                        print(diff, end='', flush=True)
                        res_old = res
                    outr = res  # don't accumulate
                else:
                    outr += res  # just is one thing
                    if extra:
                        # show sources at end after model itself had streamed to std rest of response
                        print(extra, flush=True)
            all_generations.append(outr + '\n')
            if not cli_loop:
                break
    return all_generations