abhishekrs4 commited on
Commit
a769b4d
1 Parent(s): fc91a53

added fastapi app and config scripts

Browse files
Files changed (2) hide show
  1. app.py +147 -0
  2. config.py +11 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import json
3
+ import torch
4
+ import base64
5
+ import logging
6
+ import numpy as np
7
+ from fastapi import FastAPI, File, UploadFile, Form
8
+
9
+ from config import settings
10
+
11
+ from image_colourization_cgan.image_utils import *
12
+ from image_colourization_cgan.model import ImageToImageConditionalGAN
13
+
14
+
15
+ def activate_dropout(m):
16
+ if type(m) == torch.nn.Dropout:
17
+ m.train()
18
+ return
19
+
20
+
21
+ app = FastAPI()
22
+ logging.basicConfig(level=logging.INFO)
23
+
24
+ device = settings.device
25
+ image_size = settings.image_size
26
+
27
+ file_model_local = f"./artifacts/colorizer_cgan_90.pt"
28
+ file_model_cont = f"/data/models/colorizer_cgan_90.pt"
29
+
30
+ colour_model_cgan = ImageToImageConditionalGAN(device)
31
+ colour_model_cgan.eval()
32
+
33
+ try:
34
+ logging.info(f"loading model from {file_model_local}")
35
+ colour_model_cgan.load_state_dict(torch.load(file_model_local, map_location=device))
36
+ except:
37
+ logging.info(f"loading model from {file_model_cont}")
38
+ colour_model_cgan.load_state_dict(torch.load(file_model_cont, map_location=device))
39
+ colour_model_cgan.to(device)
40
+ colour_model_cgan.net_gen.apply(activate_dropout)
41
+
42
+
43
+ def get_prediction(img_arr: np.ndarray) -> np.ndarray:
44
+ """
45
+ ---------
46
+ Arguments
47
+ ---------
48
+ img_arr: ndarray
49
+ a numpy array of the image
50
+
51
+ -------
52
+ Returns
53
+ -------
54
+ img_gen_rgb : ndarray
55
+ a numpy representing the generated colourized image
56
+ """
57
+
58
+ img_gray_resized = resize_image(img_arr, (image_size, image_size))
59
+ # resized grayscale is in [0, 1]
60
+
61
+ img_l = rescale_grayscale_image_l_channel(img_gray_resized)
62
+ # L channel is in [0, 100]
63
+
64
+ # apply pre-processing on L channel image
65
+ img_l_preprocessed = apply_image_l_pre_processing(img_l)
66
+
67
+ # repeat L channel 3 times because ResNet needs a 3 channel input
68
+ img_l_preprocessed = np.repeat(
69
+ np.expand_dims(img_l_preprocessed, axis=-1), 3, axis=-1
70
+ )
71
+ img_l_preprocessed = np.expand_dims(img_l_preprocessed, axis=0)
72
+
73
+ # NCHW format
74
+ img_l_preprocessed = np.transpose(img_l_preprocessed, (0, 3, 1, 2))
75
+
76
+ img_l_tensor = torch.tensor(img_l_preprocessed).float()
77
+ img_l_tensor = img_l_tensor.to(device, dtype=torch.float)
78
+
79
+ gen_img_ab_tensor = colour_model_cgan.net_gen(img_l_tensor)
80
+ gen_img_ab = gen_img_ab_tensor.detach().cpu().numpy()
81
+ gen_img_ab = np.squeeze(gen_img_ab)
82
+ gen_img_ab = np.transpose(gen_img_ab, [1, 2, 0])
83
+ gen_img_ab_postprocessed = apply_image_ab_post_processing(gen_img_ab)
84
+
85
+ # concat L and Generator network generated ab channels
86
+ gen_img_lab = np.concatenate(
87
+ (np.expand_dims(img_l, axis=-1), gen_img_ab_postprocessed), axis=-1
88
+ )
89
+ # convert Lab to RGB
90
+ img_gen_rgb = convert_lab2rgb(gen_img_lab)
91
+ img_gen_rgb = img_gen_rgb * 255
92
+ img_gen_rgb = img_gen_rgb.astype(np.uint8)
93
+
94
+ return img_gen_rgb
95
+
96
+
97
+ @app.get("/info")
98
+ def get_app_info() -> dict:
99
+ """
100
+ -------
101
+ Returns
102
+ -------
103
+ dict_info : dict
104
+ a dictionary with info to be sent as a response to get request
105
+ """
106
+ dict_info = {"app_name": settings.app_name, "version": settings.version}
107
+ return dict_info
108
+
109
+
110
+ @app.post("/predict")
111
+ def _file_upload(image_file: UploadFile = File(...)) -> dict:
112
+ """
113
+ ---------
114
+ Arguments
115
+ ---------
116
+ image_file: object
117
+ an object of type UploadFile
118
+
119
+ -------
120
+ Returns
121
+ -------
122
+ response_json : dict
123
+ a dict as a response json for the post request
124
+ """
125
+ try:
126
+ # if the file is sent via post request with open()
127
+ img_str = image_file.file.read()
128
+ img_decoded = cv2.imdecode(np.frombuffer(img_str, np.uint8), 0)
129
+ except:
130
+ # if the file is sent via post request from streamlit
131
+ img_decoded = cv2.imdecode(np.frombuffer(image_file.getvalue(), np.uint8), 0)
132
+
133
+ logging.info(image_file)
134
+
135
+ img_gen_rgb = get_prediction(img_decoded)
136
+ image_sum = np.sum(img_gen_rgb)
137
+ logging.info(f"image_sum: {image_sum}")
138
+ _, img_encoded = cv2.imencode(".PNG", img_gen_rgb)
139
+ img_encoded = base64.b64encode(img_encoded)
140
+
141
+ response_json = {
142
+ "name": image_file.filename,
143
+ "image_sum": str(image_sum),
144
+ "encoded_image": img_encoded,
145
+ }
146
+ # logging.info(response_json)
147
+ return response_json
config.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic_settings import BaseSettings
2
+
3
+
4
+ class Settings(BaseSettings):
5
+ app_name: str = "CGAN Image Colourization API"
6
+ version: str = "2024.04.15"
7
+ image_size: int = 320
8
+ device: str = "cpu"
9
+
10
+
11
+ settings = Settings()