soiz1 commited on
Commit
fb02422
·
verified ·
1 Parent(s): 8873cff

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +61 -43
main.py CHANGED
@@ -1,58 +1,76 @@
1
- import argparse
2
-
3
- from PIL import Image
4
  import cv2
5
  import numpy as np
6
- from preprocess_image import preprocess_image
7
- import tensorflow as tf
8
  import neuralgym as ng
9
 
 
10
  from inpaint_model import InpaintCAModel
11
 
12
- parser = argparse.ArgumentParser()
13
- parser.add_argument('--image', default='', type=str,
14
- help='The filename of image to be completed.')
15
- parser.add_argument('--output', default='output.png', type=str,
16
- help='Where to write output.')
17
- parser.add_argument('--watermark_type', default='istock', type=str,
18
- help='The watermark type')
19
- parser.add_argument('--checkpoint_dir', default='model/', type=str,
20
- help='The directory of tensorflow checkpoint.')
21
 
22
- #checkpoint_dir = 'model/'
 
 
 
23
 
 
 
 
 
24
 
25
- if __name__ == "__main__":
26
  FLAGS = ng.Config('inpaint.yml')
27
- # ng.get_gpus(1)
28
- args, unknown = parser.parse_known_args()
29
 
30
- model = InpaintCAModel()
31
- image = Image.open(args.image)
32
- input_image = preprocess_image(image, args.watermark_type)
33
  tf.reset_default_graph()
 
34
 
 
35
  sess_config = tf.ConfigProto()
36
  sess_config.gpu_options.allow_growth = True
37
- if (input_image.shape != (0,)):
38
- with tf.Session(config=sess_config) as sess:
39
- input_image = tf.constant(input_image, dtype=tf.float32)
40
- output = model.build_server_graph(FLAGS, input_image)
41
- output = (output + 1.) * 127.5
42
- output = tf.reverse(output, [-1])
43
- output = tf.saturate_cast(output, tf.uint8)
44
- # load pretrained model
45
- vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
46
- assign_ops = []
47
- for var in vars_list:
48
- vname = var.name
49
- from_name = vname
50
- var_value = tf.contrib.framework.load_variable(
51
- args.checkpoint_dir, from_name)
52
- assign_ops.append(tf.assign(var, var_value))
53
- sess.run(assign_ops)
54
- print('Model loaded.')
55
- result = sess.run(output)
56
- cv2.imwrite(args.output, cv2.cvtColor(
57
- result[0][:, :, ::-1], cv2.COLOR_BGR2RGB))
58
- print('image saved to {}'.format(args.output))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
 
3
  import cv2
4
  import numpy as np
5
+ from PIL import Image
 
6
  import neuralgym as ng
7
 
8
+ from preprocess_image import preprocess_image
9
  from inpaint_model import InpaintCAModel
10
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # ===== Inpainting function ===== / ===== 画像修復処理関数 =====
13
+ def inpaint_image(input_image, watermark_type, checkpoint_dir):
14
+ # Convert image from Gradio (PIL format) / Gradioから受け取る画像をPIL形式で処理
15
+ image = input_image.convert("RGB")
16
 
17
+ # Preprocessing / 前処理
18
+ input_image = preprocess_image(image, watermark_type)
19
+ if input_image.shape == (0,):
20
+ return None
21
 
22
+ # Load configuration file / 設定ファイルの読み込み
23
  FLAGS = ng.Config('inpaint.yml')
 
 
24
 
25
+ # Reset TensorFlow graph / TensorFlowグラフをリセット
 
 
26
  tf.reset_default_graph()
27
+ model = InpaintCAModel()
28
 
29
+ # GPU configuration / GPU設定
30
  sess_config = tf.ConfigProto()
31
  sess_config.gpu_options.allow_growth = True
32
+
33
+ # Start TensorFlow session / TensorFlowセッション開始
34
+ with tf.Session(config=sess_config) as sess:
35
+ # Create tensor from image / 画像をテンソルに変換
36
+ input_image_tensor = tf.constant(input_image, dtype=tf.float32)
37
+
38
+ # Build the model graph / モデルグラフを構築
39
+ output = model.build_server_graph(FLAGS, input_image_tensor)
40
+ output = (output + 1.) * 127.5
41
+ output = tf.reverse(output, [-1])
42
+ output = tf.saturate_cast(output, tf.uint8)
43
+
44
+ # Load model variables / モデル変数を読み込み
45
+ vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
46
+ assign_ops = []
47
+ for var in vars_list:
48
+ from_name = var.name
49
+ var_value = tf.contrib.framework.load_variable(checkpoint_dir, from_name)
50
+ assign_ops.append(tf.assign(var, var_value))
51
+ sess.run(assign_ops)
52
+
53
+ print('Model loaded.') # モデルの読み込み完了
54
+ result = sess.run(output)
55
+ result_img = result[0][:, :, ::-1] # Convert BGR to RGB / BGRからRGBに変換
56
+
57
+ # Convert numpy array to PIL image / numpy配列をPIL画像に変換
58
+ return Image.fromarray(result_img)
59
+
60
+
61
+ # ===== Gradio User Interface ===== / ===== Gradioユーザーインターフェース =====
62
+ iface = gr.Interface(
63
+ fn=inpaint_image,
64
+ inputs=[
65
+ gr.Image(label="Input Image / 入力画像", type="pil"),
66
+ gr.Radio(["istock", "other"], label="Watermark Type / ウォーターマークタイプ", value="istock"),
67
+ gr.Textbox(label="Checkpoint Directory / チェックポイントディレクトリ", value="model/")
68
+ ],
69
+ outputs=gr.Image(label="Inpainted Image / 修復済み画像"),
70
+ title="Watermark Inpainting Model / ウォーターマーク除去モデル",
71
+ description="Upload an image to remove the watermark using a TensorFlow model. / TensorFlowモデルを使用してウォーターマークを除去します。",
72
+ )
73
+
74
+ # Run the app / アプリを起動
75
+ if __name__ == "__main__":
76
+ iface.launch()