Adeboye Akinlolu commited on
Commit
2e3528d
·
1 Parent(s): 317bfc7
Files changed (8) hide show
  1. .gitignore +1 -0
  2. Dockerfile +15 -0
  3. README.md +4 -4
  4. __init__.py +0 -0
  5. app.py +74 -0
  6. image_enhancer.py +123 -0
  7. requirements.txt +13 -0
  8. video_enhancer.py +69 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ /gg
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /app
4
+
5
+ # Copy the current directory contents into the container at /app
6
+ COPY . /app
7
+
8
+ # Install any needed packages specified in requirements.txt
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+
11
+ # Make port 80 available to the world outside this container
12
+ EXPOSE 7860
13
+
14
+ # Run app.py when the container launches
15
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Glam Ai
3
- emoji: 😻
4
- colorFrom: green
5
- colorTo: purple
6
  sdk: docker
7
  pinned: false
8
  license: mit
 
1
  ---
2
+ title: Media Enhancer
3
+ emoji: 🐢
4
+ colorFrom: yellow
5
+ colorTo: blue
6
  sdk: docker
7
  pinned: false
8
  license: mit
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import StreamingResponse
3
+ from image_enhancer import EnhancementMethod, Enhancer
4
+ from video_enhancer import VideoEnhancer
5
+ from pydantic import BaseModel
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ import base64
9
+ import magic
10
+ from typing import List
11
+ class EnhancementRequest(BaseModel):
12
+ method: EnhancementMethod = EnhancementMethod.gfpgan
13
+ background_enhancement: bool = True
14
+ upscale: int = 2
15
+
16
+ class _EnhanceBase(BaseModel):
17
+ encoded_base_img: List[str]
18
+
19
+
20
+
21
+ app = FastAPI()
22
+
23
+ @app.get("/")
24
+ def greet_json():
25
+ return {"Initializing GlamApp Enhancer"}
26
+
27
+ @app.post("/enhance/image/")
28
+ async def enhance_image(
29
+ file: UploadFile = File(...),
30
+ request: EnhancementRequest = EnhancementRequest()
31
+ ):
32
+ try:
33
+ if not file.content_type.startswith('image/'):
34
+ raise HTTPException(status_code=400, detail="Invalid file type")
35
+
36
+ contents = await file.read()
37
+ base64_encoded_image = base64.b64encode(contents).decode('utf-8')
38
+ data = _EnhanceBase(encoded_base_img=[base64_encoded_image])
39
+
40
+ enhancer = Enhancer(request.method, request.background_enhancement, request.upscale)
41
+
42
+ enhanced_img, original_resolution, enhanced_resolution = await enhancer.enhance(data)
43
+
44
+ enhanced_image = Image.fromarray(enhanced_img)
45
+ img_byte_arr = BytesIO()
46
+ enhanced_image.save(img_byte_arr, format='PNG')
47
+ img_byte_arr.seek(0)
48
+ print(original_resolution, enhanced_resolution)
49
+
50
+ return StreamingResponse(img_byte_arr, media_type="image/png")
51
+
52
+ except Exception as e:
53
+ raise HTTPException(status_code=500, detail=str(e))
54
+
55
+ @app.post("/enhance/video/")
56
+ async def enhance_video(file: UploadFile = File(...)):
57
+ enhancer = VideoEnhancer()
58
+ file_header = await file.read(1024)
59
+ file.file.seek(0)
60
+ mime = magic.Magic(mime=True)
61
+ file_mime_type = mime.from_buffer(file_header)
62
+
63
+ accepted_mime_types = [
64
+ 'video/mp4',
65
+ 'video/mpeg',
66
+ 'video/x-msvideo',
67
+ 'video/quicktime',
68
+ 'video/x-matroska',
69
+ 'video/webm'
70
+ ]
71
+
72
+ if file_mime_type not in accepted_mime_types:
73
+ raise HTTPException(status_code=400, detail="Invalid file type. Please upload a video file.")
74
+ return await enhancer.stream_enhanced_video(file)
image_enhancer.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from gfpgan import GFPGANer
4
+ from tqdm import tqdm
5
+ import cv2
6
+ from realesrgan import RealESRGANer
7
+ from basicsr.archs.rrdbnet_arch import RRDBNet
8
+ import warnings
9
+ from enum import Enum
10
+
11
+ class EnhancementMethod(str, Enum):
12
+ gfpgan = "gfpgan"
13
+ RestoreFormer = "RestoreFormer"
14
+ codeformer = "codeformer"
15
+ realesrgan = "realesrgan"
16
+
17
+
18
+ class Enhancer:
19
+ def __init__(self, method: EnhancementMethod, background_enhancement=True, upscale=2):
20
+ self.method = method
21
+ self.background_enhancement = background_enhancement
22
+ self.upscale = upscale
23
+ self.bg_upsampler = None
24
+ self.realesrgan_enhancer = None
25
+
26
+ if self.method != EnhancementMethod.realesrgan:
27
+ self.setup_face_enhancer()
28
+ if self.background_enhancement:
29
+ self.setup_background_enhancer()
30
+ else:
31
+ self.setup_realesrgan_enhancer()
32
+
33
+ def setup_background_enhancer(self):
34
+ if not torch.cuda.is_available():
35
+ warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it.')
36
+ return
37
+
38
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale)
39
+ model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x{self.upscale}plus.pth'
40
+ self.bg_upsampler = RealESRGANer(
41
+ scale=self.upscale,
42
+ model_path=model_path,
43
+ model=model,
44
+ tile=400,
45
+ tile_pad=10,
46
+ pre_pad=0,
47
+ half=True)
48
+
49
+ def setup_realesrgan_enhancer(self):
50
+ if not torch.cuda.is_available():
51
+ raise ValueError('CUDA is not available for RealESRGAN')
52
+
53
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale)
54
+ model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x{self.upscale}plus.pth'
55
+ self.realesrgan_enhancer = RealESRGANer(
56
+ scale=self.upscale,
57
+ model_path=model_path,
58
+ model=model,
59
+ tile=400,
60
+ tile_pad=10,
61
+ pre_pad=0,
62
+ half=True)
63
+
64
+ def setup_face_enhancer(self):
65
+ model_configs = {
66
+ EnhancementMethod.gfpgan: {
67
+ 'arch': 'clean',
68
+ 'channel_multiplier': 2,
69
+ 'model_name': 'GFPGANv1.4',
70
+ 'url': 'https://huggingface.co/gmk123/GFPGAN/resolve/main/GFPGANv1.4.pth'
71
+ },
72
+ EnhancementMethod.RestoreFormer: {
73
+ 'arch': 'RestoreFormer',
74
+ 'channel_multiplier': 2,
75
+ 'model_name': 'RestoreFormer',
76
+ 'url': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
77
+ },
78
+ EnhancementMethod.codeformer: {
79
+ 'arch': 'CodeFormer',
80
+ 'channel_multiplier': 2,
81
+ 'model_name': 'CodeFormer',
82
+ 'url': 'https://huggingface.co/sinadi/aar/resolve/main/codeformer.pth'
83
+ }
84
+ }
85
+
86
+ config = model_configs.get(self.method)
87
+ if not config:
88
+ raise ValueError(f'Wrong model version {self.method}')
89
+
90
+ model_path = os.path.join('gfpgan/weights', config['model_name'] + '.pth')
91
+ if not os.path.isfile(model_path):
92
+ model_path = os.path.join('checkpoints', config['model_name'] + '.pth')
93
+ if not os.path.isfile(model_path):
94
+ model_path = config['url']
95
+
96
+ self.face_enhancer = GFPGANer(
97
+ model_path=model_path,
98
+ upscale=self.upscale,
99
+ arch=config['arch'],
100
+ channel_multiplier=config['channel_multiplier'],
101
+ bg_upsampler=self.bg_upsampler)
102
+
103
+ def check_image_resolution(self, image):
104
+ height, width, _ = image.shape
105
+ return width, height
106
+
107
+ async def enhance(self, image):
108
+ img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
109
+ width, height = self.check_image_resolution(img)
110
+
111
+ if self.method == EnhancementMethod.realesrgan:
112
+ enhanced_img, _ = await asyncio.to_thread(self.realesrgan_enhancer.enhance, img, outscale=self.upscale)
113
+ else:
114
+ _, _, enhanced_img = await asyncio.to_thread(self.face_enhancer.enhance,
115
+ img,
116
+ has_aligned=False,
117
+ only_center_face=False,
118
+ paste_back=True)
119
+
120
+ enhanced_img = cv2.cvtColor(enhanced_img, cv2.COLOR_BGR2RGB)
121
+ enhanced_width, enhanced_height = self.check_image_resolution(enhanced_img)
122
+
123
+ return enhanced_img, (width, height), (enhanced_width, enhanced_height)
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ gfpgan==1.3.8
4
+ realesrgan==0.3.0
5
+ pillow==10.3.0
6
+ pydantic==2.7.1
7
+ pydantic-settings==2.0.3
8
+ pydantic_core==2.18.2
9
+ requests==2.31.0
10
+ basicsr
11
+ huggingface-hub==0.25.1
12
+ numpy==1.26.4
13
+ facexlib==0.3.0
video_enhancer.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import io
5
+ import asyncio
6
+ from fastapi.responses import StreamingResponse
7
+ from basicsr.archs.rrdbnet_arch import RRDBNet
8
+ from realesrgan import RealESRGANer
9
+ from huggingface_hub import hf_hub_download
10
+ from concurrent.futures import ThreadPoolExecutor
11
+
12
+ class VideoEnhancer:
13
+ def __init__(self, model_name="RealESRGAN_x4plus"):
14
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ self.model = self.load_model(model_name)
16
+ self.executor = ThreadPoolExecutor(max_workers=4)
17
+
18
+ def load_model(self, model_name):
19
+ if model_name == "RealESRGAN_x4plus":
20
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
21
+ model_path = hf_hub_download("schwgHao/RealESRGAN_x4plus", "RealESRGAN_x4plus.pth")
22
+ return RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=True)
23
+ else:
24
+ raise ValueError(f"Unsupported model: {model_name}")
25
+
26
+ async def enhance_frame(self, frame):
27
+ loop = asyncio.get_running_loop()
28
+
29
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
30
+
31
+ enhanced, _ = await loop.run_in_executor(self.executor, self.model.enhance, frame_rgb)
32
+
33
+ return cv2.cvtColor(enhanced, cv2.COLOR_RGB2BGR)
34
+
35
+ async def process_video(self, input_bytes, output_bytes):
36
+ cap = cv2.VideoCapture(input_bytes)
37
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
38
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
39
+ fps = cap.get(cv2.CAP_PROP_FPS)
40
+
41
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
42
+ out = cv2.VideoWriter(output_bytes, fourcc, fps, (width * 4, height * 4))
43
+
44
+ while cap.isOpened():
45
+ ret, frame = cap.read()
46
+ if not ret:
47
+ break
48
+ enhanced_frame = await self.enhance_frame(frame)
49
+ out.write(enhanced_frame)
50
+
51
+ cap.release()
52
+ out.release()
53
+
54
+ async def stream_enhanced_video(self, video_file):
55
+ video_bytes = await video_file.read()
56
+ cap = cv2.VideoCapture(io.BytesIO(video_bytes).getvalue())
57
+
58
+ async def generate():
59
+ while cap.isOpened():
60
+ ret, frame = cap.read()
61
+ if not ret:
62
+ break
63
+ enhanced_frame = await self.enhance_frame(frame)
64
+ _, buffer = cv2.imencode('.jpg', enhanced_frame)
65
+ yield (b'--frame\r\n'
66
+ b'Content-Type: image/jpeg\r\n\r\n' + buffer.tobytes() + b'\r\n')
67
+ cap.release()
68
+
69
+ return StreamingResponse(generate(), media_type="multipart/x-mixed-replace; boundary=frame")