File size: 2,009 Bytes
32b4fdd
 
 
a249507
32b4fdd
4f53770
32b4fdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a249507
 
 
 
 
 
32b4fdd
a249507
32b4fdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4718ba4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from transformers import AutoModelForCausalLM
import torch
import gradio as gr
import re

model = AutoModelForCausalLM.from_pretrained("Manuel2011/addition_model")

class NumberTokenizer:
  def __init__(self, numbers_qty=10):
    vocab = ['+', '=', '-1', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    self.numbers_qty = numbers_qty
    self.pad_token = '-1'
    self.encoder = {str(v):i for i,v in enumerate(vocab)}
    self.decoder = {i:str(v) for i,v in enumerate(vocab)}
    self.pad_token_id = self.encoder[self.pad_token]

  def decode(self, token_ids):
    return ' '.join(self.decoder[t] for t in token_ids)

  def __call__(self, text):
    return [self.encoder[t] for t in text.split()]

tokenizer = NumberTokenizer(13)

def generate_solution(input, solution_length=6, model=model):
  try:
    parsed_input = re.search(r'(\d)\s*\+\s*(\d)', input)
    first_number = int(parsed_input.group(1))
    second_number = int(parsed_input.group(2))
  except:
    return 'Invalid input'
  model.eval()
  input = f'{first_number} + {second_number} ='
  input = torch.tensor(tokenizer(input))
  input = input
  solution = []
  for i in range(solution_length):
    output = model(input)
    predicted = output.logits[-1].argmax()
    input = torch.cat((input, predicted.unsqueeze(0)), dim=0)
    solution.append(predicted.cpu().item())
  return tokenizer.decode(solution)

def solve(input):
    return generate_solution(input, solution_length=2)

demo = gr.Interface(fn=solve, inputs=[gr.Textbox(label="Addition exercise", lines=1, info="The input must be of the form '1 + 2 =', with a single space between each character, and only single-digit numbers are allowed.")],
                    outputs=[gr.Textbox(label="Result", lines=1)],
                    title="Simple addition with a GPT-like model",
                    description="Perform addition of two single-digit numbers using a GPT-like model trained on a small dataset.",
                    examples=["1 + 2 =", "5 + 7 ="])
    
demo.launch()