dh-mc commited on
Commit
468b88d
·
1 Parent(s): 6e932d8

completed eval/analysis

Browse files
data/Llama3.1-70B-Chinese-Chat_metrics.csv CHANGED
@@ -1,12 +1,12 @@
1
- epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
2
- 0.0,shenzhi-wang/Llama3.1-70B-Chinese-Chat_torch.bfloat16_4bit_lf,0.7636666666666667,0.7806653325131986,0.7636666666666667,0.7525813484548423,0.009666666666666667
3
- 0.2,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-35_torch.bfloat16_4bit_lf,0.778,0.8148707737020212,0.778,0.7910805488003003,0.9996666666666667
4
- 0.4,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-70_torch.bfloat16_4bit_lf,0.7306666666666667,0.8145782271710159,0.7306666666666667,0.7624724104697406,1.0
5
- 0.6,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-105_torch.bfloat16_4bit_lf,0.7193333333333334,0.8213567226911125,0.7193333333333334,0.7560702640626931,1.0
6
- 0.8,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-140_torch.bfloat16_4bit_lf,0.7563333333333333,0.826789897753756,0.7563333333333333,0.7815164366677209,1.0
7
- 1.0,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-175_torch.bfloat16_4bit_lf,0.7963333333333333,0.8248972880055918,0.7963333333333333,0.8076868978089201,1.0
8
- 1.2,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-210_torch.bfloat16_4bit_lf,0.7326666666666667,0.8265345821998035,0.7326666666666667,0.7644418492070342,1.0
9
- 1.4,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-245_torch.bfloat16_4bit_lf,0.7556666666666667,0.8258994609525315,0.7556666666666667,0.7820405339757727,1.0
10
- 1.6,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-280_torch.bfloat16_4bit_lf,0.757,0.8264461657684251,0.757,0.7834496144681513,1.0
11
- 1.8,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-315_torch.bfloat16_4bit_lf,0.7546666666666667,0.8277723752096544,0.7546666666666667,0.7823584779069335,1.0
12
- 2.0,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-350_torch.bfloat16_4bit_lf,0.7496666666666667,0.8282310230333227,0.7496666666666667,0.7791947625361637,1.0
 
1
+ epoch,model,run,accuracy,precision,recall,f1,ratio_valid_classifications
2
+ 0.0,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat_torch.bfloat16_4bit_lf,0.7636666666666667,0.7806653325131986,0.7636666666666667,0.7525813484548423,0.009666666666666667
3
+ 0.2,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-35_torch.bfloat16_4bit_lf,0.778,0.8148707737020212,0.778,0.7910805488003003,0.9996666666666667
4
+ 0.4,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-70_torch.bfloat16_4bit_lf,0.7306666666666667,0.8145782271710159,0.7306666666666667,0.7624724104697406,1.0
5
+ 0.6,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-105_torch.bfloat16_4bit_lf,0.7193333333333334,0.8213567226911125,0.7193333333333334,0.7560702640626931,1.0
6
+ 0.8,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-140_torch.bfloat16_4bit_lf,0.7563333333333333,0.826789897753756,0.7563333333333333,0.7815164366677209,1.0
7
+ 1.0,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-175_torch.bfloat16_4bit_lf,0.7963333333333333,0.8248972880055918,0.7963333333333333,0.8076868978089201,1.0
8
+ 1.2,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-210_torch.bfloat16_4bit_lf,0.7326666666666667,0.8265345821998035,0.7326666666666667,0.7644418492070342,1.0
9
+ 1.4,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-245_torch.bfloat16_4bit_lf,0.7556666666666667,0.8258994609525315,0.7556666666666667,0.7820405339757727,1.0
10
+ 1.6,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-280_torch.bfloat16_4bit_lf,0.757,0.8264461657684251,0.757,0.7834496144681513,1.0
11
+ 1.8,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-315_torch.bfloat16_4bit_lf,0.7546666666666667,0.8277723752096544,0.7546666666666667,0.7823584779069335,1.0
12
+ 2.0,Llama3.1-70B-Chinese-Chat,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-350_torch.bfloat16_4bit_lf,0.7496666666666667,0.8282310230333227,0.7496666666666667,0.7791947625361637,1.0
data/Llama3.1-8B-Chinese-Chat_metrics.csv CHANGED
@@ -1,12 +1,12 @@
1
- epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
2
- 0.0,shenzhi-wang/Llama3.1-8B-Chinese-Chat_torch.float16_lf,0.707,0.7631091217915184,0.707,0.7243940517731183,0.3923333333333333
3
- 0.2,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-35_torch.float16_lf,0.709,0.7987219597893886,0.709,0.7427961200958145,1.0
4
- 0.4,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-70_torch.float16_lf,0.7163333333333334,0.8058657875960304,0.7163333333333334,0.7487811196109319,0.9993333333333333
5
- 0.6,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-105_torch.float16_lf,0.6996666666666667,0.802722482275839,0.6996666666666667,0.7370938556711591,1.0
6
- 0.8,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-140_torch.float16_lf,0.7716666666666666,0.8092193821623755,0.7716666666666666,0.7864287269398251,1.0
7
- 1.0,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-175_torch.float16_lf,0.78,0.810582723471486,0.78,0.7924651054056209,1.0
8
- 1.2,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-210_torch.float16_lf,0.7313333333333333,0.8157783263996798,0.7313333333333333,0.7628807622782868,1.0
9
- 1.4,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-245_torch.float16_lf,0.751,0.8125856808988221,0.751,0.7745416635653988,1.0
10
- 1.6,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-280_torch.float16_lf,0.739,0.8097375095673094,0.739,0.7662329023371559,1.0
11
- 1.8,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-315_torch.float16_lf,0.7236666666666667,0.8145530585912838,0.7236666666666667,0.7580428816095297,1.0
12
- 2.0,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-350_torch.float16_lf,0.7293333333333333,0.8151184301713545,0.7293333333333333,0.7616699266814145,1.0
 
1
+ epoch,model,run,accuracy,precision,recall,f1,ratio_valid_classifications
2
+ 0.0,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat_torch.float16_lf,0.707,0.7631091217915184,0.707,0.7243940517731183,0.3923333333333333
3
+ 0.2,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-35_torch.float16_lf,0.709,0.7987219597893886,0.709,0.7427961200958145,1.0
4
+ 0.4,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-70_torch.float16_lf,0.7163333333333334,0.8058657875960304,0.7163333333333334,0.7487811196109319,0.9993333333333333
5
+ 0.6,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-105_torch.float16_lf,0.6996666666666667,0.802722482275839,0.6996666666666667,0.7370938556711591,1.0
6
+ 0.8,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-140_torch.float16_lf,0.7716666666666666,0.8092193821623755,0.7716666666666666,0.7864287269398251,1.0
7
+ 1.0,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-175_torch.float16_lf,0.78,0.810582723471486,0.78,0.7924651054056209,1.0
8
+ 1.2,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-210_torch.float16_lf,0.7313333333333333,0.8157783263996798,0.7313333333333333,0.7628807622782868,1.0
9
+ 1.4,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-245_torch.float16_lf,0.751,0.8125856808988221,0.751,0.7745416635653988,1.0
10
+ 1.6,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-280_torch.float16_lf,0.739,0.8097375095673094,0.739,0.7662329023371559,1.0
11
+ 1.8,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-315_torch.float16_lf,0.7236666666666667,0.8145530585912838,0.7236666666666667,0.7580428816095297,1.0
12
+ 2.0,Llama3.1-8B-Chinese-Chat,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-350_torch.float16_lf,0.7293333333333333,0.8151184301713545,0.7293333333333333,0.7616699266814145,1.0
data/Mistral-7B-v0.3-Chinese-Chat_metrics.csv CHANGED
@@ -1,12 +1,12 @@
1
- epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
2
- 0.0,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat_torch.float16_lf,0.7113333333333334,0.70220546362905,0.7113333333333334,0.6894974942637364,0.004
3
- 0.2,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-35_torch.float16_lf,0.702,0.7932731014186957,0.702,0.7342714734731689,1.0
4
- 0.4,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-70_torch.float16_lf,0.742,0.78982949223512,0.742,0.7536681109811127,1.0
5
- 0.6,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-105_torch.float16_lf,0.6596666666666666,0.7923396753604393,0.6596666666666666,0.7067542301676931,1.0
6
- 0.8,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-140_torch.float16_lf,0.7146666666666667,0.7861341885687435,0.7146666666666667,0.7404677278137267,1.0
7
- 1.0,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-175_torch.float16_lf,0.7326666666666667,0.7876867721932461,0.7326666666666667,0.7471869515031995,1.0
8
- 1.2,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-210_torch.float16_lf,0.7016666666666667,0.7903119228393193,0.7016666666666667,0.7348708822385348,1.0
9
- 1.4,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-245_torch.float16_lf,0.75,0.7885868317699068,0.75,0.7648234347578796,1.0
10
- 1.6,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-280_torch.float16_lf,0.7156666666666667,0.7846106674095725,0.7156666666666667,0.7410042005708856,1.0
11
- 1.8,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-315_torch.float16_lf,0.6916666666666667,0.7864256994491394,0.6916666666666667,0.7257499426487266,1.0
12
- 2.0,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-350_torch.float16_lf,0.6976666666666667,0.7889443494370009,0.6976666666666667,0.7307996137659796,1.0
 
1
+ epoch,model,run,accuracy,precision,recall,f1,ratio_valid_classifications
2
+ 0.0,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat_torch.float16_lf,0.7113333333333334,0.70220546362905,0.7113333333333334,0.6894974942637364,0.004
3
+ 0.2,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-35_torch.float16_lf,0.702,0.7932731014186957,0.702,0.7342714734731689,1.0
4
+ 0.4,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-70_torch.float16_lf,0.742,0.78982949223512,0.742,0.7536681109811127,1.0
5
+ 0.6,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-105_torch.float16_lf,0.6596666666666666,0.7923396753604393,0.6596666666666666,0.7067542301676931,1.0
6
+ 0.8,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-140_torch.float16_lf,0.7146666666666667,0.7861341885687435,0.7146666666666667,0.7404677278137267,1.0
7
+ 1.0,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-175_torch.float16_lf,0.7326666666666667,0.7876867721932461,0.7326666666666667,0.7471869515031995,1.0
8
+ 1.2,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-210_torch.float16_lf,0.7016666666666667,0.7903119228393193,0.7016666666666667,0.7348708822385348,1.0
9
+ 1.4,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-245_torch.float16_lf,0.75,0.7885868317699068,0.75,0.7648234347578796,1.0
10
+ 1.6,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-280_torch.float16_lf,0.7156666666666667,0.7846106674095725,0.7156666666666667,0.7410042005708856,1.0
11
+ 1.8,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-315_torch.float16_lf,0.6916666666666667,0.7864256994491394,0.6916666666666667,0.7257499426487266,1.0
12
+ 2.0,Mistral-7B-v0.3-Chinese-Chat,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-350_torch.float16_lf,0.6976666666666667,0.7889443494370009,0.6976666666666667,0.7307996137659796,1.0
data/Qwen2-72B-Instruct_metrics.csv CHANGED
@@ -1,8 +1,12 @@
1
- epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
2
- 0.0,Qwen/Qwen2-72B-Instruct_torch.bfloat16_4bit_lf,0.7473333333333333,0.804122252986722,0.7473333333333333,0.7607828719113865,0.9773333333333334
3
- 0.2,Qwen/Qwen2-72B-Instruct/checkpoint-35_torch.bfloat16_4bit_lf,0.7583333333333333,0.8199928526815756,0.7583333333333333,0.782751089787442,1.0
4
- 0.4,Qwen/Qwen2-72B-Instruct/checkpoint-70_torch.bfloat16_4bit_lf,0.7366666666666667,0.8224865755517643,0.7366666666666667,0.7700627366337021,1.0
5
- 0.6,Qwen/Qwen2-72B-Instruct/checkpoint-105_torch.bfloat16_4bit_lf,0.757,0.8253824826209251,0.757,0.784000409833628,1.0
6
- 0.8,Qwen/Qwen2-72B-Instruct/checkpoint-140_torch.bfloat16_4bit_lf,0.7893333333333333,0.8229104753645825,0.7893333333333333,0.8033124955993173,1.0
7
- 1.0,Qwen/Qwen2-72B-Instruct/checkpoint-175_torch.bfloat16_4bit_lf,0.7376666666666667,0.8243654864769323,0.7376666666666667,0.7699617360961548,1.0
8
- 1.2,Qwen/Qwen2-72B-Instruct/checkpoint-210_torch.bfloat16_4bit_lf,0.763,0.8318882808702871,0.763,0.7901075708186186,1.0
 
 
 
 
 
1
+ epoch,model,run,accuracy,precision,recall,f1,ratio_valid_classifications
2
+ 0.0,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct_torch.bfloat16_4bit_lf,0.7516666666666667,0.7949378981748352,0.7516666666666667,0.7572499605227642,0.9773333333333334
3
+ 0.2,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-35_torch.bfloat16_4bit_lf,0.7583333333333333,0.8199928526815756,0.7583333333333333,0.782751089787442,1.0
4
+ 0.4,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-70_torch.bfloat16_4bit_lf,0.7366666666666667,0.8224865755517643,0.7366666666666667,0.7700627366337021,1.0
5
+ 0.6,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-105_torch.bfloat16_4bit_lf,0.757,0.8253824826209251,0.757,0.784000409833628,1.0
6
+ 0.8,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-140_torch.bfloat16_4bit_lf,0.7893333333333333,0.8229104753645825,0.7893333333333333,0.8033124955993173,1.0
7
+ 1.0,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-175_torch.bfloat16_4bit_lf,0.7376666666666667,0.8243654864769323,0.7376666666666667,0.7699617360961548,1.0
8
+ 1.2,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-210_torch.bfloat16_4bit_lf,0.763,0.8318882808702871,0.763,0.7901075708186186,1.0
9
+ 1.4,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-245_torch.bfloat16_4bit_lf,0.7656666666666667,0.8288272203240518,0.7656666666666667,0.790627109330698,1.0
10
+ 1.6,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-280_torch.bfloat16_4bit_lf,0.7693333333333333,0.8292798021666021,0.7693333333333333,0.7930169589012503,1.0
11
+ 1.8,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-315_torch.bfloat16_4bit_lf,0.784,0.8354349234761956,0.784,0.804194683154365,1.0
12
+ 2.0,Qwen2-72B-Instruct,Qwen/Qwen2-72B-Instruct/checkpoint-350_torch.bfloat16_4bit_lf,0.7736666666666666,0.8330147983140184,0.7736666666666666,0.7973657072550873,1.0
data/Qwen2-72B-Instruct_results.csv CHANGED
The diff for this file is too large to render. See raw diff
 
data/Qwen2-7B-Instruct_metrics.csv CHANGED
@@ -1,12 +1,12 @@
1
- epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
2
- 0.0,Qwen/Qwen2-7B-Instruct_torch.float16_lf,0.6203333333333333,0.7554720257311661,0.6203333333333333,0.6731632664545455,0.9973333333333333
3
- 0.2,Qwen/Qwen2-7B-Instruct/checkpoint-35_torch.float16_lf,0.725,0.7840171468707405,0.725,0.748994536667058,0.9996666666666667
4
- 0.4,Qwen/Qwen2-7B-Instruct/checkpoint-70_torch.float16_lf,0.759,0.8005303465799652,0.759,0.7748745026535183,1.0
5
- 0.6,Qwen/Qwen2-7B-Instruct/checkpoint-105_torch.float16_lf,0.6926666666666667,0.8039176975550218,0.6926666666666667,0.7332481528585848,1.0
6
- 0.8,Qwen/Qwen2-7B-Instruct/checkpoint-140_torch.float16_lf,0.725,0.7952719247171957,0.725,0.7476238017654298,1.0
7
- 1.0,Qwen/Qwen2-7B-Instruct/checkpoint-175_torch.float16_lf,0.6756666666666666,0.7810148934939715,0.6756666666666666,0.708653993277772,1.0
8
- 1.2,Qwen/Qwen2-7B-Instruct/checkpoint-210_torch.float16_lf,0.7013333333333334,0.7969562600853992,0.7013333333333334,0.7362679665494508,1.0
9
- 1.4,Qwen/Qwen2-7B-Instruct/checkpoint-245_torch.float16_lf,0.7326666666666667,0.7922538479314682,0.7326666666666667,0.755402136631717,0.9996666666666667
10
- 1.6,Qwen/Qwen2-7B-Instruct/checkpoint-280_torch.float16_lf,0.6983333333333334,0.785127298428753,0.6983333333333334,0.7292251109166867,1.0
11
- 1.8,Qwen/Qwen2-7B-Instruct/checkpoint-315_torch.float16_lf,0.6783333333333333,0.785390767631834,0.6783333333333333,0.7164131321837346,1.0
12
- 2.0,Qwen/Qwen2-7B-Instruct/checkpoint-350_torch.float16_lf,0.689,0.7929715746898984,0.689,0.7259993126510194,1.0
 
1
+ epoch,model,run,accuracy,precision,recall,f1,ratio_valid_classifications
2
+ 0.0,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct_torch.float16_lf,0.6203333333333333,0.7554720257311661,0.6203333333333333,0.6731632664545455,0.9973333333333333
3
+ 0.2,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-35_torch.float16_lf,0.725,0.7840171468707405,0.725,0.748994536667058,0.9996666666666667
4
+ 0.4,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-70_torch.float16_lf,0.759,0.8005303465799652,0.759,0.7748745026535183,1.0
5
+ 0.6,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-105_torch.float16_lf,0.6926666666666667,0.8039176975550218,0.6926666666666667,0.7332481528585848,1.0
6
+ 0.8,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-140_torch.float16_lf,0.725,0.7952719247171957,0.725,0.7476238017654298,1.0
7
+ 1.0,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-175_torch.float16_lf,0.6756666666666666,0.7810148934939715,0.6756666666666666,0.708653993277772,1.0
8
+ 1.2,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-210_torch.float16_lf,0.7013333333333334,0.7969562600853992,0.7013333333333334,0.7362679665494508,1.0
9
+ 1.4,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-245_torch.float16_lf,0.7326666666666667,0.7922538479314682,0.7326666666666667,0.755402136631717,0.9996666666666667
10
+ 1.6,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-280_torch.float16_lf,0.6983333333333334,0.785127298428753,0.6983333333333334,0.7292251109166867,1.0
11
+ 1.8,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-315_torch.float16_lf,0.6783333333333333,0.785390767631834,0.6783333333333333,0.7164131321837346,1.0
12
+ 2.0,Qwen2-7B-Instruct,Qwen/Qwen2-7B-Instruct/checkpoint-350_torch.float16_lf,0.689,0.7929715746898984,0.689,0.7259993126510194,1.0
data/internlm2_5-7b-chat-1m_metrics.csv CHANGED
@@ -1,12 +1,12 @@
1
- epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
2
- 0.0,internlm/internlm2_5-7b-chat-1m_torch.bfloat16_lf,0.5106666666666667,0.743213901498142,0.5106666666666667,0.5357333853323308,1.0
3
- 0.2,internlm/internlm2_5-7b-chat-1m/checkpoint-35_torch.bfloat16_lf,0.7843333333333333,0.7977648302848388,0.7843333333333333,0.7864944570659659,1.0
4
- 0.4,internlm/internlm2_5-7b-chat-1m/checkpoint-70_torch.bfloat16_lf,0.7836666666666666,0.7996977262947886,0.7836666666666666,0.7886881726841081,1.0
5
- 0.6,internlm/internlm2_5-7b-chat-1m/checkpoint-105_torch.bfloat16_lf,0.7243333333333334,0.8171172705912051,0.7243333333333334,0.7565804830382912,1.0
6
- 0.8,internlm/internlm2_5-7b-chat-1m/checkpoint-140_torch.bfloat16_lf,0.803,0.8031411888150441,0.803,0.8028064320197301,1.0
7
- 1.0,internlm/internlm2_5-7b-chat-1m/checkpoint-175_torch.bfloat16_lf,0.7676666666666667,0.8108441731715863,0.7676666666666667,0.7843187816704813,1.0
8
- 1.2,internlm/internlm2_5-7b-chat-1m/checkpoint-210_torch.bfloat16_lf,0.7736666666666666,0.8091671780923799,0.7736666666666666,0.7876874850235454,1.0
9
- 1.4,internlm/internlm2_5-7b-chat-1m/checkpoint-245_torch.bfloat16_lf,0.7623333333333333,0.8062291602218205,0.7623333333333333,0.777669094563925,1.0
10
- 1.6,internlm/internlm2_5-7b-chat-1m/checkpoint-280_torch.bfloat16_lf,0.7553333333333333,0.8086197936829652,0.7553333333333333,0.7755588811428297,1.0
11
- 1.8,internlm/internlm2_5-7b-chat-1m/checkpoint-315_torch.bfloat16_lf,0.748,0.8171996792797457,0.748,0.773990849396903,1.0
12
- 2.0,internlm/internlm2_5-7b-chat-1m/checkpoint-350_torch.bfloat16_lf,0.756,0.8126875394266148,0.756,0.7777812522863184,1.0
 
1
+ epoch,model,run,accuracy,precision,recall,f1,ratio_valid_classifications
2
+ 0.0,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m_torch.bfloat16_lf,0.5106666666666667,0.743213901498142,0.5106666666666667,0.5357333853323308,1.0
3
+ 0.2,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-35_torch.bfloat16_lf,0.7843333333333333,0.7977648302848388,0.7843333333333333,0.7864944570659659,1.0
4
+ 0.4,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-70_torch.bfloat16_lf,0.7836666666666666,0.7996977262947886,0.7836666666666666,0.7886881726841081,1.0
5
+ 0.6,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-105_torch.bfloat16_lf,0.7243333333333334,0.8171172705912051,0.7243333333333334,0.7565804830382912,1.0
6
+ 0.8,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-140_torch.bfloat16_lf,0.803,0.8031411888150441,0.803,0.8028064320197301,1.0
7
+ 1.0,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-175_torch.bfloat16_lf,0.7676666666666667,0.8108441731715863,0.7676666666666667,0.7843187816704813,1.0
8
+ 1.2,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-210_torch.bfloat16_lf,0.7736666666666666,0.8091671780923799,0.7736666666666666,0.7876874850235454,1.0
9
+ 1.4,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-245_torch.bfloat16_lf,0.7623333333333333,0.8062291602218205,0.7623333333333333,0.777669094563925,1.0
10
+ 1.6,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-280_torch.bfloat16_lf,0.7553333333333333,0.8086197936829652,0.7553333333333333,0.7755588811428297,1.0
11
+ 1.8,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-315_torch.bfloat16_lf,0.748,0.8171996792797457,0.748,0.773990849396903,1.0
12
+ 2.0,internlm2_5-7b-chat-1m,internlm/internlm2_5-7b-chat-1m/checkpoint-350_torch.bfloat16_lf,0.756,0.8126875394266148,0.756,0.7777812522863184,1.0
data/openai_metrics.csv ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ shots,model,accuracy,precision,recall,f1,ratio_valid_classifications
2
+ 0,gpt-4o-mini,0.7176666666666667,0.785706730193659,0.7176666666666667,0.7296061848734905,0.9916666666666667
3
+ 5,gpt-4o-mini,0.7176666666666667,0.7767294185987051,0.7176666666666667,0.7181068311028772,0.9996666666666667
4
+ 10,gpt-4o-mini,0.6793333333333333,0.7728086050218999,0.6793333333333333,0.6916749681933937,0.9983333333333333
5
+ 20,gpt-4o-mini,0.6623333333333333,0.7686706009175459,0.6623333333333333,0.6798015109939115,0.998
6
+ 30,gpt-4o-mini,0.6873333333333334,0.7684209723431035,0.6873333333333334,0.6913018667081989,0.999
7
+ 40,gpt-4o-mini,0.6923333333333334,0.7639874967862498,0.6923333333333334,0.6924934068935911,0.9986666666666667
8
+ 50,gpt-4o-mini,0.717,0.7692638634416518,0.717,0.7105227254860433,0.9993333333333333
9
+ 0,gpt-4o,0.782,0.8204048322982596,0.782,0.7953019682198627,0.066
10
+ 5,gpt-4o,0.7873333333333333,0.8230974205170392,0.7873333333333333,0.8000290527498529,0.998
11
+ 10,gpt-4o,0.7916666666666666,0.8227707658360168,0.7916666666666666,0.803614688453356,0.9996666666666667
12
+ 20,gpt-4o,0.7816666666666666,0.8204541793856629,0.7816666666666666,0.7967017169880498,0.9993333333333333
13
+ 30,gpt-4o,0.7886666666666666,0.8260847852316618,0.7886666666666666,0.8030949295928699,0.999
14
+ 40,gpt-4o,0.784,0.8233509309291644,0.784,0.7993336791122846,0.9973333333333333
15
+ 50,gpt-4o,0.787,0.8234800466218334,0.787,0.8013530974301947,0.9993333333333333
data/openai_results.csv CHANGED
The diff for this file is too large to render. See raw diff
 
llm_toolkit/eval_openai.py CHANGED
@@ -57,7 +57,13 @@ def evaluate_model_with_num_shots(
57
  for num_shots in range_num_shots:
58
  print(f"*** Evaluating with num_shots: {num_shots}")
59
 
60
- predictions = eval_openai(eval_dataset, model=model_name, max_new_tokens=max_new_tokens)
 
 
 
 
 
 
61
  model_name_with_shorts = (
62
  result_column_name
63
  if result_column_name
 
57
  for num_shots in range_num_shots:
58
  print(f"*** Evaluating with num_shots: {num_shots}")
59
 
60
+ predictions = eval_openai(
61
+ eval_dataset,
62
+ model=model_name,
63
+ max_new_tokens=max_new_tokens,
64
+ num_shots=num_shots,
65
+ train_dataset=datasets["train"].to_pandas(),
66
+ )
67
  model_name_with_shorts = (
68
  result_column_name
69
  if result_column_name
llm_toolkit/logical_reasoning_utils.py CHANGED
@@ -3,14 +3,21 @@ import re
3
  from langchain_openai import ChatOpenAI
4
  from langchain_core.prompts import ChatPromptTemplate
5
  import pandas as pd
 
6
  import seaborn as sns
7
  import matplotlib.pyplot as plt
8
  from matplotlib import rcParams
9
  from matplotlib.ticker import MultipleLocator
10
  from datasets import load_dataset
11
  import numpy as np
12
- from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
13
- from tqdm import tqdm
 
 
 
 
 
 
14
 
15
  print(f"loading {__file__}")
16
 
@@ -61,17 +68,16 @@ P2 = """你是一个情景猜谜游戏的主持人。游戏规则如下:
61
 
62
  请严格按照这些规则回答参与者提出的问题。
63
 
64
- **谜面:** {}
65
-
66
- **谜底:** {}
67
-
68
- **参与者提出的问题:** {}
69
  """
70
 
71
  P2_en = """You are the host of a situational guessing game. The rules of the game are as follows:
72
 
73
  1. Participants will receive a riddle that describes a simple yet difficult to understand event.
74
- 2. The host knows the answer, which is the solution to the riddle.
75
  3. Participants can ask any closed-ended questions to uncover the truth of the event.
76
  4. For each question, the host will respond with one of the following five options based on the actual situation: Yes, No, Unimportant, Correct answer, or Incorrect questioning. The criteria for each response are as follows:
77
  - If the riddle and answer can provide an answer to the question, respond with: Yes or No
@@ -82,14 +88,35 @@ P2_en = """You are the host of a situational guessing game. The rules of the gam
82
 
83
  Please strictly follow these rules when answering the participant's questions.
84
 
85
- **Riddle:** {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- **Answer:** {}
88
 
89
- **Participant's question:** {}
 
 
 
 
 
90
  """
91
 
92
- system_prompt = "You are an expert in logical reasoning."
93
 
94
  def get_prompt_template(using_p1=True, chinese_prompt=True):
95
  if using_p1:
@@ -98,6 +125,40 @@ def get_prompt_template(using_p1=True, chinese_prompt=True):
98
  return P2 if chinese_prompt else P2_en
99
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  def extract_answer(text, debug=False):
102
  if text and isinstance(text, str):
103
  # Remove the begin and end tokens
@@ -121,7 +182,9 @@ def extract_answer(text, debug=False):
121
  print("--------\nstep 3:", text)
122
 
123
  text = text.split(".")[0].strip()
 
124
  text = text.split("。")[0].strip()
 
125
  if debug:
126
  print("--------\nstep 4:", text)
127
 
@@ -186,7 +249,9 @@ def save_results(model_name, results_path, dataset, predictions, debug=False):
186
  df = dataset
187
  else:
188
  df = dataset.to_pandas()
189
- df.drop(columns=["answer", "prompt", "train_text"], inplace=True, errors="ignore")
 
 
190
  else:
191
  df = pd.read_csv(results_path, on_bad_lines="warn")
192
 
@@ -329,7 +394,7 @@ def plot_value_counts(df, column_name, offset=0.1, title=None, preprocess_func=N
329
  df["backup"] = df[column_name]
330
  df[column_name] = df[column_name].apply(preprocess_func)
331
 
332
- plt.figure(figsize=(12, 6))
333
  df[column_name].value_counts().plot(kind="bar")
334
  # add values on top of bars
335
  for i, v in enumerate(df[column_name].value_counts()):
@@ -342,6 +407,7 @@ def plot_value_counts(df, column_name, offset=0.1, title=None, preprocess_func=N
342
  rcParams["font.family"] = font_family
343
 
344
  if preprocess_func:
 
345
  df[column_name] = df["backup"]
346
  df.drop(columns=["backup"], inplace=True)
347
 
@@ -351,16 +417,22 @@ def calc_metrics_for_col(df, col):
351
  return metrics["accuracy"], metrics["precision"], metrics["recall"], metrics["f1"]
352
 
353
 
354
- def get_metrics_df(df):
355
  perf_df = pd.DataFrame(
356
- columns=["epoch", "model", "accuracy", "precision", "recall", "f1"]
357
  )
358
  for i, col in enumerate(df.columns[5:]):
359
  metrics = calc_metrics(df["label"], df[col], debug=False)
360
  new_model_metrics = {
361
- "epoch": i / 5,
362
- "model": col,
 
363
  }
 
 
 
 
 
364
  new_model_metrics.update(metrics)
365
 
366
  # Convert the dictionary to a DataFrame and concatenate it with the existing DataFrame
@@ -371,51 +443,61 @@ def get_metrics_df(df):
371
  return perf_df
372
 
373
 
374
- def plot_metrics(perf_df, model_name):
375
- fig, ax = plt.subplots(1, 1, figsize=(12, 6))
 
376
 
377
  # Ensure the lengths of perf_df["epoch"], perf_df["accuracy"], and perf_df["f1"] are the same
378
  min_length = min(
379
- len(perf_df["epoch"]), len(perf_df["accuracy"]), len(perf_df["f1"])
380
  )
381
  perf_df = perf_df.iloc[:min_length]
382
 
383
  # Plot accuracy and f1 on the same chart with different markers
384
- ax.plot(perf_df["epoch"], perf_df["accuracy"], marker="o", label="Accuracy")
385
  ax.plot(
386
- perf_df["epoch"], perf_df["f1"], marker="s", label="F1 Score"
 
 
 
387
  ) # Square marker for F1 Score
388
 
389
  # Add values on top of points
390
  for i in range(min_length):
 
391
  ax.annotate(
392
  f"{perf_df['accuracy'].iloc[i]*100:.2f}%",
393
- (perf_df["epoch"].iloc[i], perf_df["accuracy"].iloc[i]),
394
  ha="center",
395
  va="bottom", # Move accuracy numbers below the points
396
  xytext=(0, -15),
397
  textcoords="offset points",
398
  fontsize=10,
 
399
  )
400
  ax.annotate(
401
  f"{perf_df['f1'].iloc[i]*100:.2f}%",
402
- (perf_df["epoch"].iloc[i], perf_df["f1"].iloc[i]),
403
  ha="center",
404
  va="top", # Move F1 score numbers above the points
405
  xytext=(0, 15), # Offset by 15 points vertically
406
  textcoords="offset points",
407
  fontsize=10,
 
408
  )
409
 
410
  # Set y-axis limit
411
- # ax.set_ylim(0.49, 0.825)
 
412
 
413
  # Add title and labels
414
- ax.set_xlabel("Epoch (0: base model, 0.2 - 2: fine-tuned models)")
 
 
 
 
415
  ax.set_ylabel("Accuracy and F1 Score")
416
 
417
- # Set x-axis grid spacing to 0.2
418
- ax.xaxis.set_major_locator(MultipleLocator(0.2))
419
  ax.set_title(f"Performance Analysis Across Checkpoints for the {model_name} Model")
420
 
421
  # Rotate x labels
@@ -460,13 +542,66 @@ def reasoning_with_openai(
460
  return response.content
461
 
462
 
463
- def eval_openai(eval_dataset, model="gpt-4o-mini", max_new_tokens=300):
464
- user_prompt = get_prompt_template(using_p1=False, chinese_prompt=True)
 
 
 
 
 
 
 
 
 
 
 
465
  total = len(eval_dataset)
466
  predictions = []
467
 
468
  for i in tqdm(range(total)):
469
- output = reasoning_with_openai(eval_dataset.iloc[i], user_prompt,model=model, max_tokens=max_new_tokens)
 
 
470
  predictions.append(output)
471
 
472
  return predictions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from langchain_openai import ChatOpenAI
4
  from langchain_core.prompts import ChatPromptTemplate
5
  import pandas as pd
6
+ from tqdm import tqdm
7
  import seaborn as sns
8
  import matplotlib.pyplot as plt
9
  from matplotlib import rcParams
10
  from matplotlib.ticker import MultipleLocator
11
  from datasets import load_dataset
12
  import numpy as np
13
+ from sklearn.metrics import (
14
+ accuracy_score,
15
+ precision_score,
16
+ recall_score,
17
+ f1_score,
18
+ confusion_matrix,
19
+ )
20
+
21
 
22
  print(f"loading {__file__}")
23
 
 
68
 
69
  请严格按照这些规则回答参与者提出的问题。
70
 
71
+ 谜面: {}
72
+ 谜底: {}
73
+ 参与者提出的问题: {}
74
+ 回答:
 
75
  """
76
 
77
  P2_en = """You are the host of a situational guessing game. The rules of the game are as follows:
78
 
79
  1. Participants will receive a riddle that describes a simple yet difficult to understand event.
80
+ 2. The host knows the truth, which is the solution to the riddle.
81
  3. Participants can ask any closed-ended questions to uncover the truth of the event.
82
  4. For each question, the host will respond with one of the following five options based on the actual situation: Yes, No, Unimportant, Correct answer, or Incorrect questioning. The criteria for each response are as follows:
83
  - If the riddle and answer can provide an answer to the question, respond with: Yes or No
 
88
 
89
  Please strictly follow these rules when answering the participant's questions.
90
 
91
+ Riddle: {}
92
+ Truth: {}
93
+ Participant's question: {}
94
+ """
95
+
96
+ system_prompt = "You are an expert in logical reasoning."
97
+
98
+ P2_few_shot = """你是一个情景猜谜游戏的主持人。游戏规则如下:
99
+
100
+ 1. 参与者会得到一个谜面,谜面会描述一个简单又难以理解的事件。
101
+ 2. 主持人知道谜底,谜底是谜面的答案。
102
+ 3. 参与者可以询问任何封闭式问题来找寻事件的真相。
103
+ 4. 对于每个问题,主持人将根据实际情况回答以下五个选项之一:是、不是、不重要、回答正确、问法错误。各回答的判断标准如下:
104
+ - 若谜面和谜底能找到问题的答案,回答:是或者不是
105
+ - 若谜面和谜底不能直接或者间接推断出问题的答案,回答:不重要
106
+ - 若参与者提问不是一个封闭式问题或者问题难以理解,回答:问法错误
107
+ - 若参与者提问基本还原了谜底真相,回答:回答正确
108
+ 5. 回答中不能添加任何其它信息,也不能省略选项中的任何一个字。例如,不可以把“不是”省略成“不”。
109
 
110
+ 请严格按照这些规则回答参与者提出的问题。
111
 
112
+ 示例输入和输出:
113
+ {examples}
114
+ 谜面: {}
115
+ 谜底: {}
116
+ 参与者提出的问题: {}
117
+ 回答:
118
  """
119
 
 
120
 
121
  def get_prompt_template(using_p1=True, chinese_prompt=True):
122
  if using_p1:
 
125
  return P2 if chinese_prompt else P2_en
126
 
127
 
128
+ def get_few_shot_prompt_template(num_shots, train_dataset, debug=False):
129
+ if num_shots == 0:
130
+ return get_prompt_template(using_p1=False, chinese_prompt=True)
131
+
132
+ labels = train_dataset["label"].unique()
133
+ if debug:
134
+ print("num_shots:", num_shots)
135
+ print("labels:", labels)
136
+
137
+ examples = ""
138
+ index = 0
139
+ while num_shots > 0:
140
+ for label in labels:
141
+ while train_dataset["label"][index] != label:
142
+ index += 1
143
+
144
+ row = train_dataset.iloc[index]
145
+ examples += f"""谜面: {row["puzzle"]}
146
+ 谜底: {row["truth"]}
147
+ 参与者提出的问题: {row["text"]}
148
+ 回答: {row["label"]}
149
+
150
+ """
151
+ num_shots -= 1
152
+ if num_shots == 0:
153
+ break
154
+
155
+ prompt = P2_few_shot.replace("{examples}", examples)
156
+ if debug:
157
+ print("P2_few_shot:", prompt)
158
+
159
+ return prompt
160
+
161
+
162
  def extract_answer(text, debug=False):
163
  if text and isinstance(text, str):
164
  # Remove the begin and end tokens
 
182
  print("--------\nstep 3:", text)
183
 
184
  text = text.split(".")[0].strip()
185
+ text = text.split("\n")[0].strip()
186
  text = text.split("。")[0].strip()
187
+ text = text.replace("回答: ", "").strip()
188
  if debug:
189
  print("--------\nstep 4:", text)
190
 
 
249
  df = dataset
250
  else:
251
  df = dataset.to_pandas()
252
+ df.drop(
253
+ columns=["answer", "prompt", "train_text"], inplace=True, errors="ignore"
254
+ )
255
  else:
256
  df = pd.read_csv(results_path, on_bad_lines="warn")
257
 
 
394
  df["backup"] = df[column_name]
395
  df[column_name] = df[column_name].apply(preprocess_func)
396
 
397
+ plt.figure(figsize=(8, 4))
398
  df[column_name].value_counts().plot(kind="bar")
399
  # add values on top of bars
400
  for i, v in enumerate(df[column_name].value_counts()):
 
407
  rcParams["font.family"] = font_family
408
 
409
  if preprocess_func:
410
+ plot_confusion_matrix(df["label"], df[column_name])
411
  df[column_name] = df["backup"]
412
  df.drop(columns=["backup"], inplace=True)
413
 
 
417
  return metrics["accuracy"], metrics["precision"], metrics["recall"], metrics["f1"]
418
 
419
 
420
+ def get_metrics_df(df, variant="epoch"):
421
  perf_df = pd.DataFrame(
422
+ columns=[variant, "model", "run", "accuracy", "precision", "recall", "f1"]
423
  )
424
  for i, col in enumerate(df.columns[5:]):
425
  metrics = calc_metrics(df["label"], df[col], debug=False)
426
  new_model_metrics = {
427
+ variant: i / 5 if variant == "epoch" else i + 1,
428
+ "model": col if "/" not in col else col.split("/")[1].split("_torch")[0],
429
+ "run": col,
430
  }
431
+ if variant == "shots":
432
+ parts = col.split("/shots-")
433
+ new_model_metrics["shots"] = int(parts[1])
434
+ new_model_metrics["model"] = parts[0]
435
+
436
  new_model_metrics.update(metrics)
437
 
438
  # Convert the dictionary to a DataFrame and concatenate it with the existing DataFrame
 
443
  return perf_df
444
 
445
 
446
+ def plot_metrics(perf_df, model_name, variant="epoch", offset=0.01):
447
+ fig, ax = plt.subplots(1, 1, figsize=(8, 4))
448
+ perf_df = perf_df[perf_df["model"] == model_name]
449
 
450
  # Ensure the lengths of perf_df["epoch"], perf_df["accuracy"], and perf_df["f1"] are the same
451
  min_length = min(
452
+ len(perf_df[variant]), len(perf_df["accuracy"]), len(perf_df["f1"])
453
  )
454
  perf_df = perf_df.iloc[:min_length]
455
 
456
  # Plot accuracy and f1 on the same chart with different markers
 
457
  ax.plot(
458
+ perf_df[variant], perf_df["accuracy"], marker="o", label="Accuracy", color="r"
459
+ )
460
+ ax.plot(
461
+ perf_df[variant], perf_df["f1"], marker="s", label="F1 Score", color="b"
462
  ) # Square marker for F1 Score
463
 
464
  # Add values on top of points
465
  for i in range(min_length):
466
+ print(f"{perf_df[variant].iloc[i]}: {perf_df['run'].iloc[i]}")
467
  ax.annotate(
468
  f"{perf_df['accuracy'].iloc[i]*100:.2f}%",
469
+ (perf_df[variant].iloc[i], perf_df["accuracy"].iloc[i]),
470
  ha="center",
471
  va="bottom", # Move accuracy numbers below the points
472
  xytext=(0, -15),
473
  textcoords="offset points",
474
  fontsize=10,
475
+ color="r",
476
  )
477
  ax.annotate(
478
  f"{perf_df['f1'].iloc[i]*100:.2f}%",
479
+ (perf_df[variant].iloc[i], perf_df["f1"].iloc[i]),
480
  ha="center",
481
  va="top", # Move F1 score numbers above the points
482
  xytext=(0, 15), # Offset by 15 points vertically
483
  textcoords="offset points",
484
  fontsize=10,
485
+ color="b",
486
  )
487
 
488
  # Set y-axis limit
489
+ ylimits = ax.get_ylim()
490
+ ax.set_ylim(ylimits[0] - offset, ylimits[1] + offset)
491
 
492
  # Add title and labels
493
+ ax.set_xlabel(
494
+ "Epoch (0: base model, 0.2 - 2: fine-tuned models)"
495
+ if variant == "epoch"
496
+ else "Number of Shots"
497
+ )
498
  ax.set_ylabel("Accuracy and F1 Score")
499
 
500
+ ax.xaxis.set_major_locator(MultipleLocator(0.2 if variant == "epoch" else 5))
 
501
  ax.set_title(f"Performance Analysis Across Checkpoints for the {model_name} Model")
502
 
503
  # Rotate x labels
 
542
  return response.content
543
 
544
 
545
+ def eval_openai(
546
+ eval_dataset,
547
+ model="gpt-4o-mini",
548
+ max_new_tokens=300,
549
+ num_shots=0,
550
+ train_dataset=None,
551
+ ):
552
+ user_prompt = (
553
+ get_prompt_template(using_p1=False, chinese_prompt=True)
554
+ if num_shots == 0
555
+ else get_few_shot_prompt_template(num_shots, train_dataset)
556
+ )
557
+ print("user_prompt:", user_prompt)
558
  total = len(eval_dataset)
559
  predictions = []
560
 
561
  for i in tqdm(range(total)):
562
+ output = reasoning_with_openai(
563
+ eval_dataset.iloc[i], user_prompt, model=model, max_tokens=max_new_tokens
564
+ )
565
  predictions.append(output)
566
 
567
  return predictions
568
+
569
+
570
+ def plot_confusion_matrix(y_true, y_pred, title="Confusion Matrix"):
571
+ font_family = rcParams["font.family"]
572
+ # Set the font to SimHei to support Chinese characters
573
+ rcParams["font.family"] = "STHeiti"
574
+ rcParams["axes.unicode_minus"] = (
575
+ False # This is to support the minus sign in Chinese.
576
+ )
577
+
578
+ labels = np.unique(y_true)
579
+
580
+ y_pred = [extract_answer(text) for text in y_pred]
581
+
582
+ cm = confusion_matrix(y_true, y_pred)
583
+ cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
584
+
585
+ fig, ax = plt.subplots(figsize=(8, 8))
586
+ sns.heatmap(
587
+ cm,
588
+ annot=True,
589
+ fmt=".4f",
590
+ cmap="Blues",
591
+ xticklabels=labels,
592
+ yticklabels=labels,
593
+ )
594
+ ax.set_title(title)
595
+ ax.set_xlabel("Predicted labels")
596
+ ax.set_ylabel("True labels")
597
+ plt.show()
598
+
599
+ rcParams["font.family"] = font_family
600
+
601
+
602
+ def majority_vote(r1, r2, r3):
603
+ label = r2
604
+ if r1 == r3:
605
+ label = r1
606
+
607
+ return label
notebooks/00_Data Analysis.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/01a_internlm2_5-7b-chat-1m_analysis.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/01b_Mistral-7B-v0.3-Chinese-Chat_analysis.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/02a_Qwen2-7B-Instruct_analysis.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/02b_Qwen2-72B-Instruct_analysis.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/03a_Llama3.1-8B-Chinese-Chat_analysis.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/03b_Llama3.1-70B-Chinese-Chat_analysis.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/04_Few-shot_Prompting_OpenAI.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/04b_OpenAI-Models_analysis.ipynb ADDED
The diff for this file is too large to render. See raw diff