wahaha commited on
Commit
d67e05e
1 Parent(s): c9dfbd7
Files changed (2) hide show
  1. app.py +9 -20
  2. test1.py +28 -26
app.py CHANGED
@@ -45,20 +45,14 @@ def parse_args() -> argparse.Namespace:
45
  return parser.parse_args()
46
 
47
 
48
-
49
-
50
  def run(
51
  image,
52
- shinkai: ImportGraph,
53
- hayao: ImportGraph,
54
- paprika: ImportGraph,
55
- ) -> tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
56
 
57
- im1 = shinkai.test('shinkai', image.name, True)
58
- #im2 = hayao.test('hayao', image.name, True)
59
- #im3 = paprika.test('paprika', image.name, True)
60
 
61
- return PIL.Image.open(im1),PIL.Image.open(im1),PIL.Image.open(im1)
62
 
63
 
64
  def main():
@@ -66,13 +60,13 @@ def main():
66
 
67
  args = parse_args()
68
 
69
- curPath = os.path.abspath(os.path.dirname(__file__))
70
  #init
71
- shinkai = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Shinkai_weight'))
72
  #hayao = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Hayao_weight'))
73
  #paprika = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Paprika_weight'))
74
 
75
- func = functools.partial(run, shinkai=shinkai ,hayao=None,paprika=None )
76
  func = functools.update_wrapper(func, run)
77
 
78
 
@@ -84,13 +78,8 @@ def main():
84
  [
85
  gr.outputs.Image(
86
  type='pil',
87
- label='Shinkai Result'),
88
- gr.outputs.Image(
89
- type='pil',
90
- label='Hayao Result'),
91
- gr.outputs.Image(
92
- type='pil',
93
- label='Paprika Result'),
94
  ],
95
  #examples=examples,
96
  theme=args.theme,
 
45
  return parser.parse_args()
46
 
47
 
 
 
48
  def run(
49
  image,
50
+ ) -> tuple[PIL.Image.Image]:
 
 
 
51
 
52
+ out = test.test(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Shinkai_weight'),
53
+ style_name='Shinkai', test_file=image.name, if_adjust_brightness=True)
 
54
 
55
+ return PIL.Image.open(out)
56
 
57
 
58
  def main():
 
60
 
61
  args = parse_args()
62
 
63
+ #curPath = os.path.abspath(os.path.dirname(__file__))
64
  #init
65
+ #shinkai = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Shinkai_weight'))
66
  #hayao = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Hayao_weight'))
67
  #paprika = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Paprika_weight'))
68
 
69
+ func = functools.partial(run)
70
  func = functools.update_wrapper(func, run)
71
 
72
 
 
78
  [
79
  gr.outputs.Image(
80
  type='pil',
81
+ label='Result'),
82
+
 
 
 
 
 
83
  ],
84
  #examples=examples,
85
  theme=args.theme,
test1.py CHANGED
@@ -53,23 +53,27 @@ def stats_graph(graph):
53
  # params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
54
  print('FLOPs: {}'.format(flops.total_float_ops))
55
 
56
- def test(checkpoint_dir, style_name, test_dir, if_adjust_brightness, img_size=[256,256]):
 
 
 
 
 
 
57
  # tf.reset_default_graph()
58
  result_dir = 'results/'+style_name
59
  check_folder(result_dir)
60
- test_files = [test_dir]
61
 
62
- test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')
 
63
 
64
- with tf.variable_scope("generator", reuse=False):
65
- test_generated = generator.G_net(test_real).fake
66
- saver = tf.train.Saver()
67
 
68
- out_paths = []
 
69
 
70
- gpu_options = tf.GPUOptions(allow_growth=True)
71
- with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)) as sess:
72
- # tf.global_variables_initializer().run()
73
  # load model
74
  ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information
75
  if ckpt and ckpt.model_checkpoint_path:
@@ -81,22 +85,20 @@ def test(checkpoint_dir, style_name, test_dir, if_adjust_brightness, img_size=[2
81
  return
82
  # stats_graph(tf.get_default_graph())
83
 
84
- begin = time.time()
85
- for sample_file in tqdm(test_files) :
86
- # print('Processing image: ' + sample_file)
87
- sample_image = np.asarray(load_test_data(sample_file, img_size))
88
- image_path = os.path.join(result_dir,'{0}'.format(os.path.basename(sample_file)))
89
- fake_img = sess.run(test_generated, feed_dict = {test_real : sample_image})
90
- if if_adjust_brightness:
91
- save_images(fake_img, image_path, sample_file)
92
- else:
93
- save_images(fake_img, image_path, None)
94
-
95
- out_paths.append(image_path)
96
- end = time.time()
97
- print(f'test-time: {end-begin} s')
98
-
99
- return out_paths
100
 
101
  if __name__ == '__main__':
102
  arg = parse_args()
 
53
  # params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
54
  print('FLOPs: {}'.format(flops.total_float_ops))
55
 
56
+ g_sess = None
57
+ test_generated = None
58
+
59
+ def test(checkpoint_dir, style_name, test_file, if_adjust_brightness, img_size=[256,256]):
60
+ global g_sess
61
+ global test_generated
62
+
63
  # tf.reset_default_graph()
64
  result_dir = 'results/'+style_name
65
  check_folder(result_dir)
 
66
 
67
+ if g_sess is None:
68
+ test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')
69
 
70
+ with tf.variable_scope("generator", reuse=False):
71
+ test_generated = generator.G_net(test_real).fake
72
+ saver = tf.train.Saver()
73
 
74
+ gpu_options = tf.GPUOptions(allow_growth=True)
75
+ g_sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options))
76
 
 
 
 
77
  # load model
78
  ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information
79
  if ckpt and ckpt.model_checkpoint_path:
 
85
  return
86
  # stats_graph(tf.get_default_graph())
87
 
88
+ begin = time.time()
89
+ # print('Processing image: ' + sample_file)
90
+ sample_image = np.asarray(load_test_data(test_file, img_size))
91
+ image_path = os.path.join(result_dir,'{0}'.format(os.path.basename(test_file)))
92
+ fake_img = g_sess.run(test_generated, feed_dict = {test_real : sample_image})
93
+ if if_adjust_brightness:
94
+ save_images(fake_img, image_path, test_file)
95
+ else:
96
+ save_images(fake_img, image_path, None)
97
+
98
+ end = time.time()
99
+ print(f'test-time: {end-begin} s')
100
+
101
+ return image_path
 
 
102
 
103
  if __name__ == '__main__':
104
  arg = parse_args()