koalazf99
commited on
Commit
β’
f2a0886
1
Parent(s):
7737af5
init commit
Browse files- .gitignore +2 -0
- app.py +39 -0
- llm_src/demo.py +68 -0
- llm_src/utils/__init__.py +0 -0
- llm_src/utils/cot/__init__.py +0 -0
- llm_src/utils/cot/get_prompt.py +17 -0
- llm_src/utils/decoder.py +87 -0
- llm_src/utils/fp_substitution.py +77 -0
- llm_src/utils/logger.py +0 -0
- llm_src/utils/solis/__init__.py +0 -0
- llm_src/utils/solis/helper.py +304 -0
- llm_src/utils/solis/solis_solver.py +97 -0
- md_src/demo_description.md +1 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*.txt
|
2 |
+
__pycache__
|
app.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
st.set_page_config(
|
5 |
+
page_title="Solis Demo",
|
6 |
+
page_icon="π",
|
7 |
+
layout="wide",
|
8 |
+
initial_sidebar_state="expanded",
|
9 |
+
menu_items={
|
10 |
+
'About': "Welcome to check Reflection-Of-Thought [website](https://reflection-of-thought.github.io/)!"
|
11 |
+
}
|
12 |
+
)
|
13 |
+
|
14 |
+
test_examples = [
|
15 |
+
"Nancy uploaded 41 pictures to Facebook. She put 37 pics into one album and put the rest into 2 different albums. How many pictures were in each album?",
|
16 |
+
]
|
17 |
+
|
18 |
+
def read_markdown(path):
|
19 |
+
with open(path, "r") as f:
|
20 |
+
output = f.read()
|
21 |
+
st.markdown(output, unsafe_allow_html=True)
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
# Main Demo
|
26 |
+
st.markdown("# Solis Demo")
|
27 |
+
|
28 |
+
# Demo description
|
29 |
+
read_markdown('md_src/demo_description.md')
|
30 |
+
|
31 |
+
question = st.text_input(
|
32 |
+
"Ask a question which requries numerical reasoning:",
|
33 |
+
value=test_examples[0]
|
34 |
+
)
|
35 |
+
|
36 |
+
button = st.button("Run Solis")
|
37 |
+
if not button:
|
38 |
+
st.stop()
|
39 |
+
|
llm_src/demo.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import random
|
3 |
+
|
4 |
+
from utils.cot.get_prompt import get_prompt
|
5 |
+
from utils.decoder import Decoder, answer_cleansing
|
6 |
+
from utils.fp_substitution import fp_substitute, get_nums_from_passage
|
7 |
+
from utils.solis.solis_solver import try_search
|
8 |
+
from utils.solis.helper import *
|
9 |
+
|
10 |
+
|
11 |
+
def demo(decoder: Decoder, x: str, CNT_SUM):
|
12 |
+
random.seed(123)
|
13 |
+
args = get_default_argument()
|
14 |
+
prompt_x = get_prompt()
|
15 |
+
|
16 |
+
orig_nums, _ = get_nums_from_passage(x)
|
17 |
+
if len(orig_nums) > 3:
|
18 |
+
return "Too many operands!"
|
19 |
+
orig_x = prompt_x + f"Q: {x}\nA:"
|
20 |
+
# step 0, original predict
|
21 |
+
try:
|
22 |
+
orig_z = decoder.decode(args, orig_x, CNT_SUM)
|
23 |
+
orig_z = answer_cleansing(args, orig_z)
|
24 |
+
except Exception as e:
|
25 |
+
print(e)
|
26 |
+
orig_z = "Too Frequent!"
|
27 |
+
return orig_z
|
28 |
+
|
29 |
+
# step 1, #TODO skip operand proposal
|
30 |
+
# step 2, substitute
|
31 |
+
fp_data_list = fp_substitute(x, args.substitute_time)
|
32 |
+
fp_results = []
|
33 |
+
for fp_data in fp_data_list:
|
34 |
+
fp_x = prompt_x + f"Q: {fp_data['Question']}\nA:"
|
35 |
+
try:
|
36 |
+
fp_z = decoder.decode(args, fp_x, CNT_SUM)
|
37 |
+
fp_z = answer_cleansing(args, fp_z)
|
38 |
+
except Exception as e:
|
39 |
+
print(e)
|
40 |
+
fp_z = "Too Frequent!"
|
41 |
+
return fp_z
|
42 |
+
fp_results.append({
|
43 |
+
"fp_nums": fp_data["Alignments"],
|
44 |
+
"fp_z": fp_z,
|
45 |
+
})
|
46 |
+
# step 3, arith relationship inversion
|
47 |
+
solis_ret = try_search(args, orig_nums, fp_results)
|
48 |
+
print(solis_ret)
|
49 |
+
return solis_ret
|
50 |
+
|
51 |
+
def get_default_argument():
|
52 |
+
parser = argparse.ArgumentParser(description="Solis")
|
53 |
+
parser.add_argument("--seed", type=int, default=123)
|
54 |
+
parser.add_argument("--api_time_interval", type=float, default=2)
|
55 |
+
parser.add_argument("--max_length", type=int, default=256)
|
56 |
+
parser.add_argument("--substitute_time", type=int, default=5)
|
57 |
+
parser.add_argument("--dataset", type=str, default="multiarith")
|
58 |
+
parser.add_argument("--direct_answer_trigger_for_fewshot", type=str, default="The answer is")
|
59 |
+
args = parser.parse_args()
|
60 |
+
return args
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
test_examples = [
|
64 |
+
"Nancy uploaded 41 pictures to Facebook. She put 37 pics into one album and put the rest into 2 different albums. How many pictures were in each album?",
|
65 |
+
]
|
66 |
+
decoder = Decoder()
|
67 |
+
for test_example in test_examples:
|
68 |
+
demo(decoder, test_example, 0)
|
llm_src/utils/__init__.py
ADDED
File without changes
|
llm_src/utils/cot/__init__.py
ADDED
File without changes
|
llm_src/utils/cot/get_prompt.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
def get_prompt(num_cases=8):
|
4 |
+
"""
|
5 |
+
select prompt examples from COT
|
6 |
+
"""
|
7 |
+
COT_PROMPT = "Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\nA: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6.\n\nQ: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\nA: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5.\n\nQ: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?\nA: Leah had 32 chocolates and Leah's sister had 42. That means there were originally 32 + 42 = 74 chocolates. 35 have been eaten. So in total they still have 74 - 35 = 39 chocolates. The answer is 39.\n\nQ: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?\nA: Jason had 20 lollipops. Since he only has 12 now, he must have given the rest to Denny. The number of lollipops he has given to Denny must have been 20 - 12 = 8 lollipops. The answer is 8.\n\nQ: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\nA: He has 5 toys. He got 2 from mom, so after that he has 5 + 2 = 7 toys. Then he got 2 more from dad, so in total he has 7 + 2 = 9 toys. The answer is 9.\n\nQ: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?\nA: There are 4 days from monday to thursday. 5 computers were added each day. That means in total 4 * 5 = 20 computers were added. There were 9 computers in the beginning, so now there are 9 + 20 = 29 computers. The answer is 29.\n\nQ: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\nA: Michael initially had 58 balls. He lost 23 on Tuesday, so after that he has 58 - 23 = 35 balls. On Wednesday he lost 2 more so now he has 35 - 2 = 33 balls. The answer is 33.\n\nQ: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\nA: She bought 5 bagels for $3 each. This means she spent 5 * $3 = $15 on the bagels. She had $23 in beginning, so now she has $23 - $15 = $8. The answer is 8."
|
8 |
+
|
9 |
+
COT_prompt_list = COT_PROMPT.split('\n\n')
|
10 |
+
random.shuffle(COT_prompt_list)
|
11 |
+
COT_prompt = '\n\n'.join(COT_prompt_list[:num_cases]) + '\n\n'
|
12 |
+
|
13 |
+
return COT_prompt
|
14 |
+
|
15 |
+
|
16 |
+
if __name__ == "__main__":
|
17 |
+
print(get_prompt())
|
llm_src/utils/decoder.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import openai
|
3 |
+
import time
|
4 |
+
|
5 |
+
|
6 |
+
class Decoder():
|
7 |
+
def __init__(self):
|
8 |
+
pass
|
9 |
+
|
10 |
+
def decode(self, args, input, CNT_SUM):
|
11 |
+
response = decoder_for_gpt3(args, input, CNT_SUM)
|
12 |
+
return response
|
13 |
+
|
14 |
+
# Sentence Generator (Decoder) for GPT-3 ...
|
15 |
+
def decoder_for_gpt3(args, input, CNT_SUM):
|
16 |
+
|
17 |
+
# GPT-3 API allows each users execute the API within 60 times in a minute ...
|
18 |
+
time.sleep(args.api_time_interval)
|
19 |
+
|
20 |
+
# https://beta.openai.com/account/api-keys
|
21 |
+
openai.api_key = ""
|
22 |
+
engine = "code-davinci-002"
|
23 |
+
|
24 |
+
response = openai.Completion.create(
|
25 |
+
engine=engine,
|
26 |
+
prompt=input,
|
27 |
+
max_tokens=args.max_length,
|
28 |
+
temperature=0,
|
29 |
+
stop=['--', '\n\n', '#'],
|
30 |
+
)
|
31 |
+
|
32 |
+
return response["choices"][0]["text"]
|
33 |
+
|
34 |
+
# ver 0.2
|
35 |
+
def answer_cleansing(args, pred):
|
36 |
+
|
37 |
+
print("pred_before : " + pred)
|
38 |
+
|
39 |
+
preds = pred.split(args.direct_answer_trigger_for_fewshot)
|
40 |
+
answer_flag = True if len(preds) > 1 else False
|
41 |
+
pred = preds[-1]
|
42 |
+
|
43 |
+
if args.dataset in ("aqua", "commonsensqa"):
|
44 |
+
pred = re.findall(r'A|B|C|D|E', pred)
|
45 |
+
elif args.dataset == "bigbench_date":
|
46 |
+
pred = re.findall(r'A|B|C|D|E|F', pred)
|
47 |
+
elif args.dataset in ("object_tracking"):
|
48 |
+
pred = re.findall(r'A|B|C', pred)
|
49 |
+
elif args.dataset in ("gsm8k", "addsub", "multiarith", "svamp", "singleeq"):
|
50 |
+
pred = pred.replace(",", "")
|
51 |
+
pred = [s for s in re.findall(r'-?\d+\.?\d*', pred)]
|
52 |
+
elif args.dataset in ("strategyqa", "coin_flip"):
|
53 |
+
pred = pred.lower()
|
54 |
+
pred = re.sub("\"|\'|\n|\.|\s|\:|\,"," ", pred)
|
55 |
+
pred = pred.split(" ")
|
56 |
+
pred = [i for i in pred if i in ("yes", "no")]
|
57 |
+
elif args.dataset == "last_letters":
|
58 |
+
pred = re.sub("\"|\'|\n|\.|\s","", pred)
|
59 |
+
pred = [pred]
|
60 |
+
else:
|
61 |
+
raise ValueError("dataset is not properly defined ...")
|
62 |
+
|
63 |
+
# If there is no candidate in list, null is set.
|
64 |
+
if len(pred) == 0:
|
65 |
+
pred = ""
|
66 |
+
else:
|
67 |
+
if answer_flag:
|
68 |
+
# choose the first element in list ...
|
69 |
+
pred = pred[0]
|
70 |
+
else:
|
71 |
+
# choose the last element in list ...
|
72 |
+
pred = pred[-1]
|
73 |
+
|
74 |
+
# (For arithmetic tasks) if a word ends with period, it will be omitted ...
|
75 |
+
if pred != "":
|
76 |
+
if pred[-1] == ".":
|
77 |
+
pred = pred[:-1]
|
78 |
+
|
79 |
+
print("pred_after : " + pred)
|
80 |
+
return pred
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == "__main__":
|
84 |
+
|
85 |
+
test_str = "/* Create a JavaScript dictionary of 5 countries and capitals: */\n"
|
86 |
+
z = decoder_for_gpt3(None, test_str, 256, 0)
|
87 |
+
print(z)
|
llm_src/utils/fp_substitution.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def get_nums_from_passage(p):
|
5 |
+
words = p.strip().split()
|
6 |
+
nums = []
|
7 |
+
nums_id = []
|
8 |
+
for idx, word in enumerate(words):
|
9 |
+
try:
|
10 |
+
word = word.replace(',','').replace('$', '')
|
11 |
+
num = eval(word)
|
12 |
+
if isinstance(num, int) or isinstance(num, float):
|
13 |
+
nums.append(num)
|
14 |
+
nums_id.append(idx)
|
15 |
+
except:
|
16 |
+
continue
|
17 |
+
return nums, nums_id
|
18 |
+
|
19 |
+
|
20 |
+
def generate_fp_nums_multiarith(orig_nums, repeat_time):
|
21 |
+
# erase cannot be divided exactly
|
22 |
+
nums_candis = []
|
23 |
+
ks = [[1, 2, 3], [1, 2, 5], [1, 2, 7], [1, 3, 5]]
|
24 |
+
for x0 in [1, 2, 3, 4, 5]:
|
25 |
+
for k in ks:
|
26 |
+
nums_tmp = k
|
27 |
+
nums_tmp = [_ * x0 for _ in nums_tmp]
|
28 |
+
nums_candis.append(nums_tmp)
|
29 |
+
|
30 |
+
fp_repl_numbers_list = []
|
31 |
+
num_repl = len(orig_nums)
|
32 |
+
for num_var in range(repeat_time):
|
33 |
+
idx_candi = random.randint(0, len(nums_candis) - 1)
|
34 |
+
num_pick = sorted(nums_candis[idx_candi])
|
35 |
+
num_order = np.argsort(orig_nums).tolist()
|
36 |
+
cur_rdx = 0
|
37 |
+
idx_ = 0
|
38 |
+
fp_nums = [-1 for _ in range(num_repl)]
|
39 |
+
while idx_ < num_repl:
|
40 |
+
fp_nums[num_order[idx_]] = num_pick[cur_rdx]
|
41 |
+
idx_ += 1
|
42 |
+
if idx_ == num_repl:
|
43 |
+
break
|
44 |
+
if orig_nums[num_order[idx_]] != orig_nums[num_order[idx_ - 1]] or \
|
45 |
+
orig_nums[num_order[idx_]] < 10:
|
46 |
+
cur_rdx += 1
|
47 |
+
|
48 |
+
fp_repl_numbers_list.append(fp_nums)
|
49 |
+
|
50 |
+
return fp_repl_numbers_list
|
51 |
+
|
52 |
+
|
53 |
+
def fp_substitute(p, s_time)-> list:
|
54 |
+
|
55 |
+
orig_nums, nums_id = get_nums_from_passage(p)
|
56 |
+
fp_nums_list = generate_fp_nums_multiarith(orig_nums, s_time)
|
57 |
+
data_fp_list = []
|
58 |
+
words = p.strip().split()
|
59 |
+
for fp_idx, fp_nums in enumerate(fp_nums_list):
|
60 |
+
q = []
|
61 |
+
num_idx = 0
|
62 |
+
for idx, word in enumerate(words):
|
63 |
+
if idx in nums_id:
|
64 |
+
q.append(str(fp_nums[num_idx]))
|
65 |
+
if word[-1] == '.':
|
66 |
+
q[-1] += '.'
|
67 |
+
num_idx += 1
|
68 |
+
else:
|
69 |
+
q.append(word)
|
70 |
+
q = " ".join(q)
|
71 |
+
item_fp = ({
|
72 |
+
"Question": q,
|
73 |
+
"Alignments": fp_nums,
|
74 |
+
})
|
75 |
+
data_fp_list.append(item_fp)
|
76 |
+
|
77 |
+
return data_fp_list
|
llm_src/utils/logger.py
ADDED
File without changes
|
llm_src/utils/solis/__init__.py
ADDED
File without changes
|
llm_src/utils/solis/helper.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sympy import simplify
|
2 |
+
|
3 |
+
|
4 |
+
def generate_func_set(num_x, num_op, show_funx=False, op_sets=['+', '-', '*', '/'], rep=True):
|
5 |
+
"""
|
6 |
+
e.g. numx = 3, num_op = 2, must return 3 lists:
|
7 |
+
[0, 1, 2] => var order
|
8 |
+
['+', '-'] => op order
|
9 |
+
[0, 1] => barcket start,end pos
|
10 |
+
"""
|
11 |
+
def dfs_op(tmp_op_list, max_num):
|
12 |
+
if len(tmp_op_list) == max_num:
|
13 |
+
if tmp_op_list not in op_lists:
|
14 |
+
op_lists.append(tmp_op_list)
|
15 |
+
return
|
16 |
+
|
17 |
+
for op in op_sets:
|
18 |
+
# if op not in tmp_op_list:
|
19 |
+
dfs_op(tmp_op_list + [op], max_num)
|
20 |
+
# tmp_op_list = tmp_op_list[:-1]
|
21 |
+
# print(tmp_op_list)
|
22 |
+
return
|
23 |
+
|
24 |
+
def dfs_x(tmp_x_list, max_var_num, max_num):
|
25 |
+
if len(tmp_x_list) == max_num:
|
26 |
+
if len(set(tmp_x_list)) == max_var_num:
|
27 |
+
x_lists.append(tmp_x_list)
|
28 |
+
return
|
29 |
+
|
30 |
+
for x in x_sets:
|
31 |
+
# tmp_x_list.append(x)
|
32 |
+
if rep or (not rep and x not in tmp_x_list):
|
33 |
+
dfs_x(tmp_x_list + [x], max_var_num, max_num)
|
34 |
+
# tmp_x_list = tmp_x_list[:-1]
|
35 |
+
return
|
36 |
+
|
37 |
+
def dfs_parenthesis(expr_list, left, right):
|
38 |
+
if left > right:
|
39 |
+
return []
|
40 |
+
ans = []
|
41 |
+
for i in range(left, right + 1):
|
42 |
+
if 'x' in expr_list[i]:
|
43 |
+
continue
|
44 |
+
ans_left = dfs_parenthesis(expr_list, left, i - 1)
|
45 |
+
ans_right = dfs_parenthesis(expr_list, i + 1, right)
|
46 |
+
|
47 |
+
for a in ans_left:
|
48 |
+
for b in ans_right:
|
49 |
+
if a.count('x') > 1:
|
50 |
+
a = '(' + a + ')'
|
51 |
+
if b.count('x') > 1:
|
52 |
+
b = '(' + b + ')'
|
53 |
+
cur = ""
|
54 |
+
if expr_list[i] == "+":
|
55 |
+
cur = a + '+' + b
|
56 |
+
elif expr_list[i] == "-":
|
57 |
+
cur = a + '-' + b
|
58 |
+
elif expr_list[i] == '*':
|
59 |
+
cur = a + '*' + b
|
60 |
+
elif expr_list[i] == '/':
|
61 |
+
cur = a + '/' + b
|
62 |
+
# print(cur)
|
63 |
+
ans.append(cur)
|
64 |
+
if not ans:
|
65 |
+
ans.append(expr_list[left])
|
66 |
+
# print(ans)
|
67 |
+
return ans
|
68 |
+
|
69 |
+
x_sets = [_ for _ in range(num_x)]
|
70 |
+
# op_sets = ['/', '-']
|
71 |
+
n_op = len(op_sets)
|
72 |
+
|
73 |
+
op_lists = []
|
74 |
+
x_lists = []
|
75 |
+
|
76 |
+
# Generate all OP combinations, all X combinations
|
77 |
+
dfs_op([], max_num=num_op)
|
78 |
+
dfs_x([], max_var_num=num_x, max_num=num_op+1)
|
79 |
+
|
80 |
+
exprs = []
|
81 |
+
symps = []
|
82 |
+
|
83 |
+
# Make up all operations, add parentthesis, remove repeated cases
|
84 |
+
for op_list in op_lists:
|
85 |
+
for x_list in x_lists:
|
86 |
+
if len(x_list) != len(op_list) + 1:
|
87 |
+
continue
|
88 |
+
expr = ""
|
89 |
+
for i in range(len(op_list)):
|
90 |
+
expr = expr + "x" + str(x_list[i])# + ' '
|
91 |
+
expr = expr + op_list[i]# + ' '
|
92 |
+
# Make up expression without ()
|
93 |
+
expr = expr + "x" + str(x_list[-1])
|
94 |
+
|
95 |
+
# Make up expression with ()
|
96 |
+
e_list = expr.replace(' ', '').replace("+", " + ").replace("-", " - ").replace("*", " * ").replace("/", " / ").split(' ')
|
97 |
+
|
98 |
+
# ic(e_list)
|
99 |
+
if rep:
|
100 |
+
exprs_parenth = dfs_parenthesis(e_list, 0, len(e_list) - 1)
|
101 |
+
exprs_parenth.insert(0, expr)
|
102 |
+
else:
|
103 |
+
exprs_parenth = [expr]
|
104 |
+
|
105 |
+
# remove repeat
|
106 |
+
for expr in exprs_parenth:
|
107 |
+
symp = simplify(expr)
|
108 |
+
if all(op not in str(symp) for op in op_sets): # remove out of func space
|
109 |
+
continue
|
110 |
+
if symp not in symps:
|
111 |
+
symps.append(symp)
|
112 |
+
|
113 |
+
e_list = str(symp).replace(' ', '').replace("+", " + ").replace("-", " - ").replace("*", " * ").replace("/", " / ").replace("(", " ( ").replace(")", " ) ").replace("**", " ** ").split(' ')
|
114 |
+
|
115 |
+
if any(e.isnumeric() for e in e_list): # do not allow const in a expression, only to add manually
|
116 |
+
expr_simple = expr
|
117 |
+
else:
|
118 |
+
expr_simple = str(symp)
|
119 |
+
|
120 |
+
x_sets_ = list(set([e if 'x' in e else '' for e in e_list]))
|
121 |
+
x_sets_.remove('')
|
122 |
+
# ic(x_sets_, x_sets)
|
123 |
+
if x_sets_ is None or len(x_sets_) < len(x_sets):
|
124 |
+
continue
|
125 |
+
|
126 |
+
exprs.append(expr_simple.replace(' ',''))
|
127 |
+
|
128 |
+
# for _ in exprs:
|
129 |
+
# ic(_)
|
130 |
+
|
131 |
+
return op_lists, x_lists, exprs
|
132 |
+
|
133 |
+
|
134 |
+
def generate_func_set_all(num_x, num_op, show_funx=False, op_sets=['+', '-', '*', '/'], rep=True):
|
135 |
+
"""
|
136 |
+
e.g. numx = 3, num_op = 2, must return 3 lists:
|
137 |
+
[0, 1, 2] => var order
|
138 |
+
['+', '-'] => op order
|
139 |
+
[0, 1] => barcket start,end pos
|
140 |
+
"""
|
141 |
+
def dfs_op(tmp_op_list, max_num):
|
142 |
+
if len(tmp_op_list) == max_num:
|
143 |
+
if tmp_op_list not in op_lists:
|
144 |
+
op_lists.append(tmp_op_list)
|
145 |
+
return
|
146 |
+
|
147 |
+
elif len(tmp_op_list) < max_num and len(tmp_op_list) > 0:
|
148 |
+
if tmp_op_list not in op_lists:
|
149 |
+
op_lists.append(tmp_op_list)
|
150 |
+
|
151 |
+
for op in op_sets:
|
152 |
+
dfs_op(tmp_op_list + [op], max_num)
|
153 |
+
# tmp_op_list = tmp_op_list[:-1]
|
154 |
+
# print(tmp_op_list)
|
155 |
+
return
|
156 |
+
|
157 |
+
def dfs_x(tmp_x_list, max_var_num, max_num):
|
158 |
+
if len(tmp_x_list) == max_num:
|
159 |
+
if len(set(tmp_x_list)) == max_var_num:
|
160 |
+
x_lists.append(tmp_x_list)
|
161 |
+
return
|
162 |
+
|
163 |
+
elif len(tmp_x_list) < max_num and len(tmp_x_list) > 0:
|
164 |
+
if tmp_x_list not in x_lists:
|
165 |
+
x_lists.append(tmp_x_list)
|
166 |
+
|
167 |
+
for x in x_sets:
|
168 |
+
# tmp_x_list.append(x)
|
169 |
+
if rep or (not rep and x not in tmp_x_list):
|
170 |
+
dfs_x(tmp_x_list + [x], max_var_num, max_num)
|
171 |
+
# tmp_x_list = tmp_x_list[:-1]
|
172 |
+
return
|
173 |
+
|
174 |
+
def dfs_parenthesis(expr_list, left, right):
|
175 |
+
if left > right:
|
176 |
+
return []
|
177 |
+
ans = []
|
178 |
+
for i in range(left, right + 1):
|
179 |
+
if 'x' in expr_list[i]:
|
180 |
+
continue
|
181 |
+
ans_left = dfs_parenthesis(expr_list, left, i - 1)
|
182 |
+
ans_right = dfs_parenthesis(expr_list, i + 1, right)
|
183 |
+
|
184 |
+
for a in ans_left:
|
185 |
+
for b in ans_right:
|
186 |
+
if a.count('x') > 1:
|
187 |
+
a = '(' + a + ')'
|
188 |
+
if b.count('x') > 1:
|
189 |
+
b = '(' + b + ')'
|
190 |
+
cur = ""
|
191 |
+
if expr_list[i] == "+":
|
192 |
+
cur = a + '+' + b
|
193 |
+
elif expr_list[i] == "-":
|
194 |
+
cur = a + '-' + b
|
195 |
+
elif expr_list[i] == '*':
|
196 |
+
cur = a + '*' + b
|
197 |
+
elif expr_list[i] == '/':
|
198 |
+
cur = a + '/' + b
|
199 |
+
# print(cur)
|
200 |
+
ans.append(cur)
|
201 |
+
if not ans:
|
202 |
+
ans.append(expr_list[left])
|
203 |
+
# print(ans)
|
204 |
+
return ans
|
205 |
+
|
206 |
+
x_sets = [_ for _ in range(num_x)]
|
207 |
+
# op_sets = ['/', '-']
|
208 |
+
n_op = len(op_sets)
|
209 |
+
|
210 |
+
op_lists = []
|
211 |
+
x_lists = []
|
212 |
+
|
213 |
+
# Generate all OP combinations, all X combinations
|
214 |
+
dfs_op([], max_num=num_op)
|
215 |
+
dfs_x([], max_var_num=num_x, max_num=num_op+1)
|
216 |
+
|
217 |
+
# ic(op_lists, x_lists)
|
218 |
+
|
219 |
+
|
220 |
+
exprs = []
|
221 |
+
symps = []
|
222 |
+
|
223 |
+
# Make up all operations, add parentthesis, remove repeated cases
|
224 |
+
for op_list in op_lists:
|
225 |
+
for x_list in x_lists:
|
226 |
+
if len(x_list) != len(op_list) + 1:
|
227 |
+
continue
|
228 |
+
expr = ""
|
229 |
+
for i in range(len(op_list)):
|
230 |
+
expr = expr + "x" + str(x_list[i])# + ' '
|
231 |
+
expr = expr + op_list[i]# + ' '
|
232 |
+
# Make up expression without ()
|
233 |
+
expr = expr + "x" + str(x_list[-1])
|
234 |
+
|
235 |
+
# Make up expression with ()
|
236 |
+
e_list = expr.replace(' ', '').replace("+", " + ").replace("-", " - ").replace("*", " * ").replace("/", " / ").split(' ')
|
237 |
+
|
238 |
+
# ic(expr)
|
239 |
+
|
240 |
+
# ic(e_list)
|
241 |
+
if rep:
|
242 |
+
exprs_parenth = dfs_parenthesis(e_list, 0, len(e_list) - 1)
|
243 |
+
exprs_parenth.insert(0, expr)
|
244 |
+
else:
|
245 |
+
exprs_parenth = [expr]
|
246 |
+
|
247 |
+
# remove repeat
|
248 |
+
for expr in exprs_parenth:
|
249 |
+
symp = simplify(expr)
|
250 |
+
if all(op not in str(symp) for op in op_sets): # remove out of func space
|
251 |
+
continue
|
252 |
+
if symp not in symps:
|
253 |
+
symps.append(symp)
|
254 |
+
|
255 |
+
e_list = str(symp).replace(' ', '').replace("+", " + ").replace("-", " - ").replace("*", " * ").replace("/", " / ").replace("(", " ( ").replace(")", " ) ").replace("**", " ** ").split(' ')
|
256 |
+
|
257 |
+
if any(e.isnumeric() for e in e_list): # do not allow const in a expression, only to add manually
|
258 |
+
expr_simple = expr
|
259 |
+
else:
|
260 |
+
expr_simple = str(symp)
|
261 |
+
|
262 |
+
x_sets_ = list(set([e if 'x' in e else '' for e in e_list]))
|
263 |
+
x_sets_.remove('')
|
264 |
+
# ic(x_sets_, x_sets)
|
265 |
+
if x_sets_ is None:
|
266 |
+
continue
|
267 |
+
|
268 |
+
exprs.append(expr_simple.replace(' ',''))
|
269 |
+
|
270 |
+
# for _ in exprs:
|
271 |
+
# ic(_)
|
272 |
+
|
273 |
+
return op_lists, x_lists, exprs
|
274 |
+
|
275 |
+
|
276 |
+
def is_func(str_1, rets, vars_sets, prec=1e-1, return_loss=False):
|
277 |
+
var_dict = {}
|
278 |
+
for _, var_set in enumerate(vars_sets):
|
279 |
+
for i, var in enumerate(var_set):
|
280 |
+
var_dict.update({
|
281 |
+
f"x{i}": var,
|
282 |
+
})
|
283 |
+
try:
|
284 |
+
ret_1 = eval(str_1, var_dict)
|
285 |
+
rets[_] = eval(rets[_])
|
286 |
+
if not isinstance(rets[_], int) and not isinstance(rets[_], float):
|
287 |
+
if not return_loss:
|
288 |
+
return False
|
289 |
+
else:
|
290 |
+
return False, abs(ret_1)
|
291 |
+
except:
|
292 |
+
if not return_loss:
|
293 |
+
return False
|
294 |
+
else:
|
295 |
+
return False, 1e20
|
296 |
+
if abs(ret_1 - rets[_]) > prec:
|
297 |
+
if not return_loss:
|
298 |
+
return False
|
299 |
+
else:
|
300 |
+
return False, abs(ret_1 - rets[_])
|
301 |
+
if not return_loss:
|
302 |
+
return True
|
303 |
+
else:
|
304 |
+
return True, abs(ret_1 - rets[_])
|
llm_src/utils/solis/solis_solver.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from utils.solis.helper import generate_func_set, generate_func_set_all, is_func
|
3 |
+
|
4 |
+
def try_search(args, orig_nums, fp_results, func_set=None):
|
5 |
+
|
6 |
+
if func_set is None:
|
7 |
+
if "multiarith" in args.dataset:
|
8 |
+
op_set, x_set, func_set = generate_func_set(
|
9 |
+
num_x=3,
|
10 |
+
num_op=2,
|
11 |
+
show_funx=False,
|
12 |
+
op_sets=['+', '-', '*', '/'],
|
13 |
+
rep=True,
|
14 |
+
)
|
15 |
+
elif "addsub" in args.dataset:
|
16 |
+
op_set, x_set, func_set_3 = generate_func_set_all(
|
17 |
+
num_x=3,
|
18 |
+
num_op=2,
|
19 |
+
show_funx=False,
|
20 |
+
op_sets=['+', '-'],
|
21 |
+
rep=True,
|
22 |
+
)
|
23 |
+
op_set, x_set, func_set_2 = generate_func_set(
|
24 |
+
num_x=2,
|
25 |
+
num_op=1,
|
26 |
+
show_funx=False,
|
27 |
+
op_sets=['+', '-'],
|
28 |
+
rep=True,
|
29 |
+
)
|
30 |
+
if len(orig_nums) == 3:
|
31 |
+
func_set = func_set_3
|
32 |
+
elif len(orig_nums) == 2:
|
33 |
+
func_set = func_set_2
|
34 |
+
else:
|
35 |
+
return None, None
|
36 |
+
|
37 |
+
errors_cnt = [0] * len(func_set)
|
38 |
+
losses_cnt = [0] * len(func_set)
|
39 |
+
for k, expr in enumerate(func_set):
|
40 |
+
|
41 |
+
for fp_result in fp_results:
|
42 |
+
|
43 |
+
pred = fp_result["fp_z"]
|
44 |
+
repl_numbers = fp_result["fp_nums"]
|
45 |
+
|
46 |
+
flag_, loss_ = is_func(expr, [str(pred)], [repl_numbers], return_loss=True)
|
47 |
+
errors_cnt[k] += (flag_ == False)
|
48 |
+
losses_cnt[k] += abs(loss_)
|
49 |
+
|
50 |
+
tmp_min = 10000000000000
|
51 |
+
tmp_error = 10000000000000
|
52 |
+
thresh_k = int(len(errors_cnt))
|
53 |
+
expr_filter = ""
|
54 |
+
for k, cnt in enumerate(errors_cnt):
|
55 |
+
if cnt <= thresh_k and errors_cnt[k] < tmp_min:
|
56 |
+
expr_filter = func_set[k]
|
57 |
+
tmp_min = errors_cnt[k]
|
58 |
+
|
59 |
+
tmp_min = losses_cnt[0]
|
60 |
+
if expr_filter == "":
|
61 |
+
for k, loss in enumerate(losses_cnt):
|
62 |
+
if loss < tmp_min:
|
63 |
+
expr_filter = func_set[k]
|
64 |
+
tmp_min = loss
|
65 |
+
|
66 |
+
# try calibration
|
67 |
+
cali_pred = ""
|
68 |
+
if expr_filter != "":
|
69 |
+
var_dict = {}
|
70 |
+
for i_, var in enumerate(orig_nums):
|
71 |
+
var_dict.update({
|
72 |
+
f"x{i_}": var,
|
73 |
+
})
|
74 |
+
try:
|
75 |
+
cali_pred = eval(expr_filter, var_dict)
|
76 |
+
if "multiarith" in args.dataset:
|
77 |
+
cali_pred = round(cali_pred, 5)
|
78 |
+
if int(cali_pred * 10 // 10) == cali_pred:
|
79 |
+
cali_pred = int(cali_pred)
|
80 |
+
else:
|
81 |
+
cali_pred = math.ceil(cali_pred)
|
82 |
+
elif "addsub" in args.dataset:
|
83 |
+
bit_max = 0
|
84 |
+
for number in orig_nums:
|
85 |
+
bit = str(number).split('.')
|
86 |
+
if len(bit) == 1:
|
87 |
+
bit = 0
|
88 |
+
else:
|
89 |
+
bit = len(bit[-1])
|
90 |
+
bit_max = max(bit, bit_max)
|
91 |
+
cali_pred = round(cali_pred, bit_max)
|
92 |
+
except:
|
93 |
+
return None, None
|
94 |
+
|
95 |
+
return expr_filter, cali_pred
|
96 |
+
else:
|
97 |
+
return None, None
|
md_src/demo_description.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Welcome! This is an interactive demo of [Solis](http://reflection-of-thought.github.io/).
|