Spaces:
Build error
Build error
completed eval/analysis
Browse files- data/Llama3.1-70B-Chinese-Chat_metrics.csv +12 -12
- data/Llama3.1-8B-Chinese-Chat_metrics.csv +12 -12
- data/Mistral-7B-v0.3-Chinese-Chat_metrics.csv +12 -12
- data/Qwen2-72B-Instruct_metrics.csv +12 -8
- data/Qwen2-72B-Instruct_results.csv +0 -0
- data/Qwen2-7B-Instruct_metrics.csv +12 -12
- data/internlm2_5-7b-chat-1m_metrics.csv +12 -12
- data/openai_metrics.csv +15 -0
- data/openai_results.csv +0 -0
- llm_toolkit/eval_openai.py +7 -1
- llm_toolkit/logical_reasoning_utils.py +167 -32
- notebooks/00_Data Analysis.ipynb +0 -0
- notebooks/01a_internlm2_5-7b-chat-1m_analysis.ipynb +0 -0
- notebooks/01b_Mistral-7B-v0.3-Chinese-Chat_analysis.ipynb +0 -0
- notebooks/02a_Qwen2-7B-Instruct_analysis.ipynb +0 -0
- notebooks/02b_Qwen2-72B-Instruct_analysis.ipynb +0 -0
- notebooks/03a_Llama3.1-8B-Chinese-Chat_analysis.ipynb +0 -0
- notebooks/03b_Llama3.1-70B-Chinese-Chat_analysis.ipynb +0 -0
- notebooks/04_Few-shot_Prompting_OpenAI.ipynb +0 -0
- notebooks/04b_OpenAI-Models_analysis.ipynb +0 -0
data/Llama3.1-70B-Chinese-Chat_metrics.csv
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
-
0.0,shenzhi-wang/Llama3.1-70B-Chinese-Chat_torch.bfloat16_4bit_lf,0.7636666666666667,0.7806653325131986,0.7636666666666667,0.7525813484548423,0.009666666666666667
|
3 |
-
0.2,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-35_torch.bfloat16_4bit_lf,0.778,0.8148707737020212,0.778,0.7910805488003003,0.9996666666666667
|
4 |
-
0.4,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-70_torch.bfloat16_4bit_lf,0.7306666666666667,0.8145782271710159,0.7306666666666667,0.7624724104697406,1.0
|
5 |
-
0.6,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-105_torch.bfloat16_4bit_lf,0.7193333333333334,0.8213567226911125,0.7193333333333334,0.7560702640626931,1.0
|
6 |
-
0.8,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-140_torch.bfloat16_4bit_lf,0.7563333333333333,0.826789897753756,0.7563333333333333,0.7815164366677209,1.0
|
7 |
-
1.0,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-175_torch.bfloat16_4bit_lf,0.7963333333333333,0.8248972880055918,0.7963333333333333,0.8076868978089201,1.0
|
8 |
-
1.2,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-210_torch.bfloat16_4bit_lf,0.7326666666666667,0.8265345821998035,0.7326666666666667,0.7644418492070342,1.0
|
9 |
-
1.4,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-245_torch.bfloat16_4bit_lf,0.7556666666666667,0.8258994609525315,0.7556666666666667,0.7820405339757727,1.0
|
10 |
-
1.6,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-280_torch.bfloat16_4bit_lf,0.757,0.8264461657684251,0.757,0.7834496144681513,1.0
|
11 |
-
1.8,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-315_torch.bfloat16_4bit_lf,0.7546666666666667,0.8277723752096544,0.7546666666666667,0.7823584779069335,1.0
|
12 |
-
2.0,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-350_torch.bfloat16_4bit_lf,0.7496666666666667,0.8282310230333227,0.7496666666666667,0.7791947625361637,1.0
|
|
|
1 |
+
epoch,model,run,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
+
0.0,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat_torch.bfloat16_4bit_lf,0.7636666666666667,0.7806653325131986,0.7636666666666667,0.7525813484548423,0.009666666666666667
|
3 |
+
0.2,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-35_torch.bfloat16_4bit_lf,0.778,0.8148707737020212,0.778,0.7910805488003003,0.9996666666666667
|
4 |
+
0.4,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-70_torch.bfloat16_4bit_lf,0.7306666666666667,0.8145782271710159,0.7306666666666667,0.7624724104697406,1.0
|
5 |
+
0.6,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-105_torch.bfloat16_4bit_lf,0.7193333333333334,0.8213567226911125,0.7193333333333334,0.7560702640626931,1.0
|
6 |
+
0.8,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-140_torch.bfloat16_4bit_lf,0.7563333333333333,0.826789897753756,0.7563333333333333,0.7815164366677209,1.0
|
7 |
+
1.0,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-175_torch.bfloat16_4bit_lf,0.7963333333333333,0.8248972880055918,0.7963333333333333,0.8076868978089201,1.0
|
8 |
+
1.2,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-210_torch.bfloat16_4bit_lf,0.7326666666666667,0.8265345821998035,0.7326666666666667,0.7644418492070342,1.0
|
9 |
+
1.4,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-245_torch.bfloat16_4bit_lf,0.7556666666666667,0.8258994609525315,0.7556666666666667,0.7820405339757727,1.0
|
10 |
+
1.6,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-280_torch.bfloat16_4bit_lf,0.757,0.8264461657684251,0.757,0.7834496144681513,1.0
|
11 |
+
1.8,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-315_torch.bfloat16_4bit_lf,0.7546666666666667,0.8277723752096544,0.7546666666666667,0.7823584779069335,1.0
|
12 |
+
2.0,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-350_torch.bfloat16_4bit_lf,0.7496666666666667,0.8282310230333227,0.7496666666666667,0.7791947625361637,1.0
|
data/Llama3.1-8B-Chinese-Chat_metrics.csv
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
-
0.0,shenzhi-wang/Llama3.1-8B-Chinese-Chat_torch.float16_lf,0.707,0.7631091217915184,0.707,0.7243940517731183,0.3923333333333333
|
3 |
-
0.2,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-35_torch.float16_lf,0.709,0.7987219597893886,0.709,0.7427961200958145,1.0
|
4 |
-
0.4,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-70_torch.float16_lf,0.7163333333333334,0.8058657875960304,0.7163333333333334,0.7487811196109319,0.9993333333333333
|
5 |
-
0.6,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-105_torch.float16_lf,0.6996666666666667,0.802722482275839,0.6996666666666667,0.7370938556711591,1.0
|
6 |
-
0.8,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-140_torch.float16_lf,0.7716666666666666,0.8092193821623755,0.7716666666666666,0.7864287269398251,1.0
|
7 |
-
1.0,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-175_torch.float16_lf,0.78,0.810582723471486,0.78,0.7924651054056209,1.0
|
8 |
-
1.2,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-210_torch.float16_lf,0.7313333333333333,0.8157783263996798,0.7313333333333333,0.7628807622782868,1.0
|
9 |
-
1.4,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-245_torch.float16_lf,0.751,0.8125856808988221,0.751,0.7745416635653988,1.0
|
10 |
-
1.6,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-280_torch.float16_lf,0.739,0.8097375095673094,0.739,0.7662329023371559,1.0
|
11 |
-
1.8,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-315_torch.float16_lf,0.7236666666666667,0.8145530585912838,0.7236666666666667,0.7580428816095297,1.0
|
12 |
-
2.0,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-350_torch.float16_lf,0.7293333333333333,0.8151184301713545,0.7293333333333333,0.7616699266814145,1.0
|
|
|
1 |
+
epoch,model,run,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
+
0.0,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat_torch.float16_lf,0.707,0.7631091217915184,0.707,0.7243940517731183,0.3923333333333333
|
3 |
+
0.2,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-35_torch.float16_lf,0.709,0.7987219597893886,0.709,0.7427961200958145,1.0
|
4 |
+
0.4,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-70_torch.float16_lf,0.7163333333333334,0.8058657875960304,0.7163333333333334,0.7487811196109319,0.9993333333333333
|
5 |
+
0.6,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-105_torch.float16_lf,0.6996666666666667,0.802722482275839,0.6996666666666667,0.7370938556711591,1.0
|
6 |
+
0.8,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-140_torch.float16_lf,0.7716666666666666,0.8092193821623755,0.7716666666666666,0.7864287269398251,1.0
|
7 |
+
1.0,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-175_torch.float16_lf,0.78,0.810582723471486,0.78,0.7924651054056209,1.0
|
8 |
+
1.2,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-210_torch.float16_lf,0.7313333333333333,0.8157783263996798,0.7313333333333333,0.7628807622782868,1.0
|
9 |
+
1.4,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-245_torch.float16_lf,0.751,0.8125856808988221,0.751,0.7745416635653988,1.0
|
10 |
+
1.6,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-280_torch.float16_lf,0.739,0.8097375095673094,0.739,0.7662329023371559,1.0
|
11 |
+
1.8,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-315_torch.float16_lf,0.7236666666666667,0.8145530585912838,0.7236666666666667,0.7580428816095297,1.0
|
12 |
+
2.0,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-350_torch.float16_lf,0.7293333333333333,0.8151184301713545,0.7293333333333333,0.7616699266814145,1.0
|
data/Mistral-7B-v0.3-Chinese-Chat_metrics.csv
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
-
0.0,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat_torch.float16_lf,0.7113333333333334,0.70220546362905,0.7113333333333334,0.6894974942637364,0.004
|
3 |
-
0.2,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-35_torch.float16_lf,0.702,0.7932731014186957,0.702,0.7342714734731689,1.0
|
4 |
-
0.4,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-70_torch.float16_lf,0.742,0.78982949223512,0.742,0.7536681109811127,1.0
|
5 |
-
0.6,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-105_torch.float16_lf,0.6596666666666666,0.7923396753604393,0.6596666666666666,0.7067542301676931,1.0
|
6 |
-
0.8,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-140_torch.float16_lf,0.7146666666666667,0.7861341885687435,0.7146666666666667,0.7404677278137267,1.0
|
7 |
-
1.0,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-175_torch.float16_lf,0.7326666666666667,0.7876867721932461,0.7326666666666667,0.7471869515031995,1.0
|
8 |
-
1.2,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-210_torch.float16_lf,0.7016666666666667,0.7903119228393193,0.7016666666666667,0.7348708822385348,1.0
|
9 |
-
1.4,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-245_torch.float16_lf,0.75,0.7885868317699068,0.75,0.7648234347578796,1.0
|
10 |
-
1.6,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-280_torch.float16_lf,0.7156666666666667,0.7846106674095725,0.7156666666666667,0.7410042005708856,1.0
|
11 |
-
1.8,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-315_torch.float16_lf,0.6916666666666667,0.7864256994491394,0.6916666666666667,0.7257499426487266,1.0
|
12 |
-
2.0,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-350_torch.float16_lf,0.6976666666666667,0.7889443494370009,0.6976666666666667,0.7307996137659796,1.0
|
|
|
1 |
+
epoch,model,run,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
+
0.0,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat_torch.float16_lf,0.7113333333333334,0.70220546362905,0.7113333333333334,0.6894974942637364,0.004
|
3 |
+
0.2,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-35_torch.float16_lf,0.702,0.7932731014186957,0.702,0.7342714734731689,1.0
|
4 |
+
0.4,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-70_torch.float16_lf,0.742,0.78982949223512,0.742,0.7536681109811127,1.0
|
5 |
+
0.6,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-105_torch.float16_lf,0.6596666666666666,0.7923396753604393,0.6596666666666666,0.7067542301676931,1.0
|
6 |
+
0.8,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-140_torch.float16_lf,0.7146666666666667,0.7861341885687435,0.7146666666666667,0.7404677278137267,1.0
|
7 |
+
1.0,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-175_torch.float16_lf,0.7326666666666667,0.7876867721932461,0.7326666666666667,0.7471869515031995,1.0
|
8 |
+
1.2,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-210_torch.float16_lf,0.7016666666666667,0.7903119228393193,0.7016666666666667,0.7348708822385348,1.0
|
9 |
+
1.4,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-245_torch.float16_lf,0.75,0.7885868317699068,0.75,0.7648234347578796,1.0
|
10 |
+
1.6,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-280_torch.float16_lf,0.7156666666666667,0.7846106674095725,0.7156666666666667,0.7410042005708856,1.0
|
11 |
+
1.8,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-315_torch.float16_lf,0.6916666666666667,0.7864256994491394,0.6916666666666667,0.7257499426487266,1.0
|
12 |
+
2.0,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-350_torch.float16_lf,0.6976666666666667,0.7889443494370009,0.6976666666666667,0.7307996137659796,1.0
|
data/Qwen2-72B-Instruct_metrics.csv
CHANGED
@@ -1,8 +1,12 @@
|
|
1 |
-
epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
-
0.0,Qwen/Qwen2-72B-Instruct_torch.bfloat16_4bit_lf,0.
|
3 |
-
0.2,Qwen/Qwen2-72B-Instruct/checkpoint-35_torch.bfloat16_4bit_lf,0.7583333333333333,0.8199928526815756,0.7583333333333333,0.782751089787442,1.0
|
4 |
-
0.4,Qwen/Qwen2-72B-Instruct/checkpoint-70_torch.bfloat16_4bit_lf,0.7366666666666667,0.8224865755517643,0.7366666666666667,0.7700627366337021,1.0
|
5 |
-
0.6,Qwen/Qwen2-72B-Instruct/checkpoint-105_torch.bfloat16_4bit_lf,0.757,0.8253824826209251,0.757,0.784000409833628,1.0
|
6 |
-
0.8,Qwen/Qwen2-72B-Instruct/checkpoint-140_torch.bfloat16_4bit_lf,0.7893333333333333,0.8229104753645825,0.7893333333333333,0.8033124955993173,1.0
|
7 |
-
1.0,Qwen/Qwen2-72B-Instruct/checkpoint-175_torch.bfloat16_4bit_lf,0.7376666666666667,0.8243654864769323,0.7376666666666667,0.7699617360961548,1.0
|
8 |
-
1.2,Qwen/Qwen2-72B-Instruct/checkpoint-210_torch.bfloat16_4bit_lf,0.763,0.8318882808702871,0.763,0.7901075708186186,1.0
|
|
|
|
|
|
|
|
|
|
1 |
+
epoch,model,run,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
+
0.0,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct_torch.bfloat16_4bit_lf,0.7516666666666667,0.7949378981748352,0.7516666666666667,0.7572499605227642,0.9773333333333334
|
3 |
+
0.2,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-35_torch.bfloat16_4bit_lf,0.7583333333333333,0.8199928526815756,0.7583333333333333,0.782751089787442,1.0
|
4 |
+
0.4,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-70_torch.bfloat16_4bit_lf,0.7366666666666667,0.8224865755517643,0.7366666666666667,0.7700627366337021,1.0
|
5 |
+
0.6,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-105_torch.bfloat16_4bit_lf,0.757,0.8253824826209251,0.757,0.784000409833628,1.0
|
6 |
+
0.8,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-140_torch.bfloat16_4bit_lf,0.7893333333333333,0.8229104753645825,0.7893333333333333,0.8033124955993173,1.0
|
7 |
+
1.0,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-175_torch.bfloat16_4bit_lf,0.7376666666666667,0.8243654864769323,0.7376666666666667,0.7699617360961548,1.0
|
8 |
+
1.2,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-210_torch.bfloat16_4bit_lf,0.763,0.8318882808702871,0.763,0.7901075708186186,1.0
|
9 |
+
1.4,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-245_torch.bfloat16_4bit_lf,0.7656666666666667,0.8288272203240518,0.7656666666666667,0.790627109330698,1.0
|
10 |
+
1.6,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-280_torch.bfloat16_4bit_lf,0.7693333333333333,0.8292798021666021,0.7693333333333333,0.7930169589012503,1.0
|
11 |
+
1.8,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-315_torch.bfloat16_4bit_lf,0.784,0.8354349234761956,0.784,0.804194683154365,1.0
|
12 |
+
2.0,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-350_torch.bfloat16_4bit_lf,0.7736666666666666,0.8330147983140184,0.7736666666666666,0.7973657072550873,1.0
|
data/Qwen2-72B-Instruct_results.csv
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
data/Qwen2-7B-Instruct_metrics.csv
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
-
0.0,Qwen/Qwen2-7B-Instruct_torch.float16_lf,0.6203333333333333,0.7554720257311661,0.6203333333333333,0.6731632664545455,0.9973333333333333
|
3 |
-
0.2,Qwen/Qwen2-7B-Instruct/checkpoint-35_torch.float16_lf,0.725,0.7840171468707405,0.725,0.748994536667058,0.9996666666666667
|
4 |
-
0.4,Qwen/Qwen2-7B-Instruct/checkpoint-70_torch.float16_lf,0.759,0.8005303465799652,0.759,0.7748745026535183,1.0
|
5 |
-
0.6,Qwen/Qwen2-7B-Instruct/checkpoint-105_torch.float16_lf,0.6926666666666667,0.8039176975550218,0.6926666666666667,0.7332481528585848,1.0
|
6 |
-
0.8,Qwen/Qwen2-7B-Instruct/checkpoint-140_torch.float16_lf,0.725,0.7952719247171957,0.725,0.7476238017654298,1.0
|
7 |
-
1.0,Qwen/Qwen2-7B-Instruct/checkpoint-175_torch.float16_lf,0.6756666666666666,0.7810148934939715,0.6756666666666666,0.708653993277772,1.0
|
8 |
-
1.2,Qwen/Qwen2-7B-Instruct/checkpoint-210_torch.float16_lf,0.7013333333333334,0.7969562600853992,0.7013333333333334,0.7362679665494508,1.0
|
9 |
-
1.4,Qwen/Qwen2-7B-Instruct/checkpoint-245_torch.float16_lf,0.7326666666666667,0.7922538479314682,0.7326666666666667,0.755402136631717,0.9996666666666667
|
10 |
-
1.6,Qwen/Qwen2-7B-Instruct/checkpoint-280_torch.float16_lf,0.6983333333333334,0.785127298428753,0.6983333333333334,0.7292251109166867,1.0
|
11 |
-
1.8,Qwen/Qwen2-7B-Instruct/checkpoint-315_torch.float16_lf,0.6783333333333333,0.785390767631834,0.6783333333333333,0.7164131321837346,1.0
|
12 |
-
2.0,Qwen/Qwen2-7B-Instruct/checkpoint-350_torch.float16_lf,0.689,0.7929715746898984,0.689,0.7259993126510194,1.0
|
|
|
1 |
+
epoch,model,run,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
+
0.0,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct_torch.float16_lf,0.6203333333333333,0.7554720257311661,0.6203333333333333,0.6731632664545455,0.9973333333333333
|
3 |
+
0.2,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-35_torch.float16_lf,0.725,0.7840171468707405,0.725,0.748994536667058,0.9996666666666667
|
4 |
+
0.4,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-70_torch.float16_lf,0.759,0.8005303465799652,0.759,0.7748745026535183,1.0
|
5 |
+
0.6,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-105_torch.float16_lf,0.6926666666666667,0.8039176975550218,0.6926666666666667,0.7332481528585848,1.0
|
6 |
+
0.8,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-140_torch.float16_lf,0.725,0.7952719247171957,0.725,0.7476238017654298,1.0
|
7 |
+
1.0,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-175_torch.float16_lf,0.6756666666666666,0.7810148934939715,0.6756666666666666,0.708653993277772,1.0
|
8 |
+
1.2,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-210_torch.float16_lf,0.7013333333333334,0.7969562600853992,0.7013333333333334,0.7362679665494508,1.0
|
9 |
+
1.4,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-245_torch.float16_lf,0.7326666666666667,0.7922538479314682,0.7326666666666667,0.755402136631717,0.9996666666666667
|
10 |
+
1.6,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-280_torch.float16_lf,0.6983333333333334,0.785127298428753,0.6983333333333334,0.7292251109166867,1.0
|
11 |
+
1.8,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-315_torch.float16_lf,0.6783333333333333,0.785390767631834,0.6783333333333333,0.7164131321837346,1.0
|
12 |
+
2.0,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-350_torch.float16_lf,0.689,0.7929715746898984,0.689,0.7259993126510194,1.0
|
data/internlm2_5-7b-chat-1m_metrics.csv
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
-
0.0,internlm/internlm2_5-7b-chat-1m_torch.bfloat16_lf,0.5106666666666667,0.743213901498142,0.5106666666666667,0.5357333853323308,1.0
|
3 |
-
0.2,internlm/internlm2_5-7b-chat-1m/checkpoint-35_torch.bfloat16_lf,0.7843333333333333,0.7977648302848388,0.7843333333333333,0.7864944570659659,1.0
|
4 |
-
0.4,internlm/internlm2_5-7b-chat-1m/checkpoint-70_torch.bfloat16_lf,0.7836666666666666,0.7996977262947886,0.7836666666666666,0.7886881726841081,1.0
|
5 |
-
0.6,internlm/internlm2_5-7b-chat-1m/checkpoint-105_torch.bfloat16_lf,0.7243333333333334,0.8171172705912051,0.7243333333333334,0.7565804830382912,1.0
|
6 |
-
0.8,internlm/internlm2_5-7b-chat-1m/checkpoint-140_torch.bfloat16_lf,0.803,0.8031411888150441,0.803,0.8028064320197301,1.0
|
7 |
-
1.0,internlm/internlm2_5-7b-chat-1m/checkpoint-175_torch.bfloat16_lf,0.7676666666666667,0.8108441731715863,0.7676666666666667,0.7843187816704813,1.0
|
8 |
-
1.2,internlm/internlm2_5-7b-chat-1m/checkpoint-210_torch.bfloat16_lf,0.7736666666666666,0.8091671780923799,0.7736666666666666,0.7876874850235454,1.0
|
9 |
-
1.4,internlm/internlm2_5-7b-chat-1m/checkpoint-245_torch.bfloat16_lf,0.7623333333333333,0.8062291602218205,0.7623333333333333,0.777669094563925,1.0
|
10 |
-
1.6,internlm/internlm2_5-7b-chat-1m/checkpoint-280_torch.bfloat16_lf,0.7553333333333333,0.8086197936829652,0.7553333333333333,0.7755588811428297,1.0
|
11 |
-
1.8,internlm/internlm2_5-7b-chat-1m/checkpoint-315_torch.bfloat16_lf,0.748,0.8171996792797457,0.748,0.773990849396903,1.0
|
12 |
-
2.0,internlm/internlm2_5-7b-chat-1m/checkpoint-350_torch.bfloat16_lf,0.756,0.8126875394266148,0.756,0.7777812522863184,1.0
|
|
|
1 |
+
epoch,model,run,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
+
0.0,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m_torch.bfloat16_lf,0.5106666666666667,0.743213901498142,0.5106666666666667,0.5357333853323308,1.0
|
3 |
+
0.2,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-35_torch.bfloat16_lf,0.7843333333333333,0.7977648302848388,0.7843333333333333,0.7864944570659659,1.0
|
4 |
+
0.4,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-70_torch.bfloat16_lf,0.7836666666666666,0.7996977262947886,0.7836666666666666,0.7886881726841081,1.0
|
5 |
+
0.6,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-105_torch.bfloat16_lf,0.7243333333333334,0.8171172705912051,0.7243333333333334,0.7565804830382912,1.0
|
6 |
+
0.8,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-140_torch.bfloat16_lf,0.803,0.8031411888150441,0.803,0.8028064320197301,1.0
|
7 |
+
1.0,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-175_torch.bfloat16_lf,0.7676666666666667,0.8108441731715863,0.7676666666666667,0.7843187816704813,1.0
|
8 |
+
1.2,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-210_torch.bfloat16_lf,0.7736666666666666,0.8091671780923799,0.7736666666666666,0.7876874850235454,1.0
|
9 |
+
1.4,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-245_torch.bfloat16_lf,0.7623333333333333,0.8062291602218205,0.7623333333333333,0.777669094563925,1.0
|
10 |
+
1.6,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-280_torch.bfloat16_lf,0.7553333333333333,0.8086197936829652,0.7553333333333333,0.7755588811428297,1.0
|
11 |
+
1.8,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-315_torch.bfloat16_lf,0.748,0.8171996792797457,0.748,0.773990849396903,1.0
|
12 |
+
2.0,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-350_torch.bfloat16_lf,0.756,0.8126875394266148,0.756,0.7777812522863184,1.0
|
data/openai_metrics.csv
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
shots,model,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
+
0,gpt-4o-mini,0.7176666666666667,0.785706730193659,0.7176666666666667,0.7296061848734905,0.9916666666666667
|
3 |
+
5,gpt-4o-mini,0.7176666666666667,0.7767294185987051,0.7176666666666667,0.7181068311028772,0.9996666666666667
|
4 |
+
10,gpt-4o-mini,0.6793333333333333,0.7728086050218999,0.6793333333333333,0.6916749681933937,0.9983333333333333
|
5 |
+
20,gpt-4o-mini,0.6623333333333333,0.7686706009175459,0.6623333333333333,0.6798015109939115,0.998
|
6 |
+
30,gpt-4o-mini,0.6873333333333334,0.7684209723431035,0.6873333333333334,0.6913018667081989,0.999
|
7 |
+
40,gpt-4o-mini,0.6923333333333334,0.7639874967862498,0.6923333333333334,0.6924934068935911,0.9986666666666667
|
8 |
+
50,gpt-4o-mini,0.717,0.7692638634416518,0.717,0.7105227254860433,0.9993333333333333
|
9 |
+
0,gpt-4o,0.782,0.8204048322982596,0.782,0.7953019682198627,0.066
|
10 |
+
5,gpt-4o,0.7873333333333333,0.8230974205170392,0.7873333333333333,0.8000290527498529,0.998
|
11 |
+
10,gpt-4o,0.7916666666666666,0.8227707658360168,0.7916666666666666,0.803614688453356,0.9996666666666667
|
12 |
+
20,gpt-4o,0.7816666666666666,0.8204541793856629,0.7816666666666666,0.7967017169880498,0.9993333333333333
|
13 |
+
30,gpt-4o,0.7886666666666666,0.8260847852316618,0.7886666666666666,0.8030949295928699,0.999
|
14 |
+
40,gpt-4o,0.784,0.8233509309291644,0.784,0.7993336791122846,0.9973333333333333
|
15 |
+
50,gpt-4o,0.787,0.8234800466218334,0.787,0.8013530974301947,0.9993333333333333
|
data/openai_results.csv
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
llm_toolkit/eval_openai.py
CHANGED
@@ -57,7 +57,13 @@ def evaluate_model_with_num_shots(
|
|
57 |
for num_shots in range_num_shots:
|
58 |
print(f"*** Evaluating with num_shots: {num_shots}")
|
59 |
|
60 |
-
predictions = eval_openai(
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
model_name_with_shorts = (
|
62 |
result_column_name
|
63 |
if result_column_name
|
|
|
57 |
for num_shots in range_num_shots:
|
58 |
print(f"*** Evaluating with num_shots: {num_shots}")
|
59 |
|
60 |
+
predictions = eval_openai(
|
61 |
+
eval_dataset,
|
62 |
+
model=model_name,
|
63 |
+
max_new_tokens=max_new_tokens,
|
64 |
+
num_shots=num_shots,
|
65 |
+
train_dataset=datasets["train"].to_pandas(),
|
66 |
+
)
|
67 |
model_name_with_shorts = (
|
68 |
result_column_name
|
69 |
if result_column_name
|
llm_toolkit/logical_reasoning_utils.py
CHANGED
@@ -3,14 +3,21 @@ import re
|
|
3 |
from langchain_openai import ChatOpenAI
|
4 |
from langchain_core.prompts import ChatPromptTemplate
|
5 |
import pandas as pd
|
|
|
6 |
import seaborn as sns
|
7 |
import matplotlib.pyplot as plt
|
8 |
from matplotlib import rcParams
|
9 |
from matplotlib.ticker import MultipleLocator
|
10 |
from datasets import load_dataset
|
11 |
import numpy as np
|
12 |
-
from sklearn.metrics import
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
print(f"loading {__file__}")
|
16 |
|
@@ -61,17 +68,16 @@ P2 = """你是一个情景猜谜游戏的主持人。游戏规则如下:
|
|
61 |
|
62 |
请严格按照这些规则回答参与者提出的问题。
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
**参与者提出的问题:** {}
|
69 |
"""
|
70 |
|
71 |
P2_en = """You are the host of a situational guessing game. The rules of the game are as follows:
|
72 |
|
73 |
1. Participants will receive a riddle that describes a simple yet difficult to understand event.
|
74 |
-
2. The host knows the
|
75 |
3. Participants can ask any closed-ended questions to uncover the truth of the event.
|
76 |
4. For each question, the host will respond with one of the following five options based on the actual situation: Yes, No, Unimportant, Correct answer, or Incorrect questioning. The criteria for each response are as follows:
|
77 |
- If the riddle and answer can provide an answer to the question, respond with: Yes or No
|
@@ -82,14 +88,35 @@ P2_en = """You are the host of a situational guessing game. The rules of the gam
|
|
82 |
|
83 |
Please strictly follow these rules when answering the participant's questions.
|
84 |
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
-
|
88 |
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
90 |
"""
|
91 |
|
92 |
-
system_prompt = "You are an expert in logical reasoning."
|
93 |
|
94 |
def get_prompt_template(using_p1=True, chinese_prompt=True):
|
95 |
if using_p1:
|
@@ -98,6 +125,40 @@ def get_prompt_template(using_p1=True, chinese_prompt=True):
|
|
98 |
return P2 if chinese_prompt else P2_en
|
99 |
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
def extract_answer(text, debug=False):
|
102 |
if text and isinstance(text, str):
|
103 |
# Remove the begin and end tokens
|
@@ -121,7 +182,9 @@ def extract_answer(text, debug=False):
|
|
121 |
print("--------\nstep 3:", text)
|
122 |
|
123 |
text = text.split(".")[0].strip()
|
|
|
124 |
text = text.split("。")[0].strip()
|
|
|
125 |
if debug:
|
126 |
print("--------\nstep 4:", text)
|
127 |
|
@@ -186,7 +249,9 @@ def save_results(model_name, results_path, dataset, predictions, debug=False):
|
|
186 |
df = dataset
|
187 |
else:
|
188 |
df = dataset.to_pandas()
|
189 |
-
df.drop(
|
|
|
|
|
190 |
else:
|
191 |
df = pd.read_csv(results_path, on_bad_lines="warn")
|
192 |
|
@@ -329,7 +394,7 @@ def plot_value_counts(df, column_name, offset=0.1, title=None, preprocess_func=N
|
|
329 |
df["backup"] = df[column_name]
|
330 |
df[column_name] = df[column_name].apply(preprocess_func)
|
331 |
|
332 |
-
plt.figure(figsize=(
|
333 |
df[column_name].value_counts().plot(kind="bar")
|
334 |
# add values on top of bars
|
335 |
for i, v in enumerate(df[column_name].value_counts()):
|
@@ -342,6 +407,7 @@ def plot_value_counts(df, column_name, offset=0.1, title=None, preprocess_func=N
|
|
342 |
rcParams["font.family"] = font_family
|
343 |
|
344 |
if preprocess_func:
|
|
|
345 |
df[column_name] = df["backup"]
|
346 |
df.drop(columns=["backup"], inplace=True)
|
347 |
|
@@ -351,16 +417,22 @@ def calc_metrics_for_col(df, col):
|
|
351 |
return metrics["accuracy"], metrics["precision"], metrics["recall"], metrics["f1"]
|
352 |
|
353 |
|
354 |
-
def get_metrics_df(df):
|
355 |
perf_df = pd.DataFrame(
|
356 |
-
columns=["
|
357 |
)
|
358 |
for i, col in enumerate(df.columns[5:]):
|
359 |
metrics = calc_metrics(df["label"], df[col], debug=False)
|
360 |
new_model_metrics = {
|
361 |
-
"epoch"
|
362 |
-
"model": col,
|
|
|
363 |
}
|
|
|
|
|
|
|
|
|
|
|
364 |
new_model_metrics.update(metrics)
|
365 |
|
366 |
# Convert the dictionary to a DataFrame and concatenate it with the existing DataFrame
|
@@ -371,51 +443,61 @@ def get_metrics_df(df):
|
|
371 |
return perf_df
|
372 |
|
373 |
|
374 |
-
def plot_metrics(perf_df, model_name):
|
375 |
-
fig, ax = plt.subplots(1, 1, figsize=(
|
|
|
376 |
|
377 |
# Ensure the lengths of perf_df["epoch"], perf_df["accuracy"], and perf_df["f1"] are the same
|
378 |
min_length = min(
|
379 |
-
len(perf_df[
|
380 |
)
|
381 |
perf_df = perf_df.iloc[:min_length]
|
382 |
|
383 |
# Plot accuracy and f1 on the same chart with different markers
|
384 |
-
ax.plot(perf_df["epoch"], perf_df["accuracy"], marker="o", label="Accuracy")
|
385 |
ax.plot(
|
386 |
-
perf_df[
|
|
|
|
|
|
|
387 |
) # Square marker for F1 Score
|
388 |
|
389 |
# Add values on top of points
|
390 |
for i in range(min_length):
|
|
|
391 |
ax.annotate(
|
392 |
f"{perf_df['accuracy'].iloc[i]*100:.2f}%",
|
393 |
-
(perf_df[
|
394 |
ha="center",
|
395 |
va="bottom", # Move accuracy numbers below the points
|
396 |
xytext=(0, -15),
|
397 |
textcoords="offset points",
|
398 |
fontsize=10,
|
|
|
399 |
)
|
400 |
ax.annotate(
|
401 |
f"{perf_df['f1'].iloc[i]*100:.2f}%",
|
402 |
-
(perf_df[
|
403 |
ha="center",
|
404 |
va="top", # Move F1 score numbers above the points
|
405 |
xytext=(0, 15), # Offset by 15 points vertically
|
406 |
textcoords="offset points",
|
407 |
fontsize=10,
|
|
|
408 |
)
|
409 |
|
410 |
# Set y-axis limit
|
411 |
-
|
|
|
412 |
|
413 |
# Add title and labels
|
414 |
-
ax.set_xlabel(
|
|
|
|
|
|
|
|
|
415 |
ax.set_ylabel("Accuracy and F1 Score")
|
416 |
|
417 |
-
|
418 |
-
ax.xaxis.set_major_locator(MultipleLocator(0.2))
|
419 |
ax.set_title(f"Performance Analysis Across Checkpoints for the {model_name} Model")
|
420 |
|
421 |
# Rotate x labels
|
@@ -460,13 +542,66 @@ def reasoning_with_openai(
|
|
460 |
return response.content
|
461 |
|
462 |
|
463 |
-
def eval_openai(
|
464 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
465 |
total = len(eval_dataset)
|
466 |
predictions = []
|
467 |
|
468 |
for i in tqdm(range(total)):
|
469 |
-
output = reasoning_with_openai(
|
|
|
|
|
470 |
predictions.append(output)
|
471 |
|
472 |
return predictions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from langchain_openai import ChatOpenAI
|
4 |
from langchain_core.prompts import ChatPromptTemplate
|
5 |
import pandas as pd
|
6 |
+
from tqdm import tqdm
|
7 |
import seaborn as sns
|
8 |
import matplotlib.pyplot as plt
|
9 |
from matplotlib import rcParams
|
10 |
from matplotlib.ticker import MultipleLocator
|
11 |
from datasets import load_dataset
|
12 |
import numpy as np
|
13 |
+
from sklearn.metrics import (
|
14 |
+
accuracy_score,
|
15 |
+
precision_score,
|
16 |
+
recall_score,
|
17 |
+
f1_score,
|
18 |
+
confusion_matrix,
|
19 |
+
)
|
20 |
+
|
21 |
|
22 |
print(f"loading {__file__}")
|
23 |
|
|
|
68 |
|
69 |
请严格按照这些规则回答参与者提出的问题。
|
70 |
|
71 |
+
谜面: {}
|
72 |
+
谜底: {}
|
73 |
+
参与者提出的问题: {}
|
74 |
+
回答:
|
|
|
75 |
"""
|
76 |
|
77 |
P2_en = """You are the host of a situational guessing game. The rules of the game are as follows:
|
78 |
|
79 |
1. Participants will receive a riddle that describes a simple yet difficult to understand event.
|
80 |
+
2. The host knows the truth, which is the solution to the riddle.
|
81 |
3. Participants can ask any closed-ended questions to uncover the truth of the event.
|
82 |
4. For each question, the host will respond with one of the following five options based on the actual situation: Yes, No, Unimportant, Correct answer, or Incorrect questioning. The criteria for each response are as follows:
|
83 |
- If the riddle and answer can provide an answer to the question, respond with: Yes or No
|
|
|
88 |
|
89 |
Please strictly follow these rules when answering the participant's questions.
|
90 |
|
91 |
+
Riddle: {}
|
92 |
+
Truth: {}
|
93 |
+
Participant's question: {}
|
94 |
+
"""
|
95 |
+
|
96 |
+
system_prompt = "You are an expert in logical reasoning."
|
97 |
+
|
98 |
+
P2_few_shot = """你是一个情景猜谜游戏的主持人。游戏规则如下:
|
99 |
+
|
100 |
+
1. 参与者会得到一个谜面,谜面会描述一个简单又难以理解的事件。
|
101 |
+
2. 主持人知道谜底,谜底是谜面的答案。
|
102 |
+
3. 参与者可以询问任何封闭式问题来找寻事件的真相。
|
103 |
+
4. 对于每个问题,主持人将根据实际情况回答以下五个选项之一:是、不是、不重要、回答正确、问法错误。各回答的判断标准如下:
|
104 |
+
- 若谜面和谜底能找到问题的答案,回答:是或者不是
|
105 |
+
- 若谜面和谜底不能直接或者间接推断出问题的答案,回答:不重要
|
106 |
+
- 若参与者提问不是一个封闭式问题或者问题难以理解,回答:问法错误
|
107 |
+
- 若参与者提问基本还原了谜底真相,回答:回答正确
|
108 |
+
5. 回答中不能添加任何其它信息,也不能省略选项中的任何一个字。例如,不可以把“不是”省略成“不”。
|
109 |
|
110 |
+
请严格按照这些规则回答参与者提出的问题。
|
111 |
|
112 |
+
示例输入和输出:
|
113 |
+
{examples}
|
114 |
+
谜面: {}
|
115 |
+
谜底: {}
|
116 |
+
参与者提出的问题: {}
|
117 |
+
回答:
|
118 |
"""
|
119 |
|
|
|
120 |
|
121 |
def get_prompt_template(using_p1=True, chinese_prompt=True):
|
122 |
if using_p1:
|
|
|
125 |
return P2 if chinese_prompt else P2_en
|
126 |
|
127 |
|
128 |
+
def get_few_shot_prompt_template(num_shots, train_dataset, debug=False):
|
129 |
+
if num_shots == 0:
|
130 |
+
return get_prompt_template(using_p1=False, chinese_prompt=True)
|
131 |
+
|
132 |
+
labels = train_dataset["label"].unique()
|
133 |
+
if debug:
|
134 |
+
print("num_shots:", num_shots)
|
135 |
+
print("labels:", labels)
|
136 |
+
|
137 |
+
examples = ""
|
138 |
+
index = 0
|
139 |
+
while num_shots > 0:
|
140 |
+
for label in labels:
|
141 |
+
while train_dataset["label"][index] != label:
|
142 |
+
index += 1
|
143 |
+
|
144 |
+
row = train_dataset.iloc[index]
|
145 |
+
examples += f"""谜面: {row["puzzle"]}
|
146 |
+
谜底: {row["truth"]}
|
147 |
+
参与者提出的问题: {row["text"]}
|
148 |
+
回答: {row["label"]}
|
149 |
+
|
150 |
+
"""
|
151 |
+
num_shots -= 1
|
152 |
+
if num_shots == 0:
|
153 |
+
break
|
154 |
+
|
155 |
+
prompt = P2_few_shot.replace("{examples}", examples)
|
156 |
+
if debug:
|
157 |
+
print("P2_few_shot:", prompt)
|
158 |
+
|
159 |
+
return prompt
|
160 |
+
|
161 |
+
|
162 |
def extract_answer(text, debug=False):
|
163 |
if text and isinstance(text, str):
|
164 |
# Remove the begin and end tokens
|
|
|
182 |
print("--------\nstep 3:", text)
|
183 |
|
184 |
text = text.split(".")[0].strip()
|
185 |
+
text = text.split("\n")[0].strip()
|
186 |
text = text.split("。")[0].strip()
|
187 |
+
text = text.replace("回答: ", "").strip()
|
188 |
if debug:
|
189 |
print("--------\nstep 4:", text)
|
190 |
|
|
|
249 |
df = dataset
|
250 |
else:
|
251 |
df = dataset.to_pandas()
|
252 |
+
df.drop(
|
253 |
+
columns=["answer", "prompt", "train_text"], inplace=True, errors="ignore"
|
254 |
+
)
|
255 |
else:
|
256 |
df = pd.read_csv(results_path, on_bad_lines="warn")
|
257 |
|
|
|
394 |
df["backup"] = df[column_name]
|
395 |
df[column_name] = df[column_name].apply(preprocess_func)
|
396 |
|
397 |
+
plt.figure(figsize=(8, 4))
|
398 |
df[column_name].value_counts().plot(kind="bar")
|
399 |
# add values on top of bars
|
400 |
for i, v in enumerate(df[column_name].value_counts()):
|
|
|
407 |
rcParams["font.family"] = font_family
|
408 |
|
409 |
if preprocess_func:
|
410 |
+
plot_confusion_matrix(df["label"], df[column_name])
|
411 |
df[column_name] = df["backup"]
|
412 |
df.drop(columns=["backup"], inplace=True)
|
413 |
|
|
|
417 |
return metrics["accuracy"], metrics["precision"], metrics["recall"], metrics["f1"]
|
418 |
|
419 |
|
420 |
+
def get_metrics_df(df, variant="epoch"):
|
421 |
perf_df = pd.DataFrame(
|
422 |
+
columns=[variant, "model", "run", "accuracy", "precision", "recall", "f1"]
|
423 |
)
|
424 |
for i, col in enumerate(df.columns[5:]):
|
425 |
metrics = calc_metrics(df["label"], df[col], debug=False)
|
426 |
new_model_metrics = {
|
427 |
+
variant: i / 5 if variant == "epoch" else i + 1,
|
428 |
+
"model": col if "/" not in col else col.split("/")[1].split("_torch")[0],
|
429 |
+
"run": col,
|
430 |
}
|
431 |
+
if variant == "shots":
|
432 |
+
parts = col.split("/shots-")
|
433 |
+
new_model_metrics["shots"] = int(parts[1])
|
434 |
+
new_model_metrics["model"] = parts[0]
|
435 |
+
|
436 |
new_model_metrics.update(metrics)
|
437 |
|
438 |
# Convert the dictionary to a DataFrame and concatenate it with the existing DataFrame
|
|
|
443 |
return perf_df
|
444 |
|
445 |
|
446 |
+
def plot_metrics(perf_df, model_name, variant="epoch", offset=0.01):
|
447 |
+
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
|
448 |
+
perf_df = perf_df[perf_df["model"] == model_name]
|
449 |
|
450 |
# Ensure the lengths of perf_df["epoch"], perf_df["accuracy"], and perf_df["f1"] are the same
|
451 |
min_length = min(
|
452 |
+
len(perf_df[variant]), len(perf_df["accuracy"]), len(perf_df["f1"])
|
453 |
)
|
454 |
perf_df = perf_df.iloc[:min_length]
|
455 |
|
456 |
# Plot accuracy and f1 on the same chart with different markers
|
|
|
457 |
ax.plot(
|
458 |
+
perf_df[variant], perf_df["accuracy"], marker="o", label="Accuracy", color="r"
|
459 |
+
)
|
460 |
+
ax.plot(
|
461 |
+
perf_df[variant], perf_df["f1"], marker="s", label="F1 Score", color="b"
|
462 |
) # Square marker for F1 Score
|
463 |
|
464 |
# Add values on top of points
|
465 |
for i in range(min_length):
|
466 |
+
print(f"{perf_df[variant].iloc[i]}: {perf_df['run'].iloc[i]}")
|
467 |
ax.annotate(
|
468 |
f"{perf_df['accuracy'].iloc[i]*100:.2f}%",
|
469 |
+
(perf_df[variant].iloc[i], perf_df["accuracy"].iloc[i]),
|
470 |
ha="center",
|
471 |
va="bottom", # Move accuracy numbers below the points
|
472 |
xytext=(0, -15),
|
473 |
textcoords="offset points",
|
474 |
fontsize=10,
|
475 |
+
color="r",
|
476 |
)
|
477 |
ax.annotate(
|
478 |
f"{perf_df['f1'].iloc[i]*100:.2f}%",
|
479 |
+
(perf_df[variant].iloc[i], perf_df["f1"].iloc[i]),
|
480 |
ha="center",
|
481 |
va="top", # Move F1 score numbers above the points
|
482 |
xytext=(0, 15), # Offset by 15 points vertically
|
483 |
textcoords="offset points",
|
484 |
fontsize=10,
|
485 |
+
color="b",
|
486 |
)
|
487 |
|
488 |
# Set y-axis limit
|
489 |
+
ylimits = ax.get_ylim()
|
490 |
+
ax.set_ylim(ylimits[0] - offset, ylimits[1] + offset)
|
491 |
|
492 |
# Add title and labels
|
493 |
+
ax.set_xlabel(
|
494 |
+
"Epoch (0: base model, 0.2 - 2: fine-tuned models)"
|
495 |
+
if variant == "epoch"
|
496 |
+
else "Number of Shots"
|
497 |
+
)
|
498 |
ax.set_ylabel("Accuracy and F1 Score")
|
499 |
|
500 |
+
ax.xaxis.set_major_locator(MultipleLocator(0.2 if variant == "epoch" else 5))
|
|
|
501 |
ax.set_title(f"Performance Analysis Across Checkpoints for the {model_name} Model")
|
502 |
|
503 |
# Rotate x labels
|
|
|
542 |
return response.content
|
543 |
|
544 |
|
545 |
+
def eval_openai(
|
546 |
+
eval_dataset,
|
547 |
+
model="gpt-4o-mini",
|
548 |
+
max_new_tokens=300,
|
549 |
+
num_shots=0,
|
550 |
+
train_dataset=None,
|
551 |
+
):
|
552 |
+
user_prompt = (
|
553 |
+
get_prompt_template(using_p1=False, chinese_prompt=True)
|
554 |
+
if num_shots == 0
|
555 |
+
else get_few_shot_prompt_template(num_shots, train_dataset)
|
556 |
+
)
|
557 |
+
print("user_prompt:", user_prompt)
|
558 |
total = len(eval_dataset)
|
559 |
predictions = []
|
560 |
|
561 |
for i in tqdm(range(total)):
|
562 |
+
output = reasoning_with_openai(
|
563 |
+
eval_dataset.iloc[i], user_prompt, model=model, max_tokens=max_new_tokens
|
564 |
+
)
|
565 |
predictions.append(output)
|
566 |
|
567 |
return predictions
|
568 |
+
|
569 |
+
|
570 |
+
def plot_confusion_matrix(y_true, y_pred, title="Confusion Matrix"):
|
571 |
+
font_family = rcParams["font.family"]
|
572 |
+
# Set the font to SimHei to support Chinese characters
|
573 |
+
rcParams["font.family"] = "STHeiti"
|
574 |
+
rcParams["axes.unicode_minus"] = (
|
575 |
+
False # This is to support the minus sign in Chinese.
|
576 |
+
)
|
577 |
+
|
578 |
+
labels = np.unique(y_true)
|
579 |
+
|
580 |
+
y_pred = [extract_answer(text) for text in y_pred]
|
581 |
+
|
582 |
+
cm = confusion_matrix(y_true, y_pred)
|
583 |
+
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
|
584 |
+
|
585 |
+
fig, ax = plt.subplots(figsize=(8, 8))
|
586 |
+
sns.heatmap(
|
587 |
+
cm,
|
588 |
+
annot=True,
|
589 |
+
fmt=".4f",
|
590 |
+
cmap="Blues",
|
591 |
+
xticklabels=labels,
|
592 |
+
yticklabels=labels,
|
593 |
+
)
|
594 |
+
ax.set_title(title)
|
595 |
+
ax.set_xlabel("Predicted labels")
|
596 |
+
ax.set_ylabel("True labels")
|
597 |
+
plt.show()
|
598 |
+
|
599 |
+
rcParams["font.family"] = font_family
|
600 |
+
|
601 |
+
|
602 |
+
def majority_vote(r1, r2, r3):
|
603 |
+
label = r2
|
604 |
+
if r1 == r3:
|
605 |
+
label = r1
|
606 |
+
|
607 |
+
return label
|
notebooks/00_Data Analysis.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/01a_internlm2_5-7b-chat-1m_analysis.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/01b_Mistral-7B-v0.3-Chinese-Chat_analysis.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/02a_Qwen2-7B-Instruct_analysis.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/02b_Qwen2-72B-Instruct_analysis.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/03a_Llama3.1-8B-Chinese-Chat_analysis.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/03b_Llama3.1-70B-Chinese-Chat_analysis.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/04_Few-shot_Prompting_OpenAI.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/04b_OpenAI-Models_analysis.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|