import os import cv2 import torch import spaces import imageio import numpy as np import gradio as gr torch.jit.script = lambda f: f import argparse from utils.batch_inference import ( BSRInferenceLoop, BIDInferenceLoop ) # import subprocess # subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) device = 'cuda' if torch.cuda.is_available() else 'cpu' def get_example(task): case = { "dn": [ ['examples/bus.mp4',], ['examples/koala.mp4',], ['examples/flamingo.mp4',], ['examples/rhino.mp4',], ['examples/elephant.mp4',], ['examples/sheep.mp4',], ['examples/dog-agility.mp4',], # ['examples/dog-gooses.mp4',], ], "sr": [ ['examples/bus_sr.mp4',], ['examples/koala_sr.mp4',], ['examples/flamingo_sr.mp4',], ['examples/rhino_sr.mp4',], ['examples/elephant_sr.mp4',], ['examples/sheep_sr.mp4',], ['examples/dog-agility_sr.mp4',], # ['examples/dog-gooses_sr.mp4',], ] } return case[task] def update_prompt(input_video): video_name = input_video.split('/')[-1] return set_default_prompt(video_name) # Map videos to corresponding images video_to_image = { 'bus.mp4': ['examples_frames/bus'], 'koala.mp4': ['examples_frames/koala'], 'dog-gooses.mp4': ['examples_frames/dog-gooses'], 'flamingo.mp4': ['examples_frames/flamingo'], 'rhino.mp4': ['examples_frames/rhino'], 'elephant.mp4': ['examples_frames/elephant'], 'sheep.mp4': ['examples_frames/sheep'], 'dog-agility.mp4': ['examples_frames/dog-agility'], 'bus_sr.mp4': ['examples_frames/bus_sr'], 'koala_sr.mp4': ['examples_frames/koala_sr'], 'dog-gooses_sr.mp4': ['examples_frames/dog_gooses_sr'], 'flamingo_sr.mp4': ['examples_frames/flamingo_sr'], 'rhino_sr.mp4': ['examples_frames/rhino_sr'], 'elephant_sr.mp4': ['examples_frames/elephant_sr'], 'sheep_sr.mp4': ['examples_frames/sheep_sr'], 'dog-agility_sr.mp4': ['examples_frames/dog-agility_sr'], } def images_to_video(image_list, output_path, fps=10): # Convert PIL Images to numpy arrays frames = [np.array(img).astype(np.uint8) for img in image_list] frames = frames[:20] # Create video writer writer = imageio.get_writer(output_path, fps=fps, codec='libx264') for frame in frames: writer.append_data(frame) writer.close() def video2frames(video_path): # Open the video file video = cv2.VideoCapture(video_path) img_path = video_path[:-4] # Initialize frame counter frame_count = 0 os.makedirs(img_path, exist_ok=True) while True: # Read a frame from the video ret, frame = video.read() # If the frame was not successfully read, then we have reached the end of the video if not ret: break # Write the frame to a JPG file frame_file = f"{img_path}/{frame_count:05}.jpg" cv2.imwrite(frame_file, frame) # Increment the frame counter frame_count += 1 # Release the video file video.release() return img_path @spaces.GPU(duration=120) def DiffBIR_restore(input_video, prompt, sr_ratio, n_frames, n_steps, guidance_scale, seed, n_prompt, task): video_name = input_video.split('/')[-1] if video_name in video_to_image: frames_path = video_to_image[video_name][0] else: frames_path = video2frames(input_video) print(f"[INFO] input_video: {input_video}") print(f"[INFO] Frames path: {frames_path}") args = argparse.Namespace() # args.task = True, choices=["sr", "dn", "fr", "fr_bg"] args.task = task args.upscale = sr_ratio ### sampling parameters args.steps = n_steps args.better_start = True args.tiled = False args.tile_size = 512 args.tile_stride = 256 args.pos_prompt = prompt args.neg_prompt = n_prompt args.cfg_scale = guidance_scale ### input parameters args.input = frames_path args.n_samples = 1 args.batch_size = 10 args.final_size = (480, 854) args.config = "configs/inference/my_cldm.yaml" ### guidance parameters args.guidance = False args.g_loss = "w_mse" args.g_scale = 0.0 args.g_start = 1001 args.g_stop = -1 args.g_space = "latent" args.g_repeat = 1 ### output parameters args.output = " " ### common parameters args.seed = seed args.device = "cuda" args.n_frames = n_frames ### latent control parameters args.warp_period = [0, 0.1] args.merge_period = [0, 0] args.ToMe_period = [0, 1] args.merge_ratio = [0.6, 0] if args.task == "sr": restored_vid_path = BSRInferenceLoop(args).run() elif args.task == "dn": restored_vid_path = BIDInferenceLoop(args).run() torch.cuda.empty_cache() return restored_vid_path ######## # demo # ######## intro = """