Spaces:
Build error
Build error
open source LLM results almost done
Browse files- data/Llama3.1-70B-Chinese-Chat_metrics.csv +12 -0
- data/Llama3.1-8B-Chinese-Chat_metrics.csv +12 -7
- data/Mistral-7B-v0.3-Chinese-Chat_metrics.csv +12 -0
- data/Qwen2-72B-Instruct_metrics.csv +8 -8
- data/Qwen2-7B-Instruct_metrics.csv +12 -12
- data/internlm2_5-7b-chat-1m_metrics.csv +12 -12
- llm_toolkit/logical_reasoning_utils.py +42 -26
- notebooks/00_Data Analysis.ipynb +0 -0
- notebooks/01_internlm2_5-7b-chat-1m_analysis.ipynb +0 -0
- notebooks/01a_internlm2_5-7b-chat-1m_analysis.ipynb +0 -0
- notebooks/01b_Mistral-7B-v0.3-Chinese-Chat_analysis.ipynb +0 -0
- notebooks/02a_Qwen2-7B-Instruct_analysis.ipynb +0 -0
- notebooks/02b_Qwen2-72B-Instruct_analysis.ipynb +0 -0
- notebooks/03a_Llama3.1-8B-Chinese-Chat_analysis.ipynb +0 -0
- notebooks/03b_Llama3.1-70B-Chinese-Chat_analysis.ipynb +0 -0
data/Llama3.1-70B-Chinese-Chat_metrics.csv
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
+
0.0,shenzhi-wang/Llama3.1-70B-Chinese-Chat_torch.bfloat16_4bit_lf,0.7636666666666667,0.7806653325131986,0.7636666666666667,0.7525813484548423,0.009666666666666667
|
3 |
+
0.2,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-35_torch.bfloat16_4bit_lf,0.778,0.8148707737020212,0.778,0.7910805488003003,0.9996666666666667
|
4 |
+
0.4,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-70_torch.bfloat16_4bit_lf,0.7306666666666667,0.8145782271710159,0.7306666666666667,0.7624724104697406,1.0
|
5 |
+
0.6,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-105_torch.bfloat16_4bit_lf,0.7193333333333334,0.8213567226911125,0.7193333333333334,0.7560702640626931,1.0
|
6 |
+
0.8,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-140_torch.bfloat16_4bit_lf,0.7563333333333333,0.826789897753756,0.7563333333333333,0.7815164366677209,1.0
|
7 |
+
1.0,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-175_torch.bfloat16_4bit_lf,0.7963333333333333,0.8248972880055918,0.7963333333333333,0.8076868978089201,1.0
|
8 |
+
1.2,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-210_torch.bfloat16_4bit_lf,0.7326666666666667,0.8265345821998035,0.7326666666666667,0.7644418492070342,1.0
|
9 |
+
1.4,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-245_torch.bfloat16_4bit_lf,0.7556666666666667,0.8258994609525315,0.7556666666666667,0.7820405339757727,1.0
|
10 |
+
1.6,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-280_torch.bfloat16_4bit_lf,0.757,0.8264461657684251,0.757,0.7834496144681513,1.0
|
11 |
+
1.8,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-315_torch.bfloat16_4bit_lf,0.7546666666666667,0.8277723752096544,0.7546666666666667,0.7823584779069335,1.0
|
12 |
+
2.0,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-350_torch.bfloat16_4bit_lf,0.7496666666666667,0.8282310230333227,0.7496666666666667,0.7791947625361637,1.0
|
data/Llama3.1-8B-Chinese-Chat_metrics.csv
CHANGED
@@ -1,7 +1,12 @@
|
|
1 |
-
epoch,model,accuracy,precision,recall,f1
|
2 |
-
0.0,shenzhi-wang/Llama3.1-8B-Chinese-Chat_torch.float16_lf,0.
|
3 |
-
0.2,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-35_torch.float16_lf,0.
|
4 |
-
0.4,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-70_torch.float16_lf,0.
|
5 |
-
0.6,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-105_torch.float16_lf,0.
|
6 |
-
0.8,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-140_torch.float16_lf,0.
|
7 |
-
1.0,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-175_torch.float16_lf,0.
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
+
0.0,shenzhi-wang/Llama3.1-8B-Chinese-Chat_torch.float16_lf,0.707,0.7631091217915184,0.707,0.7243940517731183,0.3923333333333333
|
3 |
+
0.2,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-35_torch.float16_lf,0.709,0.7987219597893886,0.709,0.7427961200958145,1.0
|
4 |
+
0.4,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-70_torch.float16_lf,0.7163333333333334,0.8058657875960304,0.7163333333333334,0.7487811196109319,0.9993333333333333
|
5 |
+
0.6,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-105_torch.float16_lf,0.6996666666666667,0.802722482275839,0.6996666666666667,0.7370938556711591,1.0
|
6 |
+
0.8,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-140_torch.float16_lf,0.7716666666666666,0.8092193821623755,0.7716666666666666,0.7864287269398251,1.0
|
7 |
+
1.0,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-175_torch.float16_lf,0.78,0.810582723471486,0.78,0.7924651054056209,1.0
|
8 |
+
1.2,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-210_torch.float16_lf,0.7313333333333333,0.8157783263996798,0.7313333333333333,0.7628807622782868,1.0
|
9 |
+
1.4,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-245_torch.float16_lf,0.751,0.8125856808988221,0.751,0.7745416635653988,1.0
|
10 |
+
1.6,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-280_torch.float16_lf,0.739,0.8097375095673094,0.739,0.7662329023371559,1.0
|
11 |
+
1.8,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-315_torch.float16_lf,0.7236666666666667,0.8145530585912838,0.7236666666666667,0.7580428816095297,1.0
|
12 |
+
2.0,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-350_torch.float16_lf,0.7293333333333333,0.8151184301713545,0.7293333333333333,0.7616699266814145,1.0
|
data/Mistral-7B-v0.3-Chinese-Chat_metrics.csv
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
+
0.0,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat_torch.float16_lf,0.7113333333333334,0.70220546362905,0.7113333333333334,0.6894974942637364,0.004
|
3 |
+
0.2,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-35_torch.float16_lf,0.702,0.7932731014186957,0.702,0.7342714734731689,1.0
|
4 |
+
0.4,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-70_torch.float16_lf,0.742,0.78982949223512,0.742,0.7536681109811127,1.0
|
5 |
+
0.6,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-105_torch.float16_lf,0.6596666666666666,0.7923396753604393,0.6596666666666666,0.7067542301676931,1.0
|
6 |
+
0.8,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-140_torch.float16_lf,0.7146666666666667,0.7861341885687435,0.7146666666666667,0.7404677278137267,1.0
|
7 |
+
1.0,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-175_torch.float16_lf,0.7326666666666667,0.7876867721932461,0.7326666666666667,0.7471869515031995,1.0
|
8 |
+
1.2,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-210_torch.float16_lf,0.7016666666666667,0.7903119228393193,0.7016666666666667,0.7348708822385348,1.0
|
9 |
+
1.4,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-245_torch.float16_lf,0.75,0.7885868317699068,0.75,0.7648234347578796,1.0
|
10 |
+
1.6,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-280_torch.float16_lf,0.7156666666666667,0.7846106674095725,0.7156666666666667,0.7410042005708856,1.0
|
11 |
+
1.8,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-315_torch.float16_lf,0.6916666666666667,0.7864256994491394,0.6916666666666667,0.7257499426487266,1.0
|
12 |
+
2.0,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-350_torch.float16_lf,0.6976666666666667,0.7889443494370009,0.6976666666666667,0.7307996137659796,1.0
|
data/Qwen2-72B-Instruct_metrics.csv
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
-
epoch,model,accuracy,precision,recall,f1
|
2 |
-
0.0,Qwen/Qwen2-72B-Instruct_torch.bfloat16_4bit_lf,0.7473333333333333,0.804122252986722,0.7473333333333333,0.7607828719113865
|
3 |
-
0.2,Qwen/Qwen2-72B-Instruct/checkpoint-35_torch.bfloat16_4bit_lf,0.7583333333333333,0.8199928526815756,0.7583333333333333,0.782751089787442
|
4 |
-
0.4,Qwen/Qwen2-72B-Instruct/checkpoint-70_torch.bfloat16_4bit_lf,0.7366666666666667,0.8224865755517643,0.7366666666666667,0.7700627366337021
|
5 |
-
0.6,Qwen/Qwen2-72B-Instruct/checkpoint-105_torch.bfloat16_4bit_lf,0.757,0.8253824826209251,0.757,0.784000409833628
|
6 |
-
0.8,Qwen/Qwen2-72B-Instruct/checkpoint-140_torch.bfloat16_4bit_lf,0.7893333333333333,0.8229104753645825,0.7893333333333333,0.8033124955993173
|
7 |
-
1.0,Qwen/Qwen2-72B-Instruct/checkpoint-175_torch.bfloat16_4bit_lf,0.7376666666666667,0.8243654864769323,0.7376666666666667,0.7699617360961548
|
8 |
-
1.2,Qwen/Qwen2-72B-Instruct/checkpoint-210_torch.bfloat16_4bit_lf,0.763,0.8318882808702871,0.763,0.7901075708186186
|
|
|
1 |
+
epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
+
0.0,Qwen/Qwen2-72B-Instruct_torch.bfloat16_4bit_lf,0.7473333333333333,0.804122252986722,0.7473333333333333,0.7607828719113865,0.9773333333333334
|
3 |
+
0.2,Qwen/Qwen2-72B-Instruct/checkpoint-35_torch.bfloat16_4bit_lf,0.7583333333333333,0.8199928526815756,0.7583333333333333,0.782751089787442,1.0
|
4 |
+
0.4,Qwen/Qwen2-72B-Instruct/checkpoint-70_torch.bfloat16_4bit_lf,0.7366666666666667,0.8224865755517643,0.7366666666666667,0.7700627366337021,1.0
|
5 |
+
0.6,Qwen/Qwen2-72B-Instruct/checkpoint-105_torch.bfloat16_4bit_lf,0.757,0.8253824826209251,0.757,0.784000409833628,1.0
|
6 |
+
0.8,Qwen/Qwen2-72B-Instruct/checkpoint-140_torch.bfloat16_4bit_lf,0.7893333333333333,0.8229104753645825,0.7893333333333333,0.8033124955993173,1.0
|
7 |
+
1.0,Qwen/Qwen2-72B-Instruct/checkpoint-175_torch.bfloat16_4bit_lf,0.7376666666666667,0.8243654864769323,0.7376666666666667,0.7699617360961548,1.0
|
8 |
+
1.2,Qwen/Qwen2-72B-Instruct/checkpoint-210_torch.bfloat16_4bit_lf,0.763,0.8318882808702871,0.763,0.7901075708186186,1.0
|
data/Qwen2-7B-Instruct_metrics.csv
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
epoch,model,accuracy,precision,recall,f1
|
2 |
-
0.0,Qwen/Qwen2-7B-Instruct_torch.float16_lf,0.
|
3 |
-
0.2,Qwen/Qwen2-7B-Instruct/checkpoint-35_torch.float16_lf,0.725,0.7840171468707405,0.725,0.748994536667058
|
4 |
-
0.4,Qwen/Qwen2-7B-Instruct/checkpoint-70_torch.float16_lf,0.759,0.8005303465799652,0.759,0.7748745026535183
|
5 |
-
0.6,Qwen/Qwen2-7B-Instruct/checkpoint-105_torch.float16_lf,0.6926666666666667,0.8039176975550218,0.6926666666666667,0.7332481528585848
|
6 |
-
0.8,Qwen/Qwen2-7B-Instruct/checkpoint-140_torch.float16_lf,0.725,0.7952719247171957,0.725,0.7476238017654298
|
7 |
-
1.0,Qwen/Qwen2-7B-Instruct/checkpoint-175_torch.float16_lf,0.6756666666666666,0.7810148934939715,0.6756666666666666,0.708653993277772
|
8 |
-
1.2,Qwen/Qwen2-7B-Instruct/checkpoint-210_torch.float16_lf,0.7013333333333334,0.7969562600853992,0.7013333333333334,0.7362679665494508
|
9 |
-
1.4,Qwen/Qwen2-7B-Instruct/checkpoint-245_torch.float16_lf,0.7326666666666667,0.7922538479314682,0.7326666666666667,0.755402136631717
|
10 |
-
1.6,Qwen/Qwen2-7B-Instruct/checkpoint-280_torch.float16_lf,0.6983333333333334,0.785127298428753,0.6983333333333334,0.7292251109166867
|
11 |
-
1.8,Qwen/Qwen2-7B-Instruct/checkpoint-315_torch.float16_lf,0.6783333333333333,0.785390767631834,0.6783333333333333,0.7164131321837346
|
12 |
-
2.0,Qwen/Qwen2-7B-Instruct/checkpoint-350_torch.float16_lf,0.689,0.7929715746898984,0.689,0.7259993126510194
|
|
|
1 |
+
epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
+
0.0,Qwen/Qwen2-7B-Instruct_torch.float16_lf,0.6203333333333333,0.7554720257311661,0.6203333333333333,0.6731632664545455,0.9973333333333333
|
3 |
+
0.2,Qwen/Qwen2-7B-Instruct/checkpoint-35_torch.float16_lf,0.725,0.7840171468707405,0.725,0.748994536667058,0.9996666666666667
|
4 |
+
0.4,Qwen/Qwen2-7B-Instruct/checkpoint-70_torch.float16_lf,0.759,0.8005303465799652,0.759,0.7748745026535183,1.0
|
5 |
+
0.6,Qwen/Qwen2-7B-Instruct/checkpoint-105_torch.float16_lf,0.6926666666666667,0.8039176975550218,0.6926666666666667,0.7332481528585848,1.0
|
6 |
+
0.8,Qwen/Qwen2-7B-Instruct/checkpoint-140_torch.float16_lf,0.725,0.7952719247171957,0.725,0.7476238017654298,1.0
|
7 |
+
1.0,Qwen/Qwen2-7B-Instruct/checkpoint-175_torch.float16_lf,0.6756666666666666,0.7810148934939715,0.6756666666666666,0.708653993277772,1.0
|
8 |
+
1.2,Qwen/Qwen2-7B-Instruct/checkpoint-210_torch.float16_lf,0.7013333333333334,0.7969562600853992,0.7013333333333334,0.7362679665494508,1.0
|
9 |
+
1.4,Qwen/Qwen2-7B-Instruct/checkpoint-245_torch.float16_lf,0.7326666666666667,0.7922538479314682,0.7326666666666667,0.755402136631717,0.9996666666666667
|
10 |
+
1.6,Qwen/Qwen2-7B-Instruct/checkpoint-280_torch.float16_lf,0.6983333333333334,0.785127298428753,0.6983333333333334,0.7292251109166867,1.0
|
11 |
+
1.8,Qwen/Qwen2-7B-Instruct/checkpoint-315_torch.float16_lf,0.6783333333333333,0.785390767631834,0.6783333333333333,0.7164131321837346,1.0
|
12 |
+
2.0,Qwen/Qwen2-7B-Instruct/checkpoint-350_torch.float16_lf,0.689,0.7929715746898984,0.689,0.7259993126510194,1.0
|
data/internlm2_5-7b-chat-1m_metrics.csv
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
epoch,model,accuracy,precision,recall,f1
|
2 |
-
0.0,internlm/internlm2_5-7b-chat-1m_torch.bfloat16_lf,0.5106666666666667,0.743213901498142,0.5106666666666667,0.5357333853323308
|
3 |
-
0.2,internlm/internlm2_5-7b-chat-1m/checkpoint-35_torch.bfloat16_lf,0.7843333333333333,0.7977648302848388,0.7843333333333333,0.7864944570659659
|
4 |
-
0.4,internlm/internlm2_5-7b-chat-1m/checkpoint-70_torch.bfloat16_lf,0.7836666666666666,0.7996977262947886,0.7836666666666666,0.7886881726841081
|
5 |
-
0.6,internlm/internlm2_5-7b-chat-1m/checkpoint-105_torch.bfloat16_lf,0.7243333333333334,0.8171172705912051,0.7243333333333334,0.7565804830382912
|
6 |
-
0.8,internlm/internlm2_5-7b-chat-1m/checkpoint-140_torch.bfloat16_lf,0.803,0.8031411888150441,0.803,0.8028064320197301
|
7 |
-
1.0,internlm/internlm2_5-7b-chat-1m/checkpoint-175_torch.bfloat16_lf,0.7676666666666667,0.8108441731715863,0.7676666666666667,0.7843187816704813
|
8 |
-
1.2,internlm/internlm2_5-7b-chat-1m/checkpoint-210_torch.bfloat16_lf,0.7736666666666666,0.8091671780923799,0.7736666666666666,0.7876874850235454
|
9 |
-
1.4,internlm/internlm2_5-7b-chat-1m/checkpoint-245_torch.bfloat16_lf,0.7623333333333333,0.8062291602218205,0.7623333333333333,0.777669094563925
|
10 |
-
1.6,internlm/internlm2_5-7b-chat-1m/checkpoint-280_torch.bfloat16_lf,0.7553333333333333,0.8086197936829652,0.7553333333333333,0.7755588811428297
|
11 |
-
1.8,internlm/internlm2_5-7b-chat-1m/checkpoint-315_torch.bfloat16_lf,0.748,0.8171996792797457,0.748,0.773990849396903
|
12 |
-
2.0,internlm/internlm2_5-7b-chat-1m/checkpoint-350_torch.bfloat16_lf,0.756,0.8126875394266148,0.756,0.7777812522863184
|
|
|
1 |
+
epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
|
2 |
+
0.0,internlm/internlm2_5-7b-chat-1m_torch.bfloat16_lf,0.5106666666666667,0.743213901498142,0.5106666666666667,0.5357333853323308,1.0
|
3 |
+
0.2,internlm/internlm2_5-7b-chat-1m/checkpoint-35_torch.bfloat16_lf,0.7843333333333333,0.7977648302848388,0.7843333333333333,0.7864944570659659,1.0
|
4 |
+
0.4,internlm/internlm2_5-7b-chat-1m/checkpoint-70_torch.bfloat16_lf,0.7836666666666666,0.7996977262947886,0.7836666666666666,0.7886881726841081,1.0
|
5 |
+
0.6,internlm/internlm2_5-7b-chat-1m/checkpoint-105_torch.bfloat16_lf,0.7243333333333334,0.8171172705912051,0.7243333333333334,0.7565804830382912,1.0
|
6 |
+
0.8,internlm/internlm2_5-7b-chat-1m/checkpoint-140_torch.bfloat16_lf,0.803,0.8031411888150441,0.803,0.8028064320197301,1.0
|
7 |
+
1.0,internlm/internlm2_5-7b-chat-1m/checkpoint-175_torch.bfloat16_lf,0.7676666666666667,0.8108441731715863,0.7676666666666667,0.7843187816704813,1.0
|
8 |
+
1.2,internlm/internlm2_5-7b-chat-1m/checkpoint-210_torch.bfloat16_lf,0.7736666666666666,0.8091671780923799,0.7736666666666666,0.7876874850235454,1.0
|
9 |
+
1.4,internlm/internlm2_5-7b-chat-1m/checkpoint-245_torch.bfloat16_lf,0.7623333333333333,0.8062291602218205,0.7623333333333333,0.777669094563925,1.0
|
10 |
+
1.6,internlm/internlm2_5-7b-chat-1m/checkpoint-280_torch.bfloat16_lf,0.7553333333333333,0.8086197936829652,0.7553333333333333,0.7755588811428297,1.0
|
11 |
+
1.8,internlm/internlm2_5-7b-chat-1m/checkpoint-315_torch.bfloat16_lf,0.748,0.8171996792797457,0.748,0.773990849396903,1.0
|
12 |
+
2.0,internlm/internlm2_5-7b-chat-1m/checkpoint-350_torch.bfloat16_lf,0.756,0.8126875394266148,0.756,0.7777812522863184,1.0
|
llm_toolkit/logical_reasoning_utils.py
CHANGED
@@ -95,7 +95,7 @@ def get_prompt_template(using_p1=True, chinese_prompt=True):
|
|
95 |
|
96 |
|
97 |
def extract_answer(text, debug=False):
|
98 |
-
if text:
|
99 |
# Remove the begin and end tokens
|
100 |
text = re.sub(
|
101 |
r".*?(assistant|\[/INST\]).+?\b",
|
@@ -117,6 +117,7 @@ def extract_answer(text, debug=False):
|
|
117 |
print("--------\nstep 3:", text)
|
118 |
|
119 |
text = text.split(".")[0].strip()
|
|
|
120 |
if debug:
|
121 |
print("--------\nstep 4:", text)
|
122 |
|
@@ -129,7 +130,9 @@ def extract_answer(text, debug=False):
|
|
129 |
if debug:
|
130 |
print("--------\nstep 5:", text)
|
131 |
|
132 |
-
|
|
|
|
|
133 |
|
134 |
|
135 |
def calc_metrics(references, predictions, debug=False):
|
@@ -137,16 +140,33 @@ def calc_metrics(references, predictions, debug=False):
|
|
137 |
predictions
|
138 |
), f"lengths are difference: {len(references)} != {len(predictions)}"
|
139 |
|
|
|
|
|
|
|
140 |
predictions = [extract_answer(text) for text in predictions]
|
141 |
|
142 |
-
|
143 |
-
accuracy = sum(correct) / len(references)
|
144 |
|
145 |
results = {"accuracy": accuracy}
|
146 |
if debug:
|
147 |
-
incorrect_ids = [i for i,
|
148 |
results["incorrect_ids"] = incorrect_ids
|
149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
return results
|
151 |
|
152 |
|
@@ -240,7 +260,7 @@ def get_metrics(df):
|
|
240 |
rouge_l = []
|
241 |
all_metrics = []
|
242 |
for col in df.columns[2:]:
|
243 |
-
metrics = calc_metrics(df["
|
244 |
print(f"{col}: {metrics}")
|
245 |
|
246 |
accuracy.append(metrics["accuracy"])
|
@@ -290,38 +310,37 @@ def load_alpaca_data(data_path, using_p1=True, use_english_datasets=False):
|
|
290 |
return df_alpaca
|
291 |
|
292 |
|
293 |
-
def plot_value_counts(df,
|
294 |
font_family = rcParams["font.family"]
|
295 |
# Set the font to SimHei to support Chinese characters
|
296 |
rcParams["font.family"] = "STHeiti"
|
297 |
rcParams["axes.unicode_minus"] = (
|
298 |
False # This is to support the minus sign in Chinese.
|
299 |
)
|
|
|
|
|
|
|
300 |
|
301 |
plt.figure(figsize=(12, 6))
|
302 |
-
df[
|
303 |
# add values on top of bars
|
304 |
-
for i, v in enumerate(df[
|
305 |
-
plt.text(i, v +
|
306 |
|
307 |
-
plt.xlabel(title or
|
308 |
|
309 |
plt.show()
|
310 |
|
311 |
rcParams["font.family"] = font_family
|
312 |
|
|
|
|
|
|
|
313 |
|
314 |
-
def calc_metrics_for_col(df, col):
|
315 |
-
y_true = df["label"]
|
316 |
-
y_pred = df[col]
|
317 |
-
labels = np.unique(y_true)
|
318 |
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
f1 = f1_score(y_true, y_pred, average="weighted", labels=labels)
|
323 |
-
|
324 |
-
return accuracy, float(precision), float(recall), float(f1)
|
325 |
|
326 |
|
327 |
def get_metrics_df(df):
|
@@ -329,15 +348,12 @@ def get_metrics_df(df):
|
|
329 |
columns=["epoch", "model", "accuracy", "precision", "recall", "f1"]
|
330 |
)
|
331 |
for i, col in enumerate(df.columns[5:]):
|
332 |
-
|
333 |
new_model_metrics = {
|
334 |
"epoch": i / 5,
|
335 |
"model": col,
|
336 |
-
"accuracy": accuracy,
|
337 |
-
"precision": precision,
|
338 |
-
"recall": recall,
|
339 |
-
"f1": f1,
|
340 |
}
|
|
|
341 |
|
342 |
# Convert the dictionary to a DataFrame and concatenate it with the existing DataFrame
|
343 |
perf_df = pd.concat(
|
|
|
95 |
|
96 |
|
97 |
def extract_answer(text, debug=False):
|
98 |
+
if text and isinstance(text, str):
|
99 |
# Remove the begin and end tokens
|
100 |
text = re.sub(
|
101 |
r".*?(assistant|\[/INST\]).+?\b",
|
|
|
117 |
print("--------\nstep 3:", text)
|
118 |
|
119 |
text = text.split(".")[0].strip()
|
120 |
+
text = text.split("。")[0].strip()
|
121 |
if debug:
|
122 |
print("--------\nstep 4:", text)
|
123 |
|
|
|
130 |
if debug:
|
131 |
print("--------\nstep 5:", text)
|
132 |
|
133 |
+
return text.strip()
|
134 |
+
|
135 |
+
return ""
|
136 |
|
137 |
|
138 |
def calc_metrics(references, predictions, debug=False):
|
|
|
140 |
predictions
|
141 |
), f"lengths are difference: {len(references)} != {len(predictions)}"
|
142 |
|
143 |
+
labels = np.unique(references)
|
144 |
+
valid_classifications = [1 if p in labels else 0 for p in predictions]
|
145 |
+
|
146 |
predictions = [extract_answer(text) for text in predictions]
|
147 |
|
148 |
+
accuracy = accuracy_score(references, predictions)
|
|
|
149 |
|
150 |
results = {"accuracy": accuracy}
|
151 |
if debug:
|
152 |
+
incorrect_ids = [i for i, p in enumerate(predictions) if p != references[i]]
|
153 |
results["incorrect_ids"] = incorrect_ids
|
154 |
|
155 |
+
precision = precision_score(
|
156 |
+
references, predictions, average="weighted", labels=labels
|
157 |
+
)
|
158 |
+
results["precision"] = float(precision)
|
159 |
+
|
160 |
+
recall = recall_score(references, predictions, average="weighted", labels=labels)
|
161 |
+
results["recall"] = float(recall)
|
162 |
+
|
163 |
+
f1 = f1_score(references, predictions, average="weighted", labels=labels)
|
164 |
+
results["f1"] = float(f1)
|
165 |
+
|
166 |
+
results["ratio_valid_classifications"] = sum(valid_classifications) / len(
|
167 |
+
valid_classifications
|
168 |
+
)
|
169 |
+
|
170 |
return results
|
171 |
|
172 |
|
|
|
260 |
rouge_l = []
|
261 |
all_metrics = []
|
262 |
for col in df.columns[2:]:
|
263 |
+
metrics = calc_metrics(df["label"], df[col], debug=True)
|
264 |
print(f"{col}: {metrics}")
|
265 |
|
266 |
accuracy.append(metrics["accuracy"])
|
|
|
310 |
return df_alpaca
|
311 |
|
312 |
|
313 |
+
def plot_value_counts(df, column_name, offset=0.1, title=None, preprocess_func=None):
|
314 |
font_family = rcParams["font.family"]
|
315 |
# Set the font to SimHei to support Chinese characters
|
316 |
rcParams["font.family"] = "STHeiti"
|
317 |
rcParams["axes.unicode_minus"] = (
|
318 |
False # This is to support the minus sign in Chinese.
|
319 |
)
|
320 |
+
if preprocess_func:
|
321 |
+
df["backup"] = df[column_name]
|
322 |
+
df[column_name] = df[column_name].apply(preprocess_func)
|
323 |
|
324 |
plt.figure(figsize=(12, 6))
|
325 |
+
df[column_name].value_counts().plot(kind="bar")
|
326 |
# add values on top of bars
|
327 |
+
for i, v in enumerate(df[column_name].value_counts()):
|
328 |
+
plt.text(i, v + offset, str(v), ha="center")
|
329 |
|
330 |
+
plt.xlabel(title or column_name)
|
331 |
|
332 |
plt.show()
|
333 |
|
334 |
rcParams["font.family"] = font_family
|
335 |
|
336 |
+
if preprocess_func:
|
337 |
+
df[column_name] = df["backup"]
|
338 |
+
df.drop(columns=["backup"], inplace=True)
|
339 |
|
|
|
|
|
|
|
|
|
340 |
|
341 |
+
def calc_metrics_for_col(df, col):
|
342 |
+
metrics = calc_metrics(df["label"], df[col], debug=True)
|
343 |
+
return metrics["accuracy"], metrics["precision"], metrics["recall"], metrics["f1"]
|
|
|
|
|
|
|
344 |
|
345 |
|
346 |
def get_metrics_df(df):
|
|
|
348 |
columns=["epoch", "model", "accuracy", "precision", "recall", "f1"]
|
349 |
)
|
350 |
for i, col in enumerate(df.columns[5:]):
|
351 |
+
metrics = calc_metrics(df["label"], df[col], debug=False)
|
352 |
new_model_metrics = {
|
353 |
"epoch": i / 5,
|
354 |
"model": col,
|
|
|
|
|
|
|
|
|
355 |
}
|
356 |
+
new_model_metrics.update(metrics)
|
357 |
|
358 |
# Convert the dictionary to a DataFrame and concatenate it with the existing DataFrame
|
359 |
perf_df = pd.concat(
|
notebooks/00_Data Analysis.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/01_internlm2_5-7b-chat-1m_analysis.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/01a_internlm2_5-7b-chat-1m_analysis.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/01b_Mistral-7B-v0.3-Chinese-Chat_analysis.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/02a_Qwen2-7B-Instruct_analysis.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/02b_Qwen2-72B-Instruct_analysis.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/03a_Llama3.1-8B-Chinese-Chat_analysis.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/03b_Llama3.1-70B-Chinese-Chat_analysis.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|