Spaces:
Sleeping
Sleeping
File size: 2,009 Bytes
32b4fdd a249507 32b4fdd 4f53770 32b4fdd a249507 32b4fdd a249507 32b4fdd 4718ba4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from transformers import AutoModelForCausalLM
import torch
import gradio as gr
import re
model = AutoModelForCausalLM.from_pretrained("Manuel2011/addition_model")
class NumberTokenizer:
def __init__(self, numbers_qty=10):
vocab = ['+', '=', '-1', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
self.numbers_qty = numbers_qty
self.pad_token = '-1'
self.encoder = {str(v):i for i,v in enumerate(vocab)}
self.decoder = {i:str(v) for i,v in enumerate(vocab)}
self.pad_token_id = self.encoder[self.pad_token]
def decode(self, token_ids):
return ' '.join(self.decoder[t] for t in token_ids)
def __call__(self, text):
return [self.encoder[t] for t in text.split()]
tokenizer = NumberTokenizer(13)
def generate_solution(input, solution_length=6, model=model):
try:
parsed_input = re.search(r'(\d)\s*\+\s*(\d)', input)
first_number = int(parsed_input.group(1))
second_number = int(parsed_input.group(2))
except:
return 'Invalid input'
model.eval()
input = f'{first_number} + {second_number} ='
input = torch.tensor(tokenizer(input))
input = input
solution = []
for i in range(solution_length):
output = model(input)
predicted = output.logits[-1].argmax()
input = torch.cat((input, predicted.unsqueeze(0)), dim=0)
solution.append(predicted.cpu().item())
return tokenizer.decode(solution)
def solve(input):
return generate_solution(input, solution_length=2)
demo = gr.Interface(fn=solve, inputs=[gr.Textbox(label="Addition exercise", lines=1, info="The input must be of the form '1 + 2 =', with a single space between each character, and only single-digit numbers are allowed.")],
outputs=[gr.Textbox(label="Result", lines=1)],
title="Simple addition with a GPT-like model",
description="Perform addition of two single-digit numbers using a GPT-like model trained on a small dataset.",
examples=["1 + 2 =", "5 + 7 ="])
demo.launch() |