da03 commited on
Commit
3f861c3
1 Parent(s): 695328d
Files changed (1) hide show
  1. app.py +15 -19
app.py CHANGED
@@ -18,6 +18,13 @@ def postprocess(raw_output):
18
 
19
  @spaces.GPU
20
  def predict_product(num1, num2):
 
 
 
 
 
 
 
21
  input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
22
  inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
23
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
@@ -26,13 +33,6 @@ def predict_product(num1, num2):
26
  raw_output = tokenizer.decode(output, skip_special_tokens=True)
27
  prediction = postprocess(raw_output)
28
 
29
- try:
30
- num1_int = int(num1)
31
- num2_int = int(num2)
32
- valid_input = True
33
- except ValueError:
34
- valid_input = False
35
-
36
  if valid_input:
37
  correct_product = str(num1_int * num2_int)
38
  is_correct = (prediction == correct_product)
@@ -42,33 +42,29 @@ def predict_product(num1, num2):
42
  result_color = "black"
43
  result_message = "Invalid input. Could not evaluate correctness."
44
 
45
- return input_text, raw_output, prediction, result_message, result_color
46
-
47
- def output_component(value, color):
48
- return gr.HTML.update(value=f"<div style='color: {color};'>{value}</div>")
49
 
50
  demo = gr.Interface(
51
  fn=predict_product,
52
- inputs=[gr.Textbox(label='First Number (up to 9 digits)', value='12345'), gr.Textbox(label='Second Number (up to 9 digits)', value='67890')],
 
 
 
53
  outputs=[
54
  gr.Textbox(label='Raw Input to GPT-2'),
55
  gr.Textbox(label='Raw Output from GPT-2'),
56
  gr.Textbox(label='Predicted Product'),
57
  gr.HTML(label='Result Message')
58
  ],
59
- title='GPT-2 Multiplication Predictor',
60
- description='Enter two numbers up to 9 digits each and get the predicted product.',
61
  article="""
62
  ### Additional Resources
63
  - [Paper: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
64
  - [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step)
65
  - [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
66
  """,
67
- css="""
68
- .output-html {
69
- font-size: 1.5em;
70
- }
71
- """
72
  )
73
 
74
  demo.launch()
 
18
 
19
  @spaces.GPU
20
  def predict_product(num1, num2):
21
+ try:
22
+ num1_int = int(num1)
23
+ num2_int = int(num2)
24
+ valid_input = True
25
+ except ValueError:
26
+ valid_input = False
27
+
28
  input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
29
  inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
30
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
 
33
  raw_output = tokenizer.decode(output, skip_special_tokens=True)
34
  prediction = postprocess(raw_output)
35
 
 
 
 
 
 
 
 
36
  if valid_input:
37
  correct_product = str(num1_int * num2_int)
38
  is_correct = (prediction == correct_product)
 
42
  result_color = "black"
43
  result_message = "Invalid input. Could not evaluate correctness."
44
 
45
+ return input_text, raw_output, prediction, result_message
 
 
 
46
 
47
  demo = gr.Interface(
48
  fn=predict_product,
49
+ inputs=[
50
+ gr.Textbox(label='First Number (up to 9 digits)', value='12345'),
51
+ gr.Textbox(label='Second Number (up to 9 digits)', value='67890'),
52
+ ],
53
  outputs=[
54
  gr.Textbox(label='Raw Input to GPT-2'),
55
  gr.Textbox(label='Raw Output from GPT-2'),
56
  gr.Textbox(label='Predicted Product'),
57
  gr.HTML(label='Result Message')
58
  ],
59
+ title='GPT-2 Multiplication Calculator',
60
+ description='This demo uses GPT-2 to directly predict the product of two numbers without using any intermediate steps.',
61
  article="""
62
  ### Additional Resources
63
  - [Paper: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
64
  - [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step)
65
  - [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
66
  """,
67
+ live=False
 
 
 
 
68
  )
69
 
70
  demo.launch()