acmyu commited on
Commit
6035dc8
·
1 Parent(s): e0fbca6
Files changed (2) hide show
  1. evaluate.py +17 -17
  2. main.py +1 -1
evaluate.py CHANGED
@@ -131,7 +131,7 @@ def get_score(item, image_paths, video_path, metrics, train_steps=100, inference
131
  gt_frames.append(img)
132
  else:
133
  gt_frames = extract_frames(video_path, fps)
134
- gt_frames = gt_frames[:200]
135
  for f in gt_frames:
136
  f.thumbnail((512,512))
137
 
@@ -168,7 +168,7 @@ def get_score(item, image_paths, video_path, metrics, train_steps=100, inference
168
  psnr2.append(float(compute_psnr(gt, base)))
169
  lpips2.append(float(compute_lpips(gt, base)))
170
 
171
- if c<50:
172
  print(c)
173
  fid.append(float(compute_fid(gt, result)))
174
  fid2.append(float(compute_fid(gt, base)))
@@ -190,7 +190,7 @@ def get_score(item, image_paths, video_path, metrics, train_steps=100, inference
190
  print("FID:", sum(fid2)/len(fid2))
191
  print("FVD:", fvd2)
192
 
193
- metrics[item] = {'ft': {}, 'base': {}, 'frames': len(gt_frames), 'complexity': len(images)}
194
  metrics[item]['ft']['ssim'] = {'avg': sum(ssim)/len(ssim), 'vals': ssim}
195
  metrics[item]['ft']['psnr'] = {'avg': sum(psnr)/len(psnr), 'vals': psnr}
196
  metrics[item]['ft']['lpips'] = {'avg': sum(lpips)/len(lpips), 'vals': lpips}
@@ -238,20 +238,20 @@ def run_evaluate():
238
  continue
239
  print(item)
240
 
241
- #try:
242
- files = get_files('test/'+item)
243
- images = list(filter(lambda x: not x.endswith('.mp4'), files))
244
- images = ['test/'+item+'/'+img for img in images]
245
- videos = [x for x in files if x.endswith('.mp4')]
246
- print(images, videos)
247
-
248
- if len(videos) == 1:
249
- get_score(item, images, 'test/'+item+'/'+videos[0], metrics)
250
- #get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
251
- else:
252
- print('Error: mp4 not found')
253
- #except Exception as e:
254
- # print("Error", item, e)
255
 
256
 
257
  ssim = []
 
131
  gt_frames.append(img)
132
  else:
133
  gt_frames = extract_frames(video_path, fps)
134
+ gt_frames = gt_frames[:500]
135
  for f in gt_frames:
136
  f.thumbnail((512,512))
137
 
 
168
  psnr2.append(float(compute_psnr(gt, base)))
169
  lpips2.append(float(compute_lpips(gt, base)))
170
 
171
+ if c<100:
172
  print(c)
173
  fid.append(float(compute_fid(gt, result)))
174
  fid2.append(float(compute_fid(gt, base)))
 
190
  print("FID:", sum(fid2)/len(fid2))
191
  print("FVD:", fvd2)
192
 
193
+ metrics[item] = {'ft': {}, 'base': {}, 'n_frames': len(gt_frames), 'complexity': len(images)}
194
  metrics[item]['ft']['ssim'] = {'avg': sum(ssim)/len(ssim), 'vals': ssim}
195
  metrics[item]['ft']['psnr'] = {'avg': sum(psnr)/len(psnr), 'vals': psnr}
196
  metrics[item]['ft']['lpips'] = {'avg': sum(lpips)/len(lpips), 'vals': lpips}
 
238
  continue
239
  print(item)
240
 
241
+ try:
242
+ files = get_files('test/'+item)
243
+ images = list(filter(lambda x: not x.endswith('.mp4'), files))
244
+ images = ['test/'+item+'/'+img for img in images]
245
+ videos = [x for x in files if x.endswith('.mp4')]
246
+ print(images, videos)
247
+
248
+ if len(videos) == 1:
249
+ get_score(item, images, 'test/'+item+'/'+videos[0], metrics)
250
+ #get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
251
+ else:
252
+ print('Error: mp4 not found')
253
+ except Exception as e:
254
+ print("Error", item, e)
255
 
256
 
257
  ssim = []
main.py CHANGED
@@ -85,7 +85,7 @@ debug = False
85
  save_model = True
86
  should_gen_vid = False
87
  max_batch_size = 8
88
- max_frame_count = 200
89
 
90
  def save_temp_imgs(imgs):
91
  os.makedirs('temp', exist_ok=True)
 
85
  save_model = True
86
  should_gen_vid = False
87
  max_batch_size = 8
88
+ max_frame_count = 500
89
 
90
  def save_temp_imgs(imgs):
91
  os.makedirs('temp', exist_ok=True)