YiDuo1999 commited on
Commit
dc52a82
·
verified ·
1 Parent(s): d58e468

Update cases_collect.py

Browse files
Files changed (1) hide show
  1. cases_collect.py +2 -72
cases_collect.py CHANGED
@@ -21,12 +21,8 @@ def valid_results_collect(model_path,valid_data,task):
21
  torch.cuda.ipc_collect()
22
  # multiprocessing.set_start_method('spawn')
23
  trained_model=LLM(model=model_path,gpu_memory_utilization=0.95)
24
-
25
  start_t=time.time()
26
- if task=='sql':
27
- failed_cases,correct_cases=sql_evaluation(trained_model,valid_data)
28
- elif task=='nli':
29
- failed_cases,correct_cases=nli_evaluation(trained_model,valid_data)
30
  del trained_model
31
  end_t=time.time()
32
  print('time',start_t-end_t)
@@ -34,9 +30,6 @@ def valid_results_collect(model_path,valid_data,task):
34
  torch.cuda.empty_cache()
35
  torch.cuda.ipc_collect()
36
  torch.cuda.synchronize()
37
- #torch.cuda.synchronize()
38
- #torch.cuda.empty_cache()
39
- #torch.cuda.synchronize()
40
  time.sleep(10)
41
  return failed_cases,correct_cases
42
  def extract_answer_prediction_nli(predicted_output):
@@ -58,7 +51,6 @@ def process_batch(data_batch,trained_model,failed_cases,correct_cases):
58
  batch_prompts = [data['Input'] for data in data_batch]
59
  outputs = trained_model.generate(batch_prompts, sampling_params)
60
 
61
- results = []
62
  labels=['entailment','contradiction','neutral']
63
  for data, output in zip(data_batch, outputs):
64
  # pdb.set_trace()
@@ -70,9 +62,6 @@ def process_batch(data_batch,trained_model,failed_cases,correct_cases):
70
  # pdb.set_trace()
71
 
72
  predicted_res=predicted_output
73
- # print(label,predicted_output) # if 'contradiction #label_transform(data['Output'])
74
- # pdb.set_trace()
75
- # print(predicted_res,label,'\n')
76
  non_labels = [lbl for lbl in labels if lbl != label]
77
  if label not in predicted_res or any(non_label in predicted_res for non_label in non_labels):
78
  failed_cases.append((data['Input'],predicted_res,label,data))
@@ -80,69 +69,10 @@ def process_batch(data_batch,trained_model,failed_cases,correct_cases):
80
  correct_cases.append((data['Input'],predicted_res,label,data))
81
  return failed_cases,correct_cases
82
  def nli_evaluation(trained_model,valid_data):
83
- id=0
84
  failed_cases=[]
85
  correct_cases=[]
86
  batch_size=500
87
  batched_data = [valid_data[i:i+batch_size] for i in range(0, len(valid_data), batch_size)]
88
  for batch in batched_data:
89
  failed_cases,correct_cases=process_batch(batch,trained_model,failed_cases,correct_cases)
90
-
91
- #for data in valid_data:
92
- # prompt=data['Input']
93
- # output=trained_model.generate(prompt, sampling_params)
94
- # predicted_output=output[0].outputs[0].text
95
- # predicted_res=extract_answer_prediction_nli(predicted_output) #$try:
96
- # # predicted_res=extract_answer(predicted_output.split('final')[-1].split('is')[1].split('.')[0])
97
- #except:
98
- # predicted_res=extract_answer(predicted_output.split('is')[-1])
99
- # label=extract_answer(data['Output'].split('is')[-1])
100
- # print(label,predicted_res)
101
- # if not predicted_res:
102
- # pdb.set_trace()
103
- # predicted_res=''
104
- # if 'contradiction #label_transform(data['Output'])
105
- # pdb.set_trace()
106
- # if label not in predicted_res:
107
- # failed_cases.append((id,prompt,predicted_res,label,data))
108
- # else:
109
- # correct_cases.append((id,prompt,predicted_res,label,data))
110
- # id+=1
111
- #id,prompt,prior_pred+predicted_sql,valid_data[id],ground_truth,predicted_res,ground_truth_res
112
- return failed_cases,correct_cases
113
- def sql_evaluation(trained_model,valid_data):
114
- id=0
115
- failed_cases=[]
116
- correct_cases=[]
117
- for triple in valid_data:
118
-
119
- db_id,prompt,ground_truth=triple
120
- prompt=prompt.replace('SELECT','')
121
- db_path='/dccstor/obsidian_llm/yiduo/AgentBench/DAMO-ConvAI/bird/data/train/train_databases/{0}/{0}.sqlite'.format(db_id)
122
- prompt+=' To generate the SQL query to' #print(db_path) #pdb.set_trace()
123
- conn = sqlite3.connect(db_path)
124
- output=trained_model.generate(prompt, sampling_params) #pdb.set_trace()
125
- predicted_sql = output[0].outputs[0].text
126
- #pdb.set_trace()
127
- prior_pred=predicted_sql.split('final SQL')[0]
128
- try:
129
- predicted_sql = predicted_sql.split('final SQL')[1].strip()
130
- except:
131
- predicted_sql = 'SELECT'+predicted_sql.split('SELECT')[1]
132
- predicted_sql=predicted_sql.split(';')[0]
133
- predicted_sql=predicted_sql[predicted_sql.find('SELECT'):] #[1:]
134
- cursor = conn.cursor()
135
- # pdb.set_trace()
136
- try:
137
- cursor.execute(predicted_sql)
138
- predicted_res = cursor.fetchall()
139
- cursor.execute(ground_truth)
140
- ground_truth_res = cursor.fetchall()
141
- #print('results',predicted_res,'truth',ground_truth_res,'\n')
142
- if set(predicted_res) != set(ground_truth_res):
143
- failed_cases.append((id,prompt,prior_pred+predicted_sql,valid_data[id],ground_truth,predicted_res,ground_truth_res))
144
- else:
145
- correct_cases.append((id,prompt,prior_pred+predicted_sql,valid_data[id],ground_truth,predicted_res,ground_truth_res))
146
- except Exception as e:
147
- failed_cases.append((id,prompt,predicted_sql,valid_data[id],ground_truth,str(Exception)+str(e)))
148
- return failed_cases,correct_cases
 
21
  torch.cuda.ipc_collect()
22
  # multiprocessing.set_start_method('spawn')
23
  trained_model=LLM(model=model_path,gpu_memory_utilization=0.95)
 
24
  start_t=time.time()
25
+ failed_cases,correct_cases=nli_evaluation(trained_model,valid_data)
 
 
 
26
  del trained_model
27
  end_t=time.time()
28
  print('time',start_t-end_t)
 
30
  torch.cuda.empty_cache()
31
  torch.cuda.ipc_collect()
32
  torch.cuda.synchronize()
 
 
 
33
  time.sleep(10)
34
  return failed_cases,correct_cases
35
  def extract_answer_prediction_nli(predicted_output):
 
51
  batch_prompts = [data['Input'] for data in data_batch]
52
  outputs = trained_model.generate(batch_prompts, sampling_params)
53
 
 
54
  labels=['entailment','contradiction','neutral']
55
  for data, output in zip(data_batch, outputs):
56
  # pdb.set_trace()
 
62
  # pdb.set_trace()
63
 
64
  predicted_res=predicted_output
 
 
 
65
  non_labels = [lbl for lbl in labels if lbl != label]
66
  if label not in predicted_res or any(non_label in predicted_res for non_label in non_labels):
67
  failed_cases.append((data['Input'],predicted_res,label,data))
 
69
  correct_cases.append((data['Input'],predicted_res,label,data))
70
  return failed_cases,correct_cases
71
  def nli_evaluation(trained_model,valid_data):
 
72
  failed_cases=[]
73
  correct_cases=[]
74
  batch_size=500
75
  batched_data = [valid_data[i:i+batch_size] for i in range(0, len(valid_data), batch_size)]
76
  for batch in batched_data:
77
  failed_cases,correct_cases=process_batch(batch,trained_model,failed_cases,correct_cases)
78
+ return failed_cases,correct_cases