wahaha commited on
Commit
14baf72
1 Parent(s): f3cf058
Files changed (2) hide show
  1. app.py +6 -6
  2. test1.py +8 -7
app.py CHANGED
@@ -55,10 +55,10 @@ def run(
55
  ) -> tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
56
 
57
  im1 = shinkai.test('shinkai', image.name, True)
58
- im2 = hayao.test('hayao', image.name, True)
59
- im3 = paprika.test('paprika', image.name, True)
60
 
61
- return PIL.Image.open(im1),PIL.Image.open(im2),PIL.Image.open(im3)
62
 
63
 
64
  def main():
@@ -69,10 +69,10 @@ def main():
69
  curPath = os.path.abspath(os.path.dirname(__file__))
70
  #init
71
  shinkai = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Shinkai_weight'))
72
- hayao = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Hayao_weight'))
73
- paprika = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Paprika_weight'))
74
 
75
- func = functools.partial(run, shinkai=shinkai,hayao=hayao,paprika=paprika )
76
  func = functools.update_wrapper(func, run)
77
 
78
 
 
55
  ) -> tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
56
 
57
  im1 = shinkai.test('shinkai', image.name, True)
58
+ #im2 = hayao.test('hayao', image.name, True)
59
+ #im3 = paprika.test('paprika', image.name, True)
60
 
61
+ return PIL.Image.open(im1),PIL.Image.open(im1),PIL.Image.open(im1)
62
 
63
 
64
  def main():
 
69
  curPath = os.path.abspath(os.path.dirname(__file__))
70
  #init
71
  shinkai = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Shinkai_weight'))
72
+ #hayao = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Hayao_weight'))
73
+ #paprika = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Paprika_weight'))
74
 
75
+ func = functools.partial(run, shinkai=shinkai )
76
  func = functools.update_wrapper(func, run)
77
 
78
 
test1.py CHANGED
@@ -13,12 +13,17 @@ class ImportGraph:
13
  self.graph = tf.Graph()
14
  self.sess = tf.Session(graph=self.graph, config=tf.ConfigProto(allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True)))
15
  with self.graph.as_default():
16
- saver = tf.train.Saver()
 
 
 
 
 
17
 
18
  ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information
19
  if ckpt and ckpt.model_checkpoint_path:
20
  ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # first line
21
- saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
22
  print(" [*] Success to read {}".format(os.path.join(checkpoint_dir, ckpt_name)))
23
  else:
24
  print(" [*] Failed to find a checkpoint")
@@ -30,11 +35,7 @@ class ImportGraph:
30
  sample_image = np.asarray(load_test_data(sample_file, img_size))
31
  image_path = os.path.join(result_dir, '{0}'.format(os.path.basename(sample_file)))
32
 
33
- test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')
34
- with tf.variable_scope("generator", reuse=False):
35
- test_generated = generator.G_net(test_real).fake
36
-
37
- fake_img = self.sess.run(test_generated, feed_dict={test_real: sample_image})
38
  if if_adjust_brightness:
39
  save_images(fake_img, image_path, sample_file)
40
  else:
 
13
  self.graph = tf.Graph()
14
  self.sess = tf.Session(graph=self.graph, config=tf.ConfigProto(allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True)))
15
  with self.graph.as_default():
16
+
17
+ test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')
18
+ with tf.variable_scope("generator", reuse=False):
19
+ self.test_generated = generator.G_net(test_real).fake
20
+
21
+ self.saver = tf.train.Saver()
22
 
23
  ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information
24
  if ckpt and ckpt.model_checkpoint_path:
25
  ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # first line
26
+ self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
27
  print(" [*] Success to read {}".format(os.path.join(checkpoint_dir, ckpt_name)))
28
  else:
29
  print(" [*] Failed to find a checkpoint")
 
35
  sample_image = np.asarray(load_test_data(sample_file, img_size))
36
  image_path = os.path.join(result_dir, '{0}'.format(os.path.basename(sample_file)))
37
 
38
+ fake_img = self.sess.run(self.test_generated, feed_dict={test_real: sample_image})
 
 
 
 
39
  if if_adjust_brightness:
40
  save_images(fake_img, image_path, sample_file)
41
  else: