File size: 9,037 Bytes
02df9f8
5c3cea5
02df9f8
 
 
9a65236
371ab67
9a65236
 
 
 
 
 
59a98a5
9a65236
 
3f8feaa
9a65236
0552781
453baa6
0552781
9a65236
59a98a5
 
02df9f8
 
9bfa66b
02df9f8
 
 
9428a07
 
 
 
02df9f8
 
 
 
9a65236
39a2dae
bf65d9e
 
f7dc2d2
9a65236
f7dc2d2
8fa0ae4
70487ef
 
 
9a65236
 
70487ef
f7dc2d2
eaa0586
9a65236
 
 
d8750b1
e5d31e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad4fc9e
e5d31e2
 
9a65236
e5d31e2
 
 
 
 
 
9a65236
e5d31e2
 
 
 
 
 
 
 
 
 
 
 
 
9a65236
e5d31e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02df9f8
cf0cf7e
 
02df9f8
 
3f861c3
313c68d
 
3f861c3
9bfa66b
9a65236
fb97ac0
dc59d8f
 
9bfa66b
85fafa0
 
 
b8ed883
 
85fafa0
 
 
486c21f
 
ebf0fc0
 
02df9f8
ebf0fc0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import spaces
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load models
implicit_cot_model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication-20-digits'
implicit_cot_model = AutoModelForCausalLM.from_pretrained(implicit_cot_model_name)
tokenizer = AutoTokenizer.from_pretrained(implicit_cot_model_name)

no_cot_model_name = 'yuntian-deng/gpt2-no-cot-multiplication'
no_cot_model = AutoModelForCausalLM.from_pretrained(no_cot_model_name)

explicit_cot_model_name = 'yuntian-deng/gpt2-explicit-cot-multiplication-20-digits'
explicit_cot_model = AutoModelForCausalLM.from_pretrained(explicit_cot_model_name)

models = {'implicit': implicit_cot_model, 'no': no_cot_model, 'explicit': explicit_cot_model}

[model.to('cuda' if torch.cuda.is_available() else 'cpu') for model in models.values()]
[model.eval() for model in models.values()]

# Constants
#MAX_PRODUCT_DIGITS_PER_MODEL = {'implicit': 100, 'no': 100, 'explicit': 960}
MAX_PRODUCT_DIGITS_PER_MODEL = {'implicit': 100, 'no': 100, 'explicit': 1070}

def preprocess(num):
    num = str(num).strip().replace(' ', '')
    reversed_num = ' '.join(num[::-1])
    return reversed_num

def postprocess(raw_output):
    prediction = raw_output.replace(' ', '')[::-1]
    return prediction

@spaces.GPU
def predict_product(num1, num2):
    input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
    inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
    [model.to('cuda' if torch.cuda.is_available() else 'cpu') for model in models.values()]

    input_ids = inputs['input_ids']
    input_len = input_ids.shape[-1]
    prediction = ""
    ground_truth_product = ""
    valid_input = True

    try:
        num1_int = int(num1)
        num2_int = int(num2)
        ground_truth_product = str(num1_int * num2_int)
        ground_truth_digits_reversed = list(ground_truth_product)[::-1]
    except ValueError:
        valid_input = False

    generated_ids_per_model = {model_name: inputs['input_ids'].data.clone() for model_name in models}
    finished_per_model = {model_name: False for model_name in models}
    past_key_values_per_model = {model_name: None for model_name in models}
    predicted_annotations_per_model = {}
    try:
        for step in range(max(MAX_PRODUCT_DIGITS_PER_MODEL.values())):  # Set a maximum limit to prevent infinite loops
            # Ground Truth
            if not valid_input:
                ground_truth_annotations = [('Invalid Input!', None)]
            else:
                ground_truth_annotations = [(ground_truth_digit, None) for ground_truth_digit in ground_truth_digits_reversed[:step+1]]
                ground_truth_annotations = ground_truth_annotations[::-1]
            # Predicted
            for model_name in models:
                model = models[model_name]
                if finished_per_model[model_name]:
                    continue
                if step >= MAX_PRODUCT_DIGITS_PER_MODEL[model_name]:
                    continue
                generation_kwargs = {
                    'input_ids': generated_ids_per_model[model_name],
                    'max_new_tokens': 1,
                    'do_sample': False,
                    'past_key_values': past_key_values_per_model[model_name],
                    'return_dict_in_generate': True,
                    'use_cache': True
                }
                if step == 0:
                    del generation_kwargs['past_key_values']
                outputs = model.generate(**generation_kwargs)
                generated_ids = outputs.sequences
                next_token_id = generated_ids[0, -1]
                #print (next_token_id)
                
                if next_token_id.item() == tokenizer.eos_token_id:
                    finished_per_model[model_name] = True
                    if valid_input:
                        if len([item for item in predicted_annotations_per_model[model_name] if item[1] is not None]) < len(ground_truth_digits_reversed):
                            predicted_annotations_per_model[model_name].insert(0, ('⠀', 'wrong'))
                    continue
                    
                generated_ids_per_model[model_name] = generated_ids
                past_key_values_per_model[model_name] = outputs.past_key_values
                
                output_text = tokenizer.decode(generated_ids[0, input_len:], skip_special_tokens=True)
                predicted_digits_reversed = output_text.strip().split(' ')
            
                predicted_annotations = []
                is_correct_sofar = True
                if model_name == 'explicit':
                    if '=' not in predicted_digits_reversed:
                        predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed]
                        predicted_digits_reversed = []
                    else:
                        equal_sign_position = predicted_digits_reversed.index('=')
                        predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed[:equal_sign_position+1]]
                        predicted_digits_reversed = predicted_digits_reversed[equal_sign_position+1:]

                for i in range(len(predicted_digits_reversed)):
                    predicted_digit = predicted_digits_reversed[i]
                    if not valid_input:
                        is_correct_digit = None
                    elif i >= len(ground_truth_digits_reversed):
                        if predicted_digit == '0' and is_correct_sofar:
                            is_correct_digit = True
                        else:
                            is_correct_digit = False
                    else:
                        ground_truth_digit = ground_truth_digits_reversed[i]
                        if predicted_digit == ground_truth_digit:
                            is_correct_digit = True
                        else:
                            is_correct_digit = False
                    if not is_correct_digit:
                        is_correct_sofar = False
                    if is_correct_digit is None:
                        predicted_annotations.append((predicted_digit, None))
                    elif is_correct_digit:
                        predicted_annotations.append((predicted_digit, "correct"))
                    else:
                        predicted_annotations.append((predicted_digit, "wrong"))
                predicted_annotations = predicted_annotations[::-1]
                predicted_annotations_per_model[model_name] = predicted_annotations

            predicted_annotations_implicit_cot = predicted_annotations_per_model['implicit']
            predicted_annotations_nocot = predicted_annotations_per_model['no']
            predicted_annotations_explicit_cot = predicted_annotations_per_model['explicit']

            yield ground_truth_annotations, predicted_annotations_implicit_cot, predicted_annotations_nocot, predicted_annotations_explicit_cot
    except Exception as e:
        pass

color_map = {"correct": "green", "wrong": "red"}

demo = gr.Interface(
    fn=predict_product,
    inputs=[
        gr.Textbox(label='First Number (up to 20 digits)', value='12345678912345678912'),
        gr.Textbox(label='Second Number (up to 20 digits)', value='98765432198765432198'),
    ],
    outputs=[
        gr.HighlightedText(label='Ground Truth Product', combine_adjacent=False, show_legend=False, color_map=color_map),
        gr.HighlightedText(label='Implicit CoT Prediction (Ours)', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
        gr.HighlightedText(label='No CoT Prediction', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
        gr.HighlightedText(label='Explicit CoT Steps & Prediction', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
    ],
    title='Predicting Multiplication with GPT-2: Implicit vs. Explicit CoT',
    description='This demo showcases GPT-2\'s ability to directly predict the product of two large numbers without intermediate steps, using our stepwise internalization method. Compare the performance of implicit CoT (our method), no CoT, and explicit CoT. Implicit CoT offers accuracy and speed, while explicit CoT provides detailed reasoning but is slower.',
    article="""
    - [Paper 1: Implicit Chain of Thought Reasoning via Knowledge Distillation](https://arxiv.org/pdf/2311.01460)
    - [Paper 2: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
    - [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step)
    - [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
    """,
    clear_btn=None,
    submit_btn="Multiply!",
    live=False,
    concurrency_limit=1
)
demo.queue(max_size=20).launch()