Spaces:
Running
on
Zero
Running
on
Zero
from spandrel import ModelLoader | |
import torch | |
from pathlib import Path | |
import gradio as App | |
import logging | |
import spaces | |
import time | |
import cv2 | |
import os | |
from gradio import themes | |
from rich.console import Console | |
from rich.logging import RichHandler | |
from Scripts.SAD import GetDifferenceRectangles | |
from Scripts.ORB import DetectMotionWithOrb | |
# ============================== # | |
# Core Settings # | |
# ============================== # | |
Theme = themes.Citrus( | |
primary_hue='blue', | |
secondary_hue='blue', | |
radius_size=themes.sizes.radius_xxl | |
).set( | |
link_text_color='blue' | |
) | |
ModelDir = Path('./Models') | |
TempDir = Path('./Temp') | |
os.environ['GRADIO_TEMP_DIR'] = str(TempDir) | |
ModelFileType = '.pth' | |
# ============================== # | |
# Logging # | |
# ============================== # | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(message)s', | |
datefmt='[%X]', | |
handlers=[RichHandler( | |
console=Console(), | |
rich_tracebacks=True, | |
omit_repeated_times=False, | |
markup=True, | |
show_path=False, | |
)], | |
) | |
Logger = logging.getLogger('Zero2x') | |
logging.getLogger('httpx').setLevel(logging.WARNING) | |
# ============================== # | |
# Device Configuration # | |
# ============================== # | |
def GetDeviceName(): | |
Device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
Logger.info(f'π§ͺ Using device: {str(Device).upper()}') | |
return Device | |
Device = GetDeviceName() | |
# ============================== # | |
# Utility Functions # | |
# ============================== # | |
def HumanizeSeconds(Seconds): | |
Hours = int(Seconds // 3600) | |
Minutes = int((Seconds % 3600) // 60) | |
Seconds = int(Seconds % 60) | |
if Hours > 0: | |
return f'{Hours}h {Minutes}m {Seconds}s' | |
elif Minutes > 0: | |
return f'{Minutes}m {Seconds}s' | |
else: | |
return f'{Seconds}s' | |
def HumanizedBytes(Size): | |
Units = ['B', 'KB', 'MB', 'GB', 'TB'] | |
Index = 0 | |
while Size >= 1024 and Index < len(Units) - 1: | |
Size /= 1024.0 | |
Index += 1 | |
return f'{Size:.2f} {Units[Index]}' | |
# ============================== # | |
# Main Processing Logic # | |
# ============================== # | |
class Upscaler: | |
def __init__(self): | |
pass | |
def ListModels(self): | |
Models = sorted( | |
[File.name for File in ModelDir.glob('*' + ModelFileType) if File.is_file()] | |
) | |
Logger.info(f'π Found {len(Models)} Models In Directory') | |
return Models | |
def LoadModel(self, ModelName): | |
torch.cuda.empty_cache() | |
Model = ( | |
ModelLoader() | |
.load_from_file(ModelDir / (ModelName + ModelFileType)) | |
.to(Device) | |
.eval() | |
) | |
Logger.info(f'π€ Loaded Model {ModelName} Onto {str(Device).upper()}') | |
return Model | |
def UnloadModel(self): | |
if Device.type == 'cuda': | |
torch.cuda.empty_cache() | |
Logger.info('π€ Model Unloaded Successfully') | |
def CleanUp(self): | |
self.UnloadModel() | |
Logger.info('π§Ή Temporary Files Cleaned Up') | |
def UpscaleFullFrame(self, Model, Frame): | |
FrameRgb = cv2.cvtColor(Frame, cv2.COLOR_BGR2RGB) | |
FrameForTorch = FrameRgb.transpose(2, 0, 1) | |
FrameForTorch = torch.from_numpy(FrameForTorch).unsqueeze(0).to(Device).float() / 255.0 | |
OutputFrame = Model(FrameForTorch)[0].cpu().numpy().transpose(1, 2, 0) * 255.0 | |
OutputFrame = cv2.cvtColor(OutputFrame.astype('uint8'), cv2.COLOR_RGB2BGR) | |
return OutputFrame | |
def UpscaleRegions(self, Model, Frame, PrevFrame, UpscaledPrevFrame, InputThreshold, InputMinPercentage, InputMaxRectangles, InputPadding, InputSegmentRows, InputSegmentColumns): | |
DiffResult = GetDifferenceRectangles( | |
PrevFrame, | |
Frame, | |
Threshold=InputThreshold, | |
Rows=InputSegmentRows, | |
Columns=InputSegmentColumns, | |
Padding=InputPadding | |
) | |
SimilarityPercentage = DiffResult['SimilarPercentage'] | |
Rectangles = DiffResult['Rectangles'] | |
Cols = DiffResult['Columns'] | |
Rows = DiffResult['Rows'] | |
FrameHeight, FrameWidth = Frame.shape[:2] | |
SegmentWidth = FrameWidth // Cols | |
SegmentHeight = FrameHeight // Rows | |
UseRegions = False | |
RegionLog = 'π₯' | |
if SimilarityPercentage > InputMinPercentage and len(Rectangles) < InputMaxRectangles: | |
UpscaleFactorY = UpscaledPrevFrame.shape[0] // FrameHeight | |
UpscaleFactorX = UpscaledPrevFrame.shape[1] // FrameWidth | |
OutputFrame = UpscaledPrevFrame.copy() | |
for X, Y, W, H in Rectangles: | |
X1 = X * SegmentWidth | |
Y1 = Y * SegmentHeight | |
X2 = FrameWidth if X + W == Cols else X1 + W * SegmentWidth | |
Y2 = FrameHeight if Y + H == Rows else Y1 + H * SegmentHeight | |
Region = Frame[Y1:Y2, X1:X2] | |
RegionRgb = cv2.cvtColor(Region, cv2.COLOR_BGR2RGB) | |
RegionTorch = torch.from_numpy(RegionRgb.transpose(2, 0, 1)).unsqueeze(0).to(Device).float() / 255.0 | |
UpscaledRegion = Model(RegionTorch)[0].cpu().numpy().transpose(1, 2, 0) * 255.0 | |
UpscaledRegion = cv2.cvtColor(UpscaledRegion.astype('uint8'), cv2.COLOR_RGB2BGR) | |
RegionHeight, RegionWidth = Region.shape[:2] | |
UpscaledRegion = cv2.resize(UpscaledRegion, (RegionWidth * UpscaleFactorX, RegionHeight * UpscaleFactorY), interpolation=cv2.INTER_CUBIC) | |
UX1 = X1 * UpscaleFactorX | |
UY1 = Y1 * UpscaleFactorY | |
UX2 = UX1 + UpscaledRegion.shape[1] | |
UY2 = UY1 + UpscaledRegion.shape[0] | |
OutputFrame[UY1:UY2, UX1:UX2] = UpscaledRegion | |
RegionLog = 'π©' | |
UseRegions = True | |
else: | |
OutputFrame = self.UpscaleFullFrame(Model, Frame) | |
return OutputFrame, SimilarityPercentage, Rectangles, RegionLog, UseRegions | |
def Process(self, InputVideo, InputModel, InputUseRegions, InputThreshold, InputMinPercentage, InputMaxRectangles, InputPadding, InputSegmentRows, InputSegmentColumns, InputFullFrameInterval, InputMotionThreshold, Progress=App.Progress()): | |
if not InputVideo: | |
Logger.warning('β No Video Provided') | |
App.Warning('β No Video Provided') | |
return None, None | |
Progress(0, desc='βοΈ Loading Model') | |
Model = self.LoadModel(InputModel) | |
Logger.info(f'πΌ Processing Video: {Path(InputVideo).name}') | |
Progress(0, desc='πΌ Processing Video') | |
Video = cv2.VideoCapture(InputVideo) | |
FrameRate = Video.get(cv2.CAP_PROP_FPS) | |
FrameCount = int(Video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
Width = int(Video.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
Height = int(Video.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
Logger.info(f'π Video Properties: {FrameCount} Frames, {FrameRate} FPS, {Width}x{Height}') | |
PerFrameProgress = 1 / FrameCount | |
FrameProgress = 0.0 | |
StartTime = time.time() | |
Times = [] | |
CurrentFrameIndex = 0 | |
PrevFrame = None | |
UpscaledPrevFrame = None | |
PartialUpscaleCount = 0 | |
while True: | |
Ret, Frame = Video.read() | |
if not Ret: | |
break | |
CurrentFrameIndex += 1 | |
ForceFull = False | |
CopyPrevUpscaled = False | |
if CurrentFrameIndex == 1 or not InputUseRegions: | |
ForceFull = True | |
PartialUpscaleCount = 0 | |
elif PartialUpscaleCount >= InputFullFrameInterval: | |
ForceFull = True | |
PartialUpscaleCount = 0 | |
if PrevFrame is not None: | |
IsMotion, TotalMagnitude, DirectionAngle = DetectMotionWithOrb(PrevFrame, Frame, InputMotionThreshold) | |
if IsMotion: | |
ForceFull = True | |
PartialUpscaleCount = 0 | |
Logger.info(f'π¨ Frame {CurrentFrameIndex}: Motion Detected - Upscaling Full Frame') | |
if not ForceFull and PrevFrame is not None and UpscaledPrevFrame is not None: | |
DiffResult = GetDifferenceRectangles( | |
PrevFrame, | |
Frame, | |
Threshold=InputThreshold, | |
Rows=InputSegmentRows, | |
Columns=InputSegmentColumns, | |
Padding=InputPadding | |
) | |
SimilarityPercentage = DiffResult['SimilarPercentage'] | |
if SimilarityPercentage == 100: | |
OutputFrame = UpscaledPrevFrame.copy() | |
RegionLog = 'π¦' | |
UseRegions = False | |
Rectangles = [] | |
Logger.info(f'{RegionLog} Frame {CurrentFrameIndex}: 100% Similar - Copied Previous Upscaled Frame') | |
FrameProgress += PerFrameProgress | |
Progress(FrameProgress, desc=f'π¦ Processed Frame {CurrentFrameIndex}/{FrameCount}') | |
cv2.imwrite(f'{TempDir}/Upscaled_Frame_{CurrentFrameIndex:05d}.png', OutputFrame) | |
PrevFrame = Frame.copy() | |
UpscaledPrevFrame = OutputFrame.copy() | |
DeltaTime = time.time() - StartTime | |
Times.append(DeltaTime) | |
StartTime = time.time() | |
continue | |
if ForceFull: | |
OutputFrame = self.UpscaleFullFrame(Model, Frame) | |
SimilarityPercentage = 0 | |
Rectangles = [] | |
RegionLog = 'π₯' | |
UseRegions = False | |
else: | |
OutputFrame, SimilarityPercentage, Rectangles, RegionLog, UseRegions = self.UpscaleRegions( | |
Model, Frame, PrevFrame, UpscaledPrevFrame, InputThreshold, InputMinPercentage, InputMaxRectangles, InputPadding, InputSegmentRows, InputSegmentColumns | |
) | |
if UseRegions: | |
PartialUpscaleCount += 1 | |
else: | |
PartialUpscaleCount = 0 | |
if Times: | |
AverageTime = sum(Times) / len(Times) | |
Eta = HumanizeSeconds((FrameCount - CurrentFrameIndex) * AverageTime) | |
else: | |
Eta = None | |
if UseRegions: | |
Logger.info(f'{RegionLog} Frame {CurrentFrameIndex}: {SimilarityPercentage:.2f}% Similar, {len(Rectangles)} Regions To Upscale') | |
else: | |
Logger.info(f'{RegionLog} Frame {CurrentFrameIndex}: Upscaling Full Frame') | |
Progress(FrameProgress, desc=f'π¦ Processed Frame {CurrentFrameIndex}/{FrameCount} - {Eta}') | |
cv2.imwrite(f'{TempDir}/Upscaled_Frame_{CurrentFrameIndex:05d}.png', OutputFrame) | |
DeltaTime = time.time() - StartTime | |
Times.append(DeltaTime) | |
StartTime = time.time() | |
FrameProgress += PerFrameProgress | |
PrevFrame = Frame.copy() | |
UpscaledPrevFrame = OutputFrame.copy() | |
Progress(1, desc='π¦ Cleaning Up') | |
self.CleanUp() | |
return InputVideo, InputVideo | |
# ============================== # | |
# Streamlined UI # | |
# ============================== # | |
with App.Blocks( | |
title='Zero2x Video Upscaler', theme=Theme, delete_cache=(-1, 1800) | |
) as Interface: | |
App.Markdown('# ποΈ Zero2x Video Upscaler') | |
with App.Accordion(label='βοΈ About Zero2x', open=False): | |
App.Markdown(''' | |
**Zero2x** is a work-in-progress video upscaling tool that uses deep learning models to enhance your videos frame by frame. | |
This app leverages region-based difference detection to speed up processing and reduce unnecessary computation. | |
--- | |
## β¨ Features | |
- **Multiple Upscaling Models:** Choose from a selection of pre-trained models for different styles and quality. | |
- **Region-Based Upscaling:** Only upscale parts of the frame that have changed, making processing faster and more memory-efficient. | |
- **Full Frame Upscaling:** Optionally upscale every frame in its entirety for maximum quality. | |
- **Customizable Settings:** Fine-tune thresholds, padding, and region detection for your specific needs. | |
- **Progress Tracking:** See estimated time remaining and per-frame progress. | |
- **Downloadable Results:** Download your upscaled video when processing is complete. | |
--- | |
## π§βπ¬ Technique | |
This app uses the Segmented Absolute Differences (SAD) (Created by me) program to compare each frame with the previous one. | |
If only small regions have changed, only those regions are upscaled using the selected model. | |
If the whole frame is different, the entire frame is upscaled. | |
This hybrid approach balances speed and quality. | |
--- | |
## π§ Work In Progress | |
- More models and settings will be added soon. | |
- Some features may be experimental or incomplete. | |
- Feedback and suggestions are welcome! | |
- The quality of the upscaled video may vary depending on the model and settings used. | |
--- | |
**Tip:** If you encounter CUDA out-of-memory errors, try increasing the segment grid size or lowering the region count. | |
**Note:** The reason i named this project Zero2x is because i was inspired by Video2x, but i wanted my own version with a different approach. | |
It is running on HuggingFace's ZeroGPU hardware, which is why i came up with the name. | |
''') | |
with App.Row(): | |
with App.Column(): | |
with App.Group(): | |
InputVideo = App.Video( | |
label='Input Video', sources=['upload'], height=300 | |
) | |
ModelList = Upscaler().ListModels() | |
ModelNames = [Path(Model).stem for Model in ModelList] | |
InputModel = App.Dropdown( | |
choices=ModelNames, | |
label='Select Model', | |
value=ModelNames[0], | |
) | |
with App.Accordion(label='βοΈ Advanced Settings', open=False): | |
with App.Accordion(label='π Settings Explained', open=False): | |
App.Markdown(''' | |
- **Use Regions:** When enabled, only changed areas between frames are upscaled. This is faster but may miss subtle changes. | |
- **Threshold:** Controls how sensitive the difference detection is. I found high values to introduce unmatching regions, be careful. | |
- **Padding:** Adds extra pixels around detected regions to include out of bounds pixels. | |
- **Min Percentage:** If the similarity between frames is above this value, only regions are upscaled; otherwise, the full frame is upscaled. | |
- **Max Rectangles:** Limits the number of regions to process per frame for performance. | |
- **Segment Rows/Columns:** Controls the grid size for region detection. More segments allow finer detection but may increase processing time. Uses less Vram when used. | |
- **Full Frame Interval:** Forces a full-frame upscale every N frames. Set to 1 to always upscale the full frame. This is to prevent regions from glitching out. | |
- **Motion Threshold:** Controls how sensitive the motion detection is. Upscaling motion frames increases faulty regions. Lower = More strict | |
''') | |
with App.Group(): | |
InputUseRegions = App.Checkbox( | |
label='Use Regions', | |
value=False, | |
info='Use regions to upscale only the different parts of the video (β‘οΈ Experimental, Faster)', | |
interactive=True | |
) | |
InputThreshold = App.Slider( | |
label='Threshold', | |
value=2, | |
minimum=0, | |
maximum=10, | |
step=0.5, | |
info='Threshold for the SAD algorithm to detect different regions', | |
interactive=False | |
) | |
InputPadding = App.Slider( | |
label='Padding', | |
value=1, | |
minimum=0, | |
maximum=5, | |
step=1, | |
info='Extra padding to include neighboring pixels in the SAD algorithm', | |
interactive=False | |
) | |
InputMinPercentage = App.Slider( | |
label='Min Percentage', | |
value=50, | |
minimum=0, | |
maximum=100, | |
step=1, | |
info='Minimum percentage of similarity to consider upscaling the full frame', | |
interactive=False | |
) | |
InputMaxRectangles = App.Slider( | |
label='Max Rectangles', | |
value=10, | |
minimum=1, | |
maximum=16, | |
step=1, | |
info='Maximum number of rectangles to consider upscaling the full frame', | |
interactive=False | |
) | |
with App.Row(): | |
InputSegmentRows = App.Slider( | |
label='Segment Rows', | |
value=32, | |
minimum=1, | |
maximum=64, | |
step=1, | |
info='Number of rows to segment the video into for processing', | |
interactive=False | |
) | |
InputSegmentColumns = App.Slider( | |
label='Segment Columns', | |
value=48, | |
minimum=1, | |
maximum=64, | |
step=1, | |
info='Number of columns to segment the video into for processing', | |
interactive=False | |
) | |
InputFullFrameInterval = App.Slider( | |
label='Full Frame Interval', | |
value=5, | |
minimum=1, | |
maximum=100, | |
step=1, | |
info='Force a full-frame upscale every N frames (set to 1 to always upscale full frame)', | |
interactive=False | |
) | |
InputMotionThreshold = App.Slider( | |
label='Motion Threshold', | |
value=1, | |
minimum=0, | |
maximum=10, | |
step=0.5, | |
info='Threshold for the motion detection algorithm to consider a frame as different', | |
interactive=False | |
) | |
SubmitButton = App.Button('π Upscale Video') | |
with App.Column(show_progress=True): | |
with App.Group(): | |
OutputVideo = App.Video( | |
label='Output Video', height=300, interactive=False, format=None | |
) | |
OutputDownload = App.DownloadButton( | |
label='πΎ Download Video', interactive=False | |
) | |
def ToggleRegionInputs(UseRegions): | |
return ( | |
App.update(interactive=UseRegions), | |
App.update(interactive=UseRegions), | |
App.update(interactive=UseRegions), | |
App.update(interactive=UseRegions), | |
App.update(interactive=UseRegions), | |
App.update(interactive=UseRegions), | |
App.update(interactive=UseRegions), | |
App.update(interactive=UseRegions) | |
) | |
InputUseRegions.change( | |
fn=ToggleRegionInputs, | |
inputs=[InputUseRegions], | |
outputs=[InputThreshold, InputMinPercentage, InputMaxRectangles, InputPadding, InputSegmentRows, InputSegmentColumns, InputFullFrameInterval, InputMotionThreshold], | |
) | |
SubmitButton.click( | |
fn=Upscaler().Process, | |
inputs=[ | |
InputVideo, | |
InputModel, | |
InputUseRegions, | |
InputThreshold, | |
InputMinPercentage, | |
InputMaxRectangles, | |
InputPadding, | |
InputSegmentRows, | |
InputSegmentColumns, | |
InputFullFrameInterval, | |
InputMotionThreshold | |
], | |
outputs=[OutputVideo, OutputDownload], | |
) | |
if __name__ == '__main__': | |
os.makedirs(ModelDir, exist_ok=True) | |
os.makedirs(TempDir, exist_ok=True) | |
Logger.info('π Starting Video Upscaler') | |
Interface.launch(pwa=True) |