import spaces
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load models
implicit_cot_model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
implicit_cot_model = AutoModelForCausalLM.from_pretrained(implicit_cot_model_name)
tokenizer = AutoTokenizer.from_pretrained(implicit_cot_model_name)

no_cot_model_name = 'yuntian-deng/gpt2-no-cot-multiplication'
no_cot_model = AutoModelForCausalLM.from_pretrained(no_cot_model_name)

explicit_cot_model_name = 'yuntian-deng/gpt2-explicit-cot-multiplication'
explicit_cot_model = AutoModelForCausalLM.from_pretrained(explicit_cot_model_name)

models = {'implicit': implicit_cot_model, 'no': no_cot_model, 'explicit': explicit_cot_model}

# Constants
MAX_PRODUCT_DIGITS_PER_MODEL = {'implicit': 100, 'no': 100, 'explicit': 900}

def preprocess(num):
    num = str(num).strip().replace(' ', '')
    reversed_num = ' '.join(num[::-1])
    return reversed_num

def postprocess(raw_output):
    prediction = raw_output.replace(' ', '')[::-1]
    return prediction

@spaces.GPU
def predict_product(num1, num2):
    input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
    inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
    [model.to('cuda' if torch.cuda.is_available() else 'cpu') for model in models.values()]

    input_ids = inputs['input_ids']
    input_len = input_ids.shape[-1]
    prediction = ""
    ground_truth_product = ""
    valid_input = True

    try:
        num1_int = int(num1)
        num2_int = int(num2)
        ground_truth_product = str(num1_int * num2_int)
        ground_truth_digits_reversed = list(ground_truth_product)[::-1]
    except ValueError:
        valid_input = False

    generated_ids_per_model = {model_name: inputs['input_ids'].data.clone() for model_name in models}
    finished_per_model = {model_name: False for model_name in models}
    past_key_values_per_model = {model_name: None for model_name in models}
    predicted_annotations_per_model = {}
    for step in range(max(MAX_PRODUCT_DIGITS_PER_MODEL.values())):  # Set a maximum limit to prevent infinite loops
        # Ground Truth
        ground_truth_annotations = [(ground_truth_digit, None) for ground_truth_digit in ground_truth_digits_reversed[:step+1]]
        ground_truth_annotations = ground_truth_annotations[::-1]
        # Predicted
        for model_name in models:
            model = models[model_name]
            if finished_per_model[model_name]:
                continue
            if step >= MAX_PRODUCT_DIGITS_PER_MODEL[model_name]:
                continue
            generation_kwargs = {
                'input_ids': generated_ids_per_model[model_name],
                'max_new_tokens': 1,
                'do_sample': False,
                'past_key_values': past_key_values_per_model[model_name],
                'return_dict_in_generate': True,
                'use_cache': True
            }
            if step == 0:
                del generation_kwargs['past_key_values']
            outputs = model.generate(**generation_kwargs)
            generated_ids = outputs.sequences
            next_token_id = generated_ids[0, -1]
            #print (next_token_id)
            
            if next_token_id.item() == tokenizer.eos_token_id:
                finished_per_model[model_name] = True
                continue
                
            generated_ids_per_model[model_name] = generated_ids
            past_key_values_per_model[model_name] = outputs.past_key_values
            
            output_text = tokenizer.decode(generated_ids[0, input_len:], skip_special_tokens=True)
            predicted_digits_reversed = output_text.strip().split(' ')
        
            predicted_annotations = []
            is_correct_sofar = True
            if model_name == 'explicit':
                if '=' not in predicted_digits_reversed:
                    predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed]
                    predicted_digits_reversed = []
                else:
                    equal_sign_position = predicted_digits_reversed.index('=')
                    predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed[:equal_sign_position+1]]
                    predicted_digits_reversed = predicted_digits_reversed[equal_sign_position+1:]

            for i in range(len(predicted_digits_reversed)):
                predicted_digit = predicted_digits_reversed[i]
                if i >= len(ground_truth_digits_reversed):
                    if predicted_digit == '0' and is_correct_sofar:
                        is_correct_digit = True
                    else:
                        is_correct_digit = False
                else:
                    ground_truth_digit = ground_truth_digits_reversed[i]
                    if predicted_digit == ground_truth_digit:
                        is_correct_digit = True
                    else:
                        is_correct_digit = False
                if not is_correct_digit:
                    is_correct_sofar = False
                if is_correct_digit:
                    predicted_annotations.append((predicted_digit, "correct"))
                else:
                    predicted_annotations.append((predicted_digit, "wrong"))
            predicted_annotations = predicted_annotations[::-1]
            predicted_annotations_per_model[model_name] = predicted_annotations

        predicted_annotations_implicit_cot = predicted_annotations_per_model['implicit']
        predicted_annotations_nocot = predicted_annotations_per_model['no']
        predicted_annotations_explicit_cot = predicted_annotations_per_model['explicit']

        yield ground_truth_annotations, predicted_annotations_implicit_cot, predicted_annotations_nocot, predicted_annotations_explicit_cot

color_map = {"correct": "green", "wrong": "red"}

demo = gr.Interface(
    fn=predict_product,
    inputs=[
        gr.Textbox(label='First Number (up to 15 digits)', value='123456789'),
        gr.Textbox(label='Second Number (up to 15 digits)', value='987654321'),
    ],
    outputs=[
        gr.HighlightedText(label='Ground Truth Product', combine_adjacent=False, show_legend=False, color_map=color_map),
        gr.HighlightedText(label='Implicit CoT Predicted Product', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
        gr.HighlightedText(label='No CoT Predicted Product', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
        gr.HighlightedText(label='Explicit CoT Predicted Intermediate Steps & Product', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
    ],
    title='Can GPT2 Predict Multiplication of Two Numbers Without Intermediate Steps?',
    description='This demo demonstrates GPT2\'s ability to directly predict the product of two large numbers without intermediate reasoning steps. The GPT2 has been finetuned to internalize chain-of-thought (CoT) reasoning within its hidden states through our stepwise internalization approach. The results demonstrate the effectiveness of implicit CoT (our approach, accurate and fast), compared to no CoT (fast but inaccurate) and explicit CoT (accurate but slow).',
    article="""
    - [Paper: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
    - [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step)
    - [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
    """,
    clear_btn=None,
    submit_btn="Multiply!",
    live=False,
    concurrency_limit=1
)
demo.queue(max_size=20).launch()