File size: 26,605 Bytes
d07b421
 
 
 
 
 
 
 
 
 
 
 
 
5c8369d
d07b421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
import numpy as np
from openai import OpenAI
import os
import tiktoken
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") # use gpt3.5 tokenizer for token number controlling, so we don't need to load the actual tokenizer for API models

NUM_LOGPROBS = {
    'top_prob': 1,
}

MODEL_MAPPING = {
    "Llama-2-70B": "meta-llama/Llama-2-70b-hf",
    "Mistral-7B-v0.1": "mistralai/Mistral-7B-v0.1",
    "Mixtral-8x7B-v0.1": "mistralai/Mixtral-8x7B-v0.1",
    #  Nudging models below 
    "Mistral-7B-v0.1-Instruct": "mistralai/Mistral-7B-Instruct-v0.1",
    "Llama-2-13B-chat": "meta-llama/Llama-2-13b-chat-hf",
    "Gemma-2-2B-it": "google/gemma-2b-it",
}

def apply_instruct_template(model_name, system_prompt, instruct_prompt, response_prompt, add_bos=False):
    model_name = model_name.lower()
    # print(model_name)
    if "chat" in model_name and "llama" in model_name and "2" in model_name:
        return llama_2_chat_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos)
    elif "instruct" in model_name and "llama" in model_name and "3" in model_name:
        if "3.1" in model_name: # for llama-3.1 models, add knowledge cut in system prompmt
            return llama_3_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos, add_knowledge_cut=True)
        else:
            return llama_3_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos)
    elif "it" in model_name and "gemma" in model_name:
        return gemma_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos)
    elif "instruct" in model_name and "olmo" in model_name:
        return olmo_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos) 
    elif "instruct" in model_name and "mistral" in model_name:
        return mistral_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=True)
    else:
        return f"{system_prompt}\n{instruct_prompt}\n{response_prompt}" # non-instruct model or models with unknown template

def mistral_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=True):
    """
    Convert the input and output into the template used for the mistral instruct models training.
    """
    prefix = "<s>" if add_bos else ""
    return prefix + f"[INST] {system_prompt}\n{instruct_prompt} [/INST] {response_prompt}"

def llama_2_chat_template(system_prompt, instruct_prompt, response_prompt, add_bos=False):
    """
    Convert the input and output into the template used for the llama-2 chat models training.
    """
    prefix = "<s>" if add_bos else ""
    return prefix + f"[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{instruct_prompt} [/INST] {response_prompt.lstrip()}"  # for most servers that add <s> automatically so we don't need to add it here

def llama_3_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=False, add_knowledge_cut=False):
    """
    Convert the input and output into the template used for the llama-3 instruct models training.
    """
    # print("applying llama-3 instruct template")
    prefix = "<|begin_of_text|>" if add_bos else ""
    if add_knowledge_cut:
        system_prompt = f"Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n"+ system_prompt
    return prefix + f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{instruct_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{response_prompt}"

def gemma_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=False):
    """
    Convert the input and output into the template used for the gemma instruct models training.
    <bos><start_of_turn>user
    Write a hello world program<end_of_turn>
    <start_of_turn>model
    """
    prefix = "<bos>" if add_bos else ""
    return prefix + f"<start_of_turn>user\n{system_prompt}\n{instruct_prompt}<end_of_turn>\n<start_of_turn>model\n{response_prompt}"

def olmo_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=False):
    """
    Convert the input and output into the template used for the olmo instruct models training.
    """
    return f"<|endoftext|><|user|>\n{system_prompt}\n{instruct_prompt}\n<|assistant|>\n{response_prompt}"

def find_longest_repeated_suffix(s):
    
    # Helper function to check if a substring repeats
    def has_repeated(s, length):
        if length < 30:
            return False
        # Extract the suffix of length 'length'
        suffix = s[-length:]
        # Check the rest of the string for another occurrence
        # return s[:-length].find(suffix) != -1
        return s[:-length].endswith(suffix)

    left, right = 0, len(s)
    result = 0

    # Binary search for the longest repeated suffix
    while left <= right:
        mid = (left + right) // 2
        if has_repeated(s, mid):
            result = mid  # Store the longest length found
            left = mid + 1  # Try for a longer suffix
        else:
            right = mid - 1  # Try for a shorter suffix

    # Return the longest repeated suffix
    if result > 0:
        return s[-result:]
    return None  # Return an empty string if no repetition is found

def remove_redundant_repetitions(s):
    s = s.strip()
    # Find the longest repeated suffix
    longest_repeated_suffix = find_longest_repeated_suffix(s)
    while longest_repeated_suffix:
        # Remove the longest repeated suffix
        s = s[:-len(longest_repeated_suffix)]
        # Find the longest repeated suffix again
        longest_repeated_suffix = find_longest_repeated_suffix(s)
    return s

def repetition_check(new_completion, full_prefix, subseq_len=5):
    words = new_completion.split(" ")
    if len(words) > subseq_len and new_completion in full_prefix:
        return True
    return False

def convert_token_logprobs_to_top_logprobs(token_logprobs, tokens):
    """
    Together AI now only returns token logprobs, this function converts token logprobs to top logprobs format: {token: logprob}
    """
    top_logprobs = [{token: logprob} for token, logprob in zip(tokens, token_logprobs)]
    return top_logprobs

def check_need_nudging(nudging_method,
                        base_token_id,
                        current_base_info, 
                        thresholds,
):
    if nudging_method == 'top_prob':
        # check if the token prob is below the threshold
        sorted_base_top_logprobs = {k: v for k, v in sorted(current_base_info["top_logprobs"][base_token_id].items(), key=lambda item: item[1], reverse=True)}
        base_top_prob = np.exp(list(sorted_base_top_logprobs.values())[0])
        need_nudging = base_top_prob < thresholds['top_prob']
    else:
        raise ValueError(f"Unknown nudging method {nudging_method}")
    return need_nudging

def complete_with_base(nudging_method='top_prob',
                        base_model="davinci-002",
                        full_prefix_base="",
                        output="",
                        current_base_info=None,
                        max_completion_token=256,
                        completion_token_num=16,
                        client_base=None,
                        thresholds=None,
                        temperature=0.0,
                        top_p=0.9,
                        ):
    completion_base = "" if len(current_base_info["completion"]) == 0 else current_base_info["tokens"][0]   # accept the first token from the 1st round which is the acc token from the first stage
    completion_all = "" if len(current_base_info["completion"]) == 0 else current_base_info["tokens"][0]    # completion_all records all the tokens from the base model including the tokens that are not accepted in the last round, for debugging and visualization
    found_nudging_token = False
    response = None
    has_acc_token_stage_1 = True if len(current_base_info["completion"]) > 0 else False                     # if the current_base_info["completion"] is not empty, it means the first token in base completion is accepted from the 1st stage
    EMPTY_INFO_DICT = {
        "completion": "",
        "tokens": [],
        "top_logprobs": [],
        "stop_reason": None, 
        "num_logprobs": NUM_LOGPROBS[nudging_method],
    }
    next_nudging_info = EMPTY_INFO_DICT     # for nudging methods that compute nudging info during base completion, we can save the info for the next round, currently not used for top_prob nudging
    while len(encoding.encode(completion_base)) < max_completion_token and not found_nudging_token:
       
        if current_base_info["completion"] == "":
            # complete the sentence using the base model
            response = client_base.completions.create(
                model=base_model,
                prompt=full_prefix_base + output + completion_base,
                max_tokens=completion_token_num,
                temperature=temperature,
                logprobs=current_base_info["num_logprobs"],
                top_p=top_p,
                )
            current_base_info["tokens"] = response.choices[0].logprobs.tokens
            current_base_info["top_logprobs"] = response.choices[0].logprobs.top_logprobs
            if current_base_info["top_logprobs"] is None:
                current_base_info["top_logprobs"] = convert_token_logprobs_to_top_logprobs(response.choices[0].logprobs.token_logprobs, current_base_info["tokens"])
            current_base_info["completion"] = response.choices[0].text

        if has_acc_token_stage_1:
            # pop the first token from the 1st round as it is already accepted from stage 1
            current_base_info["tokens"] = current_base_info["tokens"][1:]
            current_base_info["top_logprobs"] = current_base_info["top_logprobs"][1:]
            current_base_info["completion"] = "".join(current_base_info["tokens"])
            has_acc_token_stage_1 = False

        completion = current_base_info["completion"]
        tokens = current_base_info["tokens"]

        if completion in completion_base:
            break   # repeated completion, break

        nudging_position = -1

        # find the first token that violates the nudging criteria
        for base_idx in range(len(tokens)):
            found_nudging_token = check_need_nudging(nudging_method=nudging_method, base_token_id=base_idx, current_base_info=current_base_info, thresholds=thresholds)
            if found_nudging_token:
                nudging_position = base_idx
                break
        
        if nudging_position == -1:
            new_completion= "".join(tokens)
        else:
            new_completion = "".join(tokens[:nudging_position])   # include the last agreed token
        # avoid repetition in answer
        if repetition_check(new_completion, output + completion_base):
            break
        else:
            completion_base += new_completion

        if found_nudging_token: # if found the nudging token, break the loop, concat the last base completion to completion_all
            completion_all += completion
        else:
            completion_all += new_completion

        next_nudging_info = EMPTY_INFO_DICT
        if response is not None and response.choices[0].finish_reason == "stop":
            break

        # reset the current_base_info
        current_base_info['completion'] = ""
        current_base_info['tokens'] = []
        current_base_info['top_logprobs'] = []

    return completion_base, completion_all, next_nudging_info

def completion_with_nudging(
        base_model="davinci-002",
        nudging_model="gpt-3.5-turbo",
        system_prompt_base="Answer the question by walking through the reasoning step by step.",
        system_prompt_nudging="Answer the question by walking through the reasoning step by step.",
        question="",
        context="",
        question_prompt="Question: ",
        answer_start_prompt_base="Answer: ",
        answer_start_prompt_nudging="Answer: ",
        completion_token_num=16,
        completion_token_num_nudging=16,
        max_token_total=256,
        print_intermediate_output=False,
        client=None,                # default client
        client_base=None,
        client_nudging=None,
        max_round=150,
        nudging_temperature=0.0,    # deterministic for nudging
        base_temperature=0.0,       # deterministic for base model
        nudging_method='top_prob',
        top_prob_thres=0.3,
        top_p=0.9,
        ):
    if client_base is None:
        client_base = client
    if client_nudging is None:
        client_nudging = client

    if nudging_method not in NUM_LOGPROBS.keys():
        raise ValueError(f"nudging method {nudging_method} number of logprobs not defined")

    full_prefix_base = apply_instruct_template(base_model, system_prompt_base, context + question_prompt + question, answer_start_prompt_base)  # for base model this function just adds newlines
    full_prefix_nudging = apply_instruct_template(nudging_model, system_prompt_nudging, context + question_prompt + question, answer_start_prompt_nudging)

    thresholds = {
        'top_prob': top_prob_thres,
    }

    output = ""
    nudging_round = 0
    all_nudging_words = []
    all_nudging_and_completions = []
    current_nudging_info = {
        "completion": "",
        "tokens": [],
        "top_logprobs": [],
        "stop_reason": None,
        "num_logprobs": NUM_LOGPROBS[nudging_method],
    }
    stop_reason = None
    repeat_nudging_word = 0
    last_nudging_word = ""
    while len(encoding.encode(output)) < max_token_total and nudging_round < max_round:    # use the number of gpt-3.5 token to approximately control the length
        nudging_round += 1
        if current_nudging_info["completion"] == "":
            response = client_nudging.completions.create(
                model=nudging_model,
                prompt=full_prefix_nudging + output,
                max_tokens=completion_token_num_nudging,
                temperature=nudging_temperature,
                logprobs=current_nudging_info["num_logprobs"],
                )
            current_nudging_info["completion"] = response.choices[0].text
            current_nudging_info["tokens"] = response.choices[0].logprobs.tokens
            current_nudging_info["top_logprobs"] = response.choices[0].logprobs.top_logprobs
            if current_nudging_info["top_logprobs"] is None:
                current_nudging_info["top_logprobs"] = convert_token_logprobs_to_top_logprobs(response.choices[0].logprobs.token_logprobs, current_nudging_info["tokens"])
            current_nudging_info["stop_reason"] = response.choices[0].finish_reason

        # if finish_reason is stop, break the loop, also handles nudging completion from previous round
        if current_nudging_info["stop_reason"] == "stop":
            stop_reason = "nudging_model_stop"
            if len(current_nudging_info["completion"]) > 0:
                all_nudging_words.append(current_nudging_info["completion"])
                all_nudging_and_completions.append(current_nudging_info["completion"])
                output += current_nudging_info["completion"]
            break

        # ===================================================================
        # Stage 1: use base model to find the first token that violates the nudging criteria (no need to nudge)
        # ===================================================================
        found_acc_token = False
        current_base_info = {   # will be passed to the next stage
            "completion": "",
            "tokens": [],
            "top_logprobs": [],
            "num_logprobs": NUM_LOGPROBS[nudging_method],
        }
        nudging_text = current_nudging_info["completion"]
        num_whitespaces = len(nudging_text) - len(nudging_text.lstrip(" "))
        space_prefix = " " * num_whitespaces
        current_nudging_words = nudging_text.lstrip(" ").split(" ")     # token leads to some unexpected behaviors, still use nudging word
        nudging_word_id = 0 if len(current_nudging_words) > 1 else 1    # if only one word, always accept the word and go to the next round: it won't go into the loop and found_acc_token will be False
        while not found_acc_token and nudging_word_id < len(current_nudging_words) - 1:
            nudging_word_id += 1                # always accept the first word
            nudging_gen_prefix = space_prefix + " ".join(current_nudging_words[:nudging_word_id])
            current_nudging_word = " " + current_nudging_words[nudging_word_id]  # add a leading space to the current nudging word since the nudging words a split by space
            if current_nudging_word == " ":     # skip the multiple space
                continue
            prefix = full_prefix_base + output + nudging_gen_prefix
            response = client_base.completions.create(
                model=base_model,
                prompt=prefix,
                max_tokens=completion_token_num,
                temperature=base_temperature,
                logprobs=current_base_info["num_logprobs"],
                top_p=top_p,
                )
            current_base_info["tokens"] = response.choices[0].logprobs.tokens
            current_base_info["top_logprobs"] = response.choices[0].logprobs.top_logprobs
            if current_base_info["top_logprobs"] is None:
                current_base_info["top_logprobs"] = convert_token_logprobs_to_top_logprobs(response.choices[0].logprobs.token_logprobs, current_base_info["tokens"])
            current_base_info["completion"] = response.choices[0].text

            # look for the first token that meets the nudging criteria
            first_base_token = current_base_info["tokens"][0]            
            if current_nudging_word.startswith(first_base_token): # check if the current nudging word is the same or starts with the first base token
                found_acc_token = True
            else: 
                found_acc_token = not check_need_nudging(nudging_method,    # check if the token violates the nudging criteria (no need to nudge)
                                                         base_token_id=0,
                                                         current_base_info=current_base_info, 
                                                         thresholds=thresholds)
                
        # here we have either prefix_idx == len(current_nudging_info["tokens"]):    if no token meets the nudging criteria, use the current nudging completion
        # or found_acc_token == True:    if a token violates the nudging criteria, we use the prefix as nudging tokens
        
        nudging_words = space_prefix +  " ".join(current_nudging_words[:nudging_word_id])
        
        # Heuristic: if the nudging words are the same as the last one for three rounds, break the loop
        if nudging_words == last_nudging_word:
            repeat_nudging_word += 1
            if repeat_nudging_word >= 3:
                stop_reason = "repeated_nudging_words"
                break
        else:
            last_nudging_word = nudging_words
            repeat_nudging_word = 0
        all_nudging_words.append(nudging_words)
        output += nudging_words

        if not found_acc_token: # if no base token can be accepted, use the current nudging completion and go to the next round
            all_nudging_and_completions.append(nudging_words)
            # reset the current nudging info and continue to the next round
            current_nudging_info = {
                "completion": "",
                "tokens": [],
                "logprobs": [],
                "stop_reason": None,
                "num_logprobs": NUM_LOGPROBS[nudging_method],
            }
            continue
        if current_base_info["completion"] == "":   # the base model thinks the completion is done, go to the next round. Make sure current_base_info["completion"] is not empty if proceed to the next stage
            all_nudging_and_completions.append(nudging_words)
            current_nudging_info = {
                "completion": "",
                "tokens": [],
                "logprobs": [],
                "stop_reason": None,
                "num_logprobs": NUM_LOGPROBS[nudging_method],
            }
            continue

        # ===================================================================
        # Stage 2: use nudging model to find the first token that meets the nudging criteria (need to nudge)
        # ===================================================================
        max_completion_token = max_token_total - len(encoding.encode(output))
        completion_base, completion_base_all, current_nudging_info = complete_with_base(nudging_method=nudging_method,
                                                                                        base_model=base_model,
                                                                                        full_prefix_base=full_prefix_base,
                                                                                        output=output,
                                                                                        current_base_info=current_base_info,
                                                                                        max_completion_token=max_completion_token,
                                                                                        completion_token_num=completion_token_num,
                                                                                        client_base=client_base,
                                                                                        thresholds=thresholds,
                                                                                        temperature=base_temperature,
                                                                                        top_p=top_p,
                                                                                        )
        # print(f"next_nudging_info: {current_nudging_info}") # debug

        output += completion_base
        all_nudging_and_completions.append(nudging_words + completion_base) # the generated tokens in each round, concating all completion would be the final output
        if print_intermediate_output:
            print(f"************nudging round {nudging_round}************")
            print(f"****nudging words from {nudging_model}****: {nudging_words}")
            print(f"****nudging text****: {nudging_text}")
            print(f"****completion from {base_model}****: {completion_base}")
            print(f"****all completion from {base_model}****: {completion_base_all}")
            print(f"****output****: {output}")
    
    if nudging_round >= max_round and not stop_reason:
        stop_reason = "round"
    if len(encoding.encode(output)) >= max_token_total and not stop_reason:
        stop_reason = "length"
    output = remove_redundant_repetitions(output)
    if print_intermediate_output:
        print(f"************final output************")
        print(f"****output****: {output}")

    all_info = {
        "question": question,
        "context": context,
        "raw_answer": output,
        "all_nudging_words": all_nudging_words,
        "all_completions": all_nudging_and_completions,
        "stop_reason": stop_reason,
        "system_prompt_base": system_prompt_base,
        "system_prompt_nudging": system_prompt_nudging,
        "full_prefix_base": full_prefix_base,
        "full_prefix_nudging": full_prefix_nudging,
    }
    return all_info


def get_nudging_answer(base_model,
                       nudging_model,
                       system_prompt,
                       question,
                       context="",
                       question_prompt="",
                       answer_start_prompt_base="",
                       answer_start_prompt_nudging="",
                       completion_token_num=16,
                       completion_token_num_nudging=16,
                       max_token_total=256,
                       max_round=150,
                       nudging_temperature=0.0,
                       base_temperature=0.0,
                       nudging_method='top_prob',
                       top_prob_thres=0.3,
                       ):
    base_model = MODEL_MAPPING[base_model]
    nudging_model = MODEL_MAPPING[nudging_model]
    # with open('TOGETHER_KEY.txt', 'r') as f:
    #     togetherai_api_key = f.read().strip()
    togetherai_api_key = os.environ.get("TOGETHERAI_API_KEY")
    client = OpenAI(
        api_key=togetherai_api_key,
        base_url="https://api.together.xyz/v1",
        )
    return completion_with_nudging(
        base_model=base_model,
        nudging_model=nudging_model,
        system_prompt_base=system_prompt,
        system_prompt_nudging=system_prompt,
        question=question,
        context=context,
        question_prompt=question_prompt,
        answer_start_prompt_base=answer_start_prompt_base,
        answer_start_prompt_nudging=answer_start_prompt_nudging,
        completion_token_num=completion_token_num,
        completion_token_num_nudging=completion_token_num_nudging,
        max_token_total=max_token_total,
        print_intermediate_output=False,
        client_base=client,
        client_nudging=client,
        max_round=max_round,
        nudging_temperature=nudging_temperature,
        base_temperature=base_temperature,
        nudging_method=nudging_method,
        top_prob_thres=top_prob_thres,
    )

def get_base_answer(base_model,
                    system_prompt,
                    question,
                    max_tokens=256,):
    base_model = MODEL_MAPPING[base_model]
    # with open('TOGETHER_KEY.txt', 'r') as f:
    #     togetherai_api_key = f.read().strip()
    togetherai_api_key = os.environ.get("TOGETHERAI_API_KEY")
    client = OpenAI(
        api_key=togetherai_api_key,
        base_url="https://api.together.xyz/v1",
        )
    response = client.completions.create(
        model=base_model,
        prompt=system_prompt+"\n"+ question,
        max_tokens=max_tokens,
        temperature=0.0,
        logprobs=1,
        )
    return response.choices[0].text