Sutirtha commited on
Commit
365f787
1 Parent(s): 6eacbe7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from color_matcher import ColorMatcher
3
+ from color_matcher.normalizer import Normalizer
4
+ import numpy as np
5
+ import cv2
6
+ from PIL import Image
7
+
8
+ # Function to apply color correction
9
+ def color_match(source_img, reference_img):
10
+ # Convert PIL images to OpenCV format (numpy arrays)
11
+ img_src = np.array(source_img)
12
+ img_ref = np.array(reference_img)
13
+
14
+ # Ensure images are in RGB format (3 channels)
15
+ if img_src.shape[2] == 4:
16
+ img_src = cv2.cvtColor(img_src, cv2.COLOR_RGBA2RGB)
17
+ if img_ref.shape[2] == 4:
18
+ img_ref = cv2.cvtColor(img_ref, cv2.COLOR_RGBA2RGB)
19
+
20
+ # Apply color matching
21
+ cm = ColorMatcher()
22
+ img_res = cm.transfer(src=img_src, ref=img_ref, method='mkl')
23
+
24
+ # Normalize the result
25
+ img_res = Normalizer(img_res).uint8_norm()
26
+
27
+ # Convert back to PIL for displaying in Gradio
28
+ img_res_pil = Image.fromarray(img_res)
29
+
30
+ return img_res_pil
31
+
32
+ # Gradio Interface
33
+ def gradio_interface():
34
+ # Define input and output components
35
+ inputs = [
36
+ gr.Image(type="pil", label="Source Image"),
37
+ gr.Image(type="pil", label="Reference Image")
38
+ ]
39
+ outputs = gr.Image(type="pil", label="Resulting Image")
40
+
41
+ # Launch Gradio app
42
+ gr.Interface(fn=color_match, inputs=inputs, outputs=outputs, title="Color Matching Tool").launch()
43
+
44
+ # Run the Gradio Interface
45
+ if __name__ == "__main__":
46
+ gradio_interface()