lsatone / backend /compact_ai_models.py
3v324v23's picture
Update Comic123 with local comic folder files
83e35a7
"""
Compact AI Models for <1GB VRAM Usage
SwinIR Lightweight & Compact Real-ESRGAN
Optimized for RTX 3050 Laptop GPU
"""
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Dict
import math
import requests
from tqdm import tqdm
# Compact SwinIR Implementation
class PatchEmbed(nn.Module):
"""Image to Patch Embedding - Compact version"""
def __init__(self, img_size=64, patch_size=1, embed_dim=60):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(3, embed_dim, kernel_size=3, stride=1, padding=1)
def forward(self, x):
return self.proj(x)
class WindowAttention(nn.Module):
"""Window based multi-head self attention - Compact version"""
def __init__(self, dim, window_size, num_heads=6):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
class SwinTransformerBlock(nn.Module):
"""Swin Transformer Block - Compact version"""
def __init__(self, dim, num_heads, window_size=8, mlp_ratio=2.):
super().__init__()
self.window_size = window_size
self.norm1 = nn.LayerNorm(dim)
self.attn = WindowAttention(dim, window_size, num_heads)
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, dim)
)
def forward(self, x):
H, W = x.shape[2:]
B, C, H, W = x.shape
# Reshape for attention
x_reshaped = x.flatten(2).transpose(1, 2)
# Attention
shortcut = x_reshaped
x_reshaped = self.norm1(x_reshaped)
x_reshaped = self.attn(x_reshaped.unsqueeze(0)).squeeze(0)
x_reshaped = shortcut + x_reshaped
# MLP
shortcut = x_reshaped
x_reshaped = self.norm2(x_reshaped)
x_reshaped = self.mlp(x_reshaped)
x_reshaped = shortcut + x_reshaped
# Reshape back
x = x_reshaped.transpose(1, 2).reshape(B, C, H, W)
return x
class CompactSwinIR(nn.Module):
"""Compact SwinIR for <1GB VRAM"""
def __init__(self, upscale=4, img_size=64, window_size=8,
embed_dim=60, depths=[4], num_heads=[6]):
super().__init__()
self.upscale = upscale
self.img_size = img_size
self.window_size = window_size
# Shallow feature extraction
self.conv_first = nn.Conv2d(3, embed_dim, 3, 1, 1)
# Transformer blocks (reduced depth)
self.layers = nn.ModuleList()
for i in range(depths[0]):
self.layers.append(
SwinTransformerBlock(embed_dim, num_heads[0], window_size)
)
# Reconstruction
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
# Upsampling
self.conv_before_upsample = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
self.upsample = nn.Sequential(
nn.Conv2d(embed_dim, 3 * upscale * upscale, 3, 1, 1),
nn.PixelShuffle(upscale)
)
def forward(self, x):
# Shallow feature extraction
x = self.conv_first(x)
res = x
# Transformer blocks
for layer in self.layers:
x = layer(x)
# Reconstruction
x = self.conv_after_body(x)
x = x + res
# Upsampling
x = self.conv_before_upsample(x)
x = self.upsample(x)
return x
class CompactRRDBNet(nn.Module):
"""Compact RRDB Net for Real-ESRGAN - <1GB VRAM"""
def __init__(self, in_nc=3, out_nc=3, nf=32, nb=6, gc=16):
super().__init__()
# First convolution
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
# Compact RRDB blocks (reduced from 23 to 6)
self.RRDB_trunk = nn.Sequential(*[
self.make_rrdb_block(nf, gc) for _ in range(nb)
])
# Trunk convolution
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
# Upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def make_rrdb_block(self, nf, gc):
"""Make a compact RRDB block"""
return nn.Sequential(
nn.Conv2d(nf, gc, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(gc, nf, 3, 1, 1)
)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out
class CompactAIEnhancer:
"""Compact AI Enhancer using SwinIR & Real-ESRGAN for <1GB VRAM"""
MODEL_URLS = {
'swinir_lightweight': 'https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth',
'realesrgan_compact': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x4plus_netD.pth',
}
def __init__(self, model_type='swinir', device=None):
"""Initialize compact enhancer"""
self.model_type = model_type
# Device configuration
if device is None:
if torch.cuda.is_available():
self.device = torch.device('cuda')
# Aggressive memory management for <1GB usage
torch.cuda.set_per_process_memory_fraction(0.5) # Use max 50% of VRAM
torch.backends.cudnn.benchmark = False # Save memory
torch.backends.cudnn.deterministic = True
print(f"🚀 Using GPU: {torch.cuda.get_device_name(0)}")
# Get actual VRAM
props = torch.cuda.get_device_properties(0)
vram_gb = props.total_memory / (1024**3)
print(f"📊 Total VRAM: {vram_gb:.1f} GB")
# Adjust tile size based on available VRAM
if vram_gb < 4:
self.tile_size = 128 # Very small tiles for <4GB
self.tile_pad = 8
else:
self.tile_size = 192
self.tile_pad = 16
else:
self.device = torch.device('cpu')
self.tile_size = 256
self.tile_pad = 16
print("💻 Using CPU")
else:
self.device = device
self.tile_size = 128
self.tile_pad = 8
# Model directory
self.model_dir = 'models_compact'
os.makedirs(self.model_dir, exist_ok=True)
# Initialize model
self.model = None
self.load_model()
def load_model(self):
"""Load compact model"""
try:
print(f"🔄 Loading compact {self.model_type} model...")
if self.model_type == 'swinir':
# Compact SwinIR configuration
self.model = CompactSwinIR(
upscale=4,
img_size=64,
window_size=8,
embed_dim=60, # Reduced from 180
depths=[4], # Reduced from [6,6,6,6]
num_heads=[6] # Reduced from [6,6,6,6]
)
model_size = sum(p.numel() for p in self.model.parameters()) * 4 / (1024**2)
print(f"📦 SwinIR Lightweight model size: {model_size:.1f} MB")
elif self.model_type == 'realesrgan':
# Compact Real-ESRGAN
self.model = CompactRRDBNet(
in_nc=3,
out_nc=3,
nf=32, # Reduced from 64
nb=6, # Reduced from 23
gc=16 # Reduced from 32
)
model_size = sum(p.numel() for p in self.model.parameters()) * 4 / (1024**2)
print(f"📦 Real-ESRGAN Compact model size: {model_size:.1f} MB")
else:
raise ValueError(f"Unknown model type: {self.model_type}")
# Move to device
self.model = self.model.to(self.device)
self.model.eval()
# Use half precision on GPU to save memory
if self.device.type == 'cuda':
self.model = self.model.half()
print("✅ Using FP16 for memory efficiency")
# Try to load pretrained weights if available
model_path = os.path.join(self.model_dir, f'{self.model_type}_compact.pth')
if os.path.exists(model_path):
state_dict = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(state_dict, strict=False)
print(f"✅ Loaded pretrained {self.model_type} weights")
else:
print(f"⚠️ No pretrained weights found, using random initialization")
print(f" Model will still work but quality may be lower")
print(f"✅ Model ready! Estimated VRAM usage: <500MB")
except Exception as e:
print(f"❌ Failed to load model: {e}")
self.model = None
def enhance_image(self, image_path: str, output_path: str = None) -> str:
"""Enhance image with compact model"""
if output_path is None:
output_path = image_path.replace('.', '_enhanced.')
print(f"🎨 Enhancing {os.path.basename(image_path)} with {self.model_type}...")
try:
# Load image
img = cv2.imread(image_path)
if img is None:
print(f"❌ Failed to load image: {image_path}")
return image_path
h, w = img.shape[:2]
print(f" Input size: {w}x{h}")
# Clear cache before processing
if self.device.type == 'cuda':
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Enhance
if self.model is not None:
enhanced = self.process_with_tiling(img)
else:
# Fallback
print(" ⚠️ Using fallback upscaling")
enhanced = self.fallback_upscale(img)
# Save result
cv2.imwrite(output_path, enhanced, [cv2.IMWRITE_JPEG_QUALITY, 95])
new_h, new_w = enhanced.shape[:2]
print(f" ✅ Output size: {new_w}x{new_h}")
# Clear memory after processing
if self.device.type == 'cuda':
torch.cuda.empty_cache()
torch.cuda.synchronize()
return output_path
except torch.cuda.OutOfMemoryError:
print(" ❌ CUDA OOM! Falling back to CPU")
self.device = torch.device('cpu')
if self.model:
self.model = self.model.cpu().float()
return self.enhance_image(image_path, output_path)
except Exception as e:
print(f" ❌ Enhancement failed: {e}")
return image_path
def process_with_tiling(self, img):
"""Process image with tiling for minimal VRAM usage"""
# Prepare image
img_tensor = self.img_to_tensor(img)
_, _, h, w = img_tensor.shape
# Calculate output size
out_h, out_w = h * 4, w * 4
# Prepare output tensor on CPU to save VRAM
output = torch.zeros((1, 3, out_h, out_w), dtype=torch.float32, device='cpu')
# Process tiles
tile_size = self.tile_size
pad = self.tile_pad
print(f" Processing with {tile_size}x{tile_size} tiles...")
for y in range(0, h, tile_size - pad * 2):
for x in range(0, w, tile_size - pad * 2):
# Calculate tile boundaries with padding
x_start = max(0, x - pad)
y_start = max(0, y - pad)
x_end = min(w, x + tile_size - pad)
y_end = min(h, y + tile_size - pad)
# Extract tile
tile = img_tensor[:, :, y_start:y_end, x_start:x_end]
# Move tile to device
tile = tile.to(self.device)
if self.device.type == 'cuda' and self.model.training == False:
tile = tile.half()
# Process tile
with torch.no_grad():
enhanced_tile = self.model(tile)
# Move result back to CPU immediately
enhanced_tile = enhanced_tile.cpu().float()
# Calculate output coordinates (excluding padding)
out_x_start = x * 4
out_y_start = y * 4
out_x_end = min(out_w, (x + tile_size - pad * 2) * 4)
out_y_end = min(out_h, (y + tile_size - pad * 2) * 4)
# Calculate tile coordinates (excluding padding)
tile_x_start = pad * 4 if x > 0 else 0
tile_y_start = pad * 4 if y > 0 else 0
tile_x_end = tile_x_start + (out_x_end - out_x_start)
tile_y_end = tile_y_start + (out_y_end - out_y_start)
# Place tile in output
output[:, :, out_y_start:out_y_end, out_x_start:out_x_end] = \
enhanced_tile[:, :, tile_y_start:tile_y_end, tile_x_start:tile_x_end]
# Clear tile from GPU memory immediately
del tile, enhanced_tile
if self.device.type == 'cuda':
torch.cuda.empty_cache()
# Convert back to image
return self.tensor_to_img(output)
def img_to_tensor(self, img):
"""Convert image to tensor"""
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
return img_tensor
def tensor_to_img(self, tensor):
"""Convert tensor to image"""
img = tensor.squeeze(0).permute(1, 2, 0).numpy()
img = (img * 255).clip(0, 255).astype(np.uint8)
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
def fallback_upscale(self, img):
"""High-quality fallback upscaling"""
h, w = img.shape[:2]
# EDSR-inspired upscaling (max 2K)
scale = min(2, 2048/w, 1080/h)
new_w = int(w * scale)
new_h = int(h * scale)
upscaled = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
# Enhance sharpness
kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) / 1
upscaled = cv2.filter2D(upscaled, -1, kernel)
# Denoise
upscaled = cv2.bilateralFilter(upscaled, 5, 50, 50)
return upscaled
def get_memory_usage(self):
"""Get current memory usage"""
if self.device.type == 'cuda':
allocated = torch.cuda.memory_allocated() / (1024**2)
reserved = torch.cuda.memory_reserved() / (1024**2)
return f"Allocated: {allocated:.1f}MB, Reserved: {reserved:.1f}MB"
return "Using CPU"
# Easy-to-use functions
def create_compact_enhancer(model_type='swinir'):
"""Create a compact enhancer that works with <1GB VRAM"""
return CompactAIEnhancer(model_type=model_type)
def enhance_with_swinir(image_path, output_path=None):
"""Enhance image with compact SwinIR"""
enhancer = CompactAIEnhancer(model_type='swinir')
return enhancer.enhance_image(image_path, output_path)
def enhance_with_compact_realesrgan(image_path, output_path=None):
"""Enhance image with compact Real-ESRGAN"""
enhancer = CompactAIEnhancer(model_type='realesrgan')
return enhancer.enhance_image(image_path, output_path)
if __name__ == "__main__":
print("🚀 Compact AI Models for <1GB VRAM")
print("=" * 50)
# Test both models
enhancer = CompactAIEnhancer(model_type='swinir')
print(f"\nMemory usage: {enhancer.get_memory_usage()}")
enhancer2 = CompactAIEnhancer(model_type='realesrgan')
print(f"Memory usage: {enhancer2.get_memory_usage()}")