dh-mc commited on
Commit
5a8f8d2
·
1 Parent(s): 5dc41da

open source LLM results almost done

Browse files
data/Llama3.1-70B-Chinese-Chat_metrics.csv ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
2
+ 0.0,shenzhi-wang/Llama3.1-70B-Chinese-Chat_torch.bfloat16_4bit_lf,0.7636666666666667,0.7806653325131986,0.7636666666666667,0.7525813484548423,0.009666666666666667
3
+ 0.2,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-35_torch.bfloat16_4bit_lf,0.778,0.8148707737020212,0.778,0.7910805488003003,0.9996666666666667
4
+ 0.4,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-70_torch.bfloat16_4bit_lf,0.7306666666666667,0.8145782271710159,0.7306666666666667,0.7624724104697406,1.0
5
+ 0.6,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-105_torch.bfloat16_4bit_lf,0.7193333333333334,0.8213567226911125,0.7193333333333334,0.7560702640626931,1.0
6
+ 0.8,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-140_torch.bfloat16_4bit_lf,0.7563333333333333,0.826789897753756,0.7563333333333333,0.7815164366677209,1.0
7
+ 1.0,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-175_torch.bfloat16_4bit_lf,0.7963333333333333,0.8248972880055918,0.7963333333333333,0.8076868978089201,1.0
8
+ 1.2,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-210_torch.bfloat16_4bit_lf,0.7326666666666667,0.8265345821998035,0.7326666666666667,0.7644418492070342,1.0
9
+ 1.4,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-245_torch.bfloat16_4bit_lf,0.7556666666666667,0.8258994609525315,0.7556666666666667,0.7820405339757727,1.0
10
+ 1.6,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-280_torch.bfloat16_4bit_lf,0.757,0.8264461657684251,0.757,0.7834496144681513,1.0
11
+ 1.8,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-315_torch.bfloat16_4bit_lf,0.7546666666666667,0.8277723752096544,0.7546666666666667,0.7823584779069335,1.0
12
+ 2.0,shenzhi-wang/Llama3.1-70B-Chinese-Chat/checkpoint-350_torch.bfloat16_4bit_lf,0.7496666666666667,0.8282310230333227,0.7496666666666667,0.7791947625361637,1.0
data/Llama3.1-8B-Chinese-Chat_metrics.csv CHANGED
@@ -1,7 +1,12 @@
1
- epoch,model,accuracy,precision,recall,f1
2
- 0.0,shenzhi-wang/Llama3.1-8B-Chinese-Chat_torch.float16_lf,0.23666666666666666,0.7457179631400438,0.23666666666666666,0.33962354850065374
3
- 0.2,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-35_torch.float16_lf,0.6256666666666667,0.827414387212707,0.6256666666666667,0.6935695138877099
4
- 0.4,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-70_torch.float16_lf,0.762,0.7899461556934093,0.762,0.7667008346960339
5
- 0.6,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-105_torch.float16_lf,0.6803333333333333,0.79802978899557,0.6803333333333333,0.7212437740051865
6
- 0.8,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-140_torch.float16_lf,0.7523333333333333,0.8074258170836324,0.7523333333333333,0.7736442997308933
7
- 1.0,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-175_torch.float16_lf,0.737,0.8090588922502886,0.737,0.7637837184140026
 
 
 
 
 
 
1
+ epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
2
+ 0.0,shenzhi-wang/Llama3.1-8B-Chinese-Chat_torch.float16_lf,0.707,0.7631091217915184,0.707,0.7243940517731183,0.3923333333333333
3
+ 0.2,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-35_torch.float16_lf,0.709,0.7987219597893886,0.709,0.7427961200958145,1.0
4
+ 0.4,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-70_torch.float16_lf,0.7163333333333334,0.8058657875960304,0.7163333333333334,0.7487811196109319,0.9993333333333333
5
+ 0.6,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-105_torch.float16_lf,0.6996666666666667,0.802722482275839,0.6996666666666667,0.7370938556711591,1.0
6
+ 0.8,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-140_torch.float16_lf,0.7716666666666666,0.8092193821623755,0.7716666666666666,0.7864287269398251,1.0
7
+ 1.0,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-175_torch.float16_lf,0.78,0.810582723471486,0.78,0.7924651054056209,1.0
8
+ 1.2,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-210_torch.float16_lf,0.7313333333333333,0.8157783263996798,0.7313333333333333,0.7628807622782868,1.0
9
+ 1.4,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-245_torch.float16_lf,0.751,0.8125856808988221,0.751,0.7745416635653988,1.0
10
+ 1.6,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-280_torch.float16_lf,0.739,0.8097375095673094,0.739,0.7662329023371559,1.0
11
+ 1.8,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-315_torch.float16_lf,0.7236666666666667,0.8145530585912838,0.7236666666666667,0.7580428816095297,1.0
12
+ 2.0,shenzhi-wang/Llama3.1-8B-Chinese-Chat/checkpoint-350_torch.float16_lf,0.7293333333333333,0.8151184301713545,0.7293333333333333,0.7616699266814145,1.0
data/Mistral-7B-v0.3-Chinese-Chat_metrics.csv ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
2
+ 0.0,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat_torch.float16_lf,0.7113333333333334,0.70220546362905,0.7113333333333334,0.6894974942637364,0.004
3
+ 0.2,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-35_torch.float16_lf,0.702,0.7932731014186957,0.702,0.7342714734731689,1.0
4
+ 0.4,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-70_torch.float16_lf,0.742,0.78982949223512,0.742,0.7536681109811127,1.0
5
+ 0.6,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-105_torch.float16_lf,0.6596666666666666,0.7923396753604393,0.6596666666666666,0.7067542301676931,1.0
6
+ 0.8,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-140_torch.float16_lf,0.7146666666666667,0.7861341885687435,0.7146666666666667,0.7404677278137267,1.0
7
+ 1.0,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-175_torch.float16_lf,0.7326666666666667,0.7876867721932461,0.7326666666666667,0.7471869515031995,1.0
8
+ 1.2,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-210_torch.float16_lf,0.7016666666666667,0.7903119228393193,0.7016666666666667,0.7348708822385348,1.0
9
+ 1.4,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-245_torch.float16_lf,0.75,0.7885868317699068,0.75,0.7648234347578796,1.0
10
+ 1.6,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-280_torch.float16_lf,0.7156666666666667,0.7846106674095725,0.7156666666666667,0.7410042005708856,1.0
11
+ 1.8,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-315_torch.float16_lf,0.6916666666666667,0.7864256994491394,0.6916666666666667,0.7257499426487266,1.0
12
+ 2.0,shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/checkpoint-350_torch.float16_lf,0.6976666666666667,0.7889443494370009,0.6976666666666667,0.7307996137659796,1.0
data/Qwen2-72B-Instruct_metrics.csv CHANGED
@@ -1,8 +1,8 @@
1
- epoch,model,accuracy,precision,recall,f1
2
- 0.0,Qwen/Qwen2-72B-Instruct_torch.bfloat16_4bit_lf,0.7473333333333333,0.804122252986722,0.7473333333333333,0.7607828719113865
3
- 0.2,Qwen/Qwen2-72B-Instruct/checkpoint-35_torch.bfloat16_4bit_lf,0.7583333333333333,0.8199928526815756,0.7583333333333333,0.782751089787442
4
- 0.4,Qwen/Qwen2-72B-Instruct/checkpoint-70_torch.bfloat16_4bit_lf,0.7366666666666667,0.8224865755517643,0.7366666666666667,0.7700627366337021
5
- 0.6,Qwen/Qwen2-72B-Instruct/checkpoint-105_torch.bfloat16_4bit_lf,0.757,0.8253824826209251,0.757,0.784000409833628
6
- 0.8,Qwen/Qwen2-72B-Instruct/checkpoint-140_torch.bfloat16_4bit_lf,0.7893333333333333,0.8229104753645825,0.7893333333333333,0.8033124955993173
7
- 1.0,Qwen/Qwen2-72B-Instruct/checkpoint-175_torch.bfloat16_4bit_lf,0.7376666666666667,0.8243654864769323,0.7376666666666667,0.7699617360961548
8
- 1.2,Qwen/Qwen2-72B-Instruct/checkpoint-210_torch.bfloat16_4bit_lf,0.763,0.8318882808702871,0.763,0.7901075708186186
 
1
+ epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
2
+ 0.0,Qwen/Qwen2-72B-Instruct_torch.bfloat16_4bit_lf,0.7473333333333333,0.804122252986722,0.7473333333333333,0.7607828719113865,0.9773333333333334
3
+ 0.2,Qwen/Qwen2-72B-Instruct/checkpoint-35_torch.bfloat16_4bit_lf,0.7583333333333333,0.8199928526815756,0.7583333333333333,0.782751089787442,1.0
4
+ 0.4,Qwen/Qwen2-72B-Instruct/checkpoint-70_torch.bfloat16_4bit_lf,0.7366666666666667,0.8224865755517643,0.7366666666666667,0.7700627366337021,1.0
5
+ 0.6,Qwen/Qwen2-72B-Instruct/checkpoint-105_torch.bfloat16_4bit_lf,0.757,0.8253824826209251,0.757,0.784000409833628,1.0
6
+ 0.8,Qwen/Qwen2-72B-Instruct/checkpoint-140_torch.bfloat16_4bit_lf,0.7893333333333333,0.8229104753645825,0.7893333333333333,0.8033124955993173,1.0
7
+ 1.0,Qwen/Qwen2-72B-Instruct/checkpoint-175_torch.bfloat16_4bit_lf,0.7376666666666667,0.8243654864769323,0.7376666666666667,0.7699617360961548,1.0
8
+ 1.2,Qwen/Qwen2-72B-Instruct/checkpoint-210_torch.bfloat16_4bit_lf,0.763,0.8318882808702871,0.763,0.7901075708186186,1.0
data/Qwen2-7B-Instruct_metrics.csv CHANGED
@@ -1,12 +1,12 @@
1
- epoch,model,accuracy,precision,recall,f1
2
- 0.0,Qwen/Qwen2-7B-Instruct_torch.float16_lf,0.6193333333333333,0.7555701755118281,0.6193333333333333,0.6726302447185493
3
- 0.2,Qwen/Qwen2-7B-Instruct/checkpoint-35_torch.float16_lf,0.725,0.7840171468707405,0.725,0.748994536667058
4
- 0.4,Qwen/Qwen2-7B-Instruct/checkpoint-70_torch.float16_lf,0.759,0.8005303465799652,0.759,0.7748745026535183
5
- 0.6,Qwen/Qwen2-7B-Instruct/checkpoint-105_torch.float16_lf,0.6926666666666667,0.8039176975550218,0.6926666666666667,0.7332481528585848
6
- 0.8,Qwen/Qwen2-7B-Instruct/checkpoint-140_torch.float16_lf,0.725,0.7952719247171957,0.725,0.7476238017654298
7
- 1.0,Qwen/Qwen2-7B-Instruct/checkpoint-175_torch.float16_lf,0.6756666666666666,0.7810148934939715,0.6756666666666666,0.708653993277772
8
- 1.2,Qwen/Qwen2-7B-Instruct/checkpoint-210_torch.float16_lf,0.7013333333333334,0.7969562600853992,0.7013333333333334,0.7362679665494508
9
- 1.4,Qwen/Qwen2-7B-Instruct/checkpoint-245_torch.float16_lf,0.7326666666666667,0.7922538479314682,0.7326666666666667,0.755402136631717
10
- 1.6,Qwen/Qwen2-7B-Instruct/checkpoint-280_torch.float16_lf,0.6983333333333334,0.785127298428753,0.6983333333333334,0.7292251109166867
11
- 1.8,Qwen/Qwen2-7B-Instruct/checkpoint-315_torch.float16_lf,0.6783333333333333,0.785390767631834,0.6783333333333333,0.7164131321837346
12
- 2.0,Qwen/Qwen2-7B-Instruct/checkpoint-350_torch.float16_lf,0.689,0.7929715746898984,0.689,0.7259993126510194
 
1
+ epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
2
+ 0.0,Qwen/Qwen2-7B-Instruct_torch.float16_lf,0.6203333333333333,0.7554720257311661,0.6203333333333333,0.6731632664545455,0.9973333333333333
3
+ 0.2,Qwen/Qwen2-7B-Instruct/checkpoint-35_torch.float16_lf,0.725,0.7840171468707405,0.725,0.748994536667058,0.9996666666666667
4
+ 0.4,Qwen/Qwen2-7B-Instruct/checkpoint-70_torch.float16_lf,0.759,0.8005303465799652,0.759,0.7748745026535183,1.0
5
+ 0.6,Qwen/Qwen2-7B-Instruct/checkpoint-105_torch.float16_lf,0.6926666666666667,0.8039176975550218,0.6926666666666667,0.7332481528585848,1.0
6
+ 0.8,Qwen/Qwen2-7B-Instruct/checkpoint-140_torch.float16_lf,0.725,0.7952719247171957,0.725,0.7476238017654298,1.0
7
+ 1.0,Qwen/Qwen2-7B-Instruct/checkpoint-175_torch.float16_lf,0.6756666666666666,0.7810148934939715,0.6756666666666666,0.708653993277772,1.0
8
+ 1.2,Qwen/Qwen2-7B-Instruct/checkpoint-210_torch.float16_lf,0.7013333333333334,0.7969562600853992,0.7013333333333334,0.7362679665494508,1.0
9
+ 1.4,Qwen/Qwen2-7B-Instruct/checkpoint-245_torch.float16_lf,0.7326666666666667,0.7922538479314682,0.7326666666666667,0.755402136631717,0.9996666666666667
10
+ 1.6,Qwen/Qwen2-7B-Instruct/checkpoint-280_torch.float16_lf,0.6983333333333334,0.785127298428753,0.6983333333333334,0.7292251109166867,1.0
11
+ 1.8,Qwen/Qwen2-7B-Instruct/checkpoint-315_torch.float16_lf,0.6783333333333333,0.785390767631834,0.6783333333333333,0.7164131321837346,1.0
12
+ 2.0,Qwen/Qwen2-7B-Instruct/checkpoint-350_torch.float16_lf,0.689,0.7929715746898984,0.689,0.7259993126510194,1.0
data/internlm2_5-7b-chat-1m_metrics.csv CHANGED
@@ -1,12 +1,12 @@
1
- epoch,model,accuracy,precision,recall,f1
2
- 0.0,internlm/internlm2_5-7b-chat-1m_torch.bfloat16_lf,0.5106666666666667,0.743213901498142,0.5106666666666667,0.5357333853323308
3
- 0.2,internlm/internlm2_5-7b-chat-1m/checkpoint-35_torch.bfloat16_lf,0.7843333333333333,0.7977648302848388,0.7843333333333333,0.7864944570659659
4
- 0.4,internlm/internlm2_5-7b-chat-1m/checkpoint-70_torch.bfloat16_lf,0.7836666666666666,0.7996977262947886,0.7836666666666666,0.7886881726841081
5
- 0.6,internlm/internlm2_5-7b-chat-1m/checkpoint-105_torch.bfloat16_lf,0.7243333333333334,0.8171172705912051,0.7243333333333334,0.7565804830382912
6
- 0.8,internlm/internlm2_5-7b-chat-1m/checkpoint-140_torch.bfloat16_lf,0.803,0.8031411888150441,0.803,0.8028064320197301
7
- 1.0,internlm/internlm2_5-7b-chat-1m/checkpoint-175_torch.bfloat16_lf,0.7676666666666667,0.8108441731715863,0.7676666666666667,0.7843187816704813
8
- 1.2,internlm/internlm2_5-7b-chat-1m/checkpoint-210_torch.bfloat16_lf,0.7736666666666666,0.8091671780923799,0.7736666666666666,0.7876874850235454
9
- 1.4,internlm/internlm2_5-7b-chat-1m/checkpoint-245_torch.bfloat16_lf,0.7623333333333333,0.8062291602218205,0.7623333333333333,0.777669094563925
10
- 1.6,internlm/internlm2_5-7b-chat-1m/checkpoint-280_torch.bfloat16_lf,0.7553333333333333,0.8086197936829652,0.7553333333333333,0.7755588811428297
11
- 1.8,internlm/internlm2_5-7b-chat-1m/checkpoint-315_torch.bfloat16_lf,0.748,0.8171996792797457,0.748,0.773990849396903
12
- 2.0,internlm/internlm2_5-7b-chat-1m/checkpoint-350_torch.bfloat16_lf,0.756,0.8126875394266148,0.756,0.7777812522863184
 
1
+ epoch,model,accuracy,precision,recall,f1,ratio_valid_classifications
2
+ 0.0,internlm/internlm2_5-7b-chat-1m_torch.bfloat16_lf,0.5106666666666667,0.743213901498142,0.5106666666666667,0.5357333853323308,1.0
3
+ 0.2,internlm/internlm2_5-7b-chat-1m/checkpoint-35_torch.bfloat16_lf,0.7843333333333333,0.7977648302848388,0.7843333333333333,0.7864944570659659,1.0
4
+ 0.4,internlm/internlm2_5-7b-chat-1m/checkpoint-70_torch.bfloat16_lf,0.7836666666666666,0.7996977262947886,0.7836666666666666,0.7886881726841081,1.0
5
+ 0.6,internlm/internlm2_5-7b-chat-1m/checkpoint-105_torch.bfloat16_lf,0.7243333333333334,0.8171172705912051,0.7243333333333334,0.7565804830382912,1.0
6
+ 0.8,internlm/internlm2_5-7b-chat-1m/checkpoint-140_torch.bfloat16_lf,0.803,0.8031411888150441,0.803,0.8028064320197301,1.0
7
+ 1.0,internlm/internlm2_5-7b-chat-1m/checkpoint-175_torch.bfloat16_lf,0.7676666666666667,0.8108441731715863,0.7676666666666667,0.7843187816704813,1.0
8
+ 1.2,internlm/internlm2_5-7b-chat-1m/checkpoint-210_torch.bfloat16_lf,0.7736666666666666,0.8091671780923799,0.7736666666666666,0.7876874850235454,1.0
9
+ 1.4,internlm/internlm2_5-7b-chat-1m/checkpoint-245_torch.bfloat16_lf,0.7623333333333333,0.8062291602218205,0.7623333333333333,0.777669094563925,1.0
10
+ 1.6,internlm/internlm2_5-7b-chat-1m/checkpoint-280_torch.bfloat16_lf,0.7553333333333333,0.8086197936829652,0.7553333333333333,0.7755588811428297,1.0
11
+ 1.8,internlm/internlm2_5-7b-chat-1m/checkpoint-315_torch.bfloat16_lf,0.748,0.8171996792797457,0.748,0.773990849396903,1.0
12
+ 2.0,internlm/internlm2_5-7b-chat-1m/checkpoint-350_torch.bfloat16_lf,0.756,0.8126875394266148,0.756,0.7777812522863184,1.0
llm_toolkit/logical_reasoning_utils.py CHANGED
@@ -95,7 +95,7 @@ def get_prompt_template(using_p1=True, chinese_prompt=True):
95
 
96
 
97
  def extract_answer(text, debug=False):
98
- if text:
99
  # Remove the begin and end tokens
100
  text = re.sub(
101
  r".*?(assistant|\[/INST\]).+?\b",
@@ -117,6 +117,7 @@ def extract_answer(text, debug=False):
117
  print("--------\nstep 3:", text)
118
 
119
  text = text.split(".")[0].strip()
 
120
  if debug:
121
  print("--------\nstep 4:", text)
122
 
@@ -129,7 +130,9 @@ def extract_answer(text, debug=False):
129
  if debug:
130
  print("--------\nstep 5:", text)
131
 
132
- return text
 
 
133
 
134
 
135
  def calc_metrics(references, predictions, debug=False):
@@ -137,16 +140,33 @@ def calc_metrics(references, predictions, debug=False):
137
  predictions
138
  ), f"lengths are difference: {len(references)} != {len(predictions)}"
139
 
 
 
 
140
  predictions = [extract_answer(text) for text in predictions]
141
 
142
- correct = [1 if ref == pred else 0 for ref, pred in zip(references, predictions)]
143
- accuracy = sum(correct) / len(references)
144
 
145
  results = {"accuracy": accuracy}
146
  if debug:
147
- incorrect_ids = [i for i, c in enumerate(correct) if c == 0]
148
  results["incorrect_ids"] = incorrect_ids
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  return results
151
 
152
 
@@ -240,7 +260,7 @@ def get_metrics(df):
240
  rouge_l = []
241
  all_metrics = []
242
  for col in df.columns[2:]:
243
- metrics = calc_metrics(df["english"], df[col], debug=True)
244
  print(f"{col}: {metrics}")
245
 
246
  accuracy.append(metrics["accuracy"])
@@ -290,38 +310,37 @@ def load_alpaca_data(data_path, using_p1=True, use_english_datasets=False):
290
  return df_alpaca
291
 
292
 
293
- def plot_value_counts(df, column, title=None):
294
  font_family = rcParams["font.family"]
295
  # Set the font to SimHei to support Chinese characters
296
  rcParams["font.family"] = "STHeiti"
297
  rcParams["axes.unicode_minus"] = (
298
  False # This is to support the minus sign in Chinese.
299
  )
 
 
 
300
 
301
  plt.figure(figsize=(12, 6))
302
- df[column].value_counts().plot(kind="bar")
303
  # add values on top of bars
304
- for i, v in enumerate(df[column].value_counts()):
305
- plt.text(i, v + 0.1, str(v), ha="center")
306
 
307
- plt.xlabel(title or column)
308
 
309
  plt.show()
310
 
311
  rcParams["font.family"] = font_family
312
 
 
 
 
313
 
314
- def calc_metrics_for_col(df, col):
315
- y_true = df["label"]
316
- y_pred = df[col]
317
- labels = np.unique(y_true)
318
 
319
- accuracy = accuracy_score(y_true, y_pred)
320
- precision = precision_score(y_true, y_pred, average="weighted", labels=labels)
321
- recall = recall_score(y_true, y_pred, average="weighted", labels=labels)
322
- f1 = f1_score(y_true, y_pred, average="weighted", labels=labels)
323
-
324
- return accuracy, float(precision), float(recall), float(f1)
325
 
326
 
327
  def get_metrics_df(df):
@@ -329,15 +348,12 @@ def get_metrics_df(df):
329
  columns=["epoch", "model", "accuracy", "precision", "recall", "f1"]
330
  )
331
  for i, col in enumerate(df.columns[5:]):
332
- accuracy, precision, recall, f1 = calc_metrics_for_col(df, col)
333
  new_model_metrics = {
334
  "epoch": i / 5,
335
  "model": col,
336
- "accuracy": accuracy,
337
- "precision": precision,
338
- "recall": recall,
339
- "f1": f1,
340
  }
 
341
 
342
  # Convert the dictionary to a DataFrame and concatenate it with the existing DataFrame
343
  perf_df = pd.concat(
 
95
 
96
 
97
  def extract_answer(text, debug=False):
98
+ if text and isinstance(text, str):
99
  # Remove the begin and end tokens
100
  text = re.sub(
101
  r".*?(assistant|\[/INST\]).+?\b",
 
117
  print("--------\nstep 3:", text)
118
 
119
  text = text.split(".")[0].strip()
120
+ text = text.split("。")[0].strip()
121
  if debug:
122
  print("--------\nstep 4:", text)
123
 
 
130
  if debug:
131
  print("--------\nstep 5:", text)
132
 
133
+ return text.strip()
134
+
135
+ return ""
136
 
137
 
138
  def calc_metrics(references, predictions, debug=False):
 
140
  predictions
141
  ), f"lengths are difference: {len(references)} != {len(predictions)}"
142
 
143
+ labels = np.unique(references)
144
+ valid_classifications = [1 if p in labels else 0 for p in predictions]
145
+
146
  predictions = [extract_answer(text) for text in predictions]
147
 
148
+ accuracy = accuracy_score(references, predictions)
 
149
 
150
  results = {"accuracy": accuracy}
151
  if debug:
152
+ incorrect_ids = [i for i, p in enumerate(predictions) if p != references[i]]
153
  results["incorrect_ids"] = incorrect_ids
154
 
155
+ precision = precision_score(
156
+ references, predictions, average="weighted", labels=labels
157
+ )
158
+ results["precision"] = float(precision)
159
+
160
+ recall = recall_score(references, predictions, average="weighted", labels=labels)
161
+ results["recall"] = float(recall)
162
+
163
+ f1 = f1_score(references, predictions, average="weighted", labels=labels)
164
+ results["f1"] = float(f1)
165
+
166
+ results["ratio_valid_classifications"] = sum(valid_classifications) / len(
167
+ valid_classifications
168
+ )
169
+
170
  return results
171
 
172
 
 
260
  rouge_l = []
261
  all_metrics = []
262
  for col in df.columns[2:]:
263
+ metrics = calc_metrics(df["label"], df[col], debug=True)
264
  print(f"{col}: {metrics}")
265
 
266
  accuracy.append(metrics["accuracy"])
 
310
  return df_alpaca
311
 
312
 
313
+ def plot_value_counts(df, column_name, offset=0.1, title=None, preprocess_func=None):
314
  font_family = rcParams["font.family"]
315
  # Set the font to SimHei to support Chinese characters
316
  rcParams["font.family"] = "STHeiti"
317
  rcParams["axes.unicode_minus"] = (
318
  False # This is to support the minus sign in Chinese.
319
  )
320
+ if preprocess_func:
321
+ df["backup"] = df[column_name]
322
+ df[column_name] = df[column_name].apply(preprocess_func)
323
 
324
  plt.figure(figsize=(12, 6))
325
+ df[column_name].value_counts().plot(kind="bar")
326
  # add values on top of bars
327
+ for i, v in enumerate(df[column_name].value_counts()):
328
+ plt.text(i, v + offset, str(v), ha="center")
329
 
330
+ plt.xlabel(title or column_name)
331
 
332
  plt.show()
333
 
334
  rcParams["font.family"] = font_family
335
 
336
+ if preprocess_func:
337
+ df[column_name] = df["backup"]
338
+ df.drop(columns=["backup"], inplace=True)
339
 
 
 
 
 
340
 
341
+ def calc_metrics_for_col(df, col):
342
+ metrics = calc_metrics(df["label"], df[col], debug=True)
343
+ return metrics["accuracy"], metrics["precision"], metrics["recall"], metrics["f1"]
 
 
 
344
 
345
 
346
  def get_metrics_df(df):
 
348
  columns=["epoch", "model", "accuracy", "precision", "recall", "f1"]
349
  )
350
  for i, col in enumerate(df.columns[5:]):
351
+ metrics = calc_metrics(df["label"], df[col], debug=False)
352
  new_model_metrics = {
353
  "epoch": i / 5,
354
  "model": col,
 
 
 
 
355
  }
356
+ new_model_metrics.update(metrics)
357
 
358
  # Convert the dictionary to a DataFrame and concatenate it with the existing DataFrame
359
  perf_df = pd.concat(
notebooks/00_Data Analysis.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/01_internlm2_5-7b-chat-1m_analysis.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
notebooks/01a_internlm2_5-7b-chat-1m_analysis.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/01b_Mistral-7B-v0.3-Chinese-Chat_analysis.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/02a_Qwen2-7B-Instruct_analysis.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/02b_Qwen2-72B-Instruct_analysis.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/03a_Llama3.1-8B-Chinese-Chat_analysis.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/03b_Llama3.1-70B-Chinese-Chat_analysis.ipynb ADDED
The diff for this file is too large to render. See raw diff