File size: 5,128 Bytes
e5f98db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7baf792
 
dac160c
7baf792
f33e554
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cd7994
605b8ba
7baf792
dac160c
 
605b8ba
c25f048
7baf792
 
605b8ba
 
 
 
 
7baf792
 
ddbe533
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
'''
import os
from gradio_client import Client, handle_file
from PIL import Image

# Initialize Gradio client
client = Client("http://localhost:7860")

# Path configuration
input_dir = "Aesthetics_X_Phone_720p_Images_Rec_Captioned_16_9"
output_dir = "processed_output_v2_single_thread"
os.makedirs(output_dir, exist_ok=True)

def process_single_image(png_path):
    """Process a single image and save results"""
    try:
        # Get base filename without extension
        base_name = os.path.splitext(os.path.basename(png_path))[0]
        
        # Corresponding text file path
        txt_path = os.path.join(input_dir, f"{base_name}.txt")
        
        print(f"Processing: {png_path}...", end=" ", flush=True)
        
        # Process image through API (returns WEBP path)
        webp_result = client.predict(
            img=handle_file(png_path),
            model_selection="v2",
            api_name="/predict"
        )
        
        # Output paths
        output_image_path = os.path.join(output_dir, f"{base_name}.png")
        output_text_path = os.path.join(output_dir, f"{base_name}.txt")
        
        # Convert WEBP to PNG
        with Image.open(webp_result) as img:
            img.save(output_image_path, "PNG")
        
        # Copy corresponding text file if exists
        if os.path.exists(txt_path):
            with open(txt_path, 'r', encoding='utf-8') as src, \
                 open(output_text_path, 'w', encoding='utf-8') as dst:
                dst.write(src.read())
        
        print("Done")
        return True
    
    except Exception as e:
        print(f"Failed: {str(e)}")
        return False

def main():
    # Get all PNG files in input directory
    png_files = sorted([
        os.path.join(input_dir, f) 
        for f in os.listdir(input_dir) 
        if f.lower().endswith('.png')
    ])
    
    print(f"Found {len(png_files)} PNG files to process")
    
    # Process files one by one
    success_count = 0
    for i, png_path in enumerate(png_files, 1):
        print(f"\n[{i}/{len(png_files)}] ", end="")
        if process_single_image(png_path):
            success_count += 1
    
    print(f"\nProcessing complete! Success: {success_count}/{len(png_files)}")

if __name__ == "__main__":
    main()
'''

from aura_sr import AuraSR
import gradio as gr
import spaces


class ZeroGPUAuraSR(AuraSR):
    @classmethod
    def from_pretrained(cls, model_id: str = "fal-ai/AuraSR", use_safetensors: bool = True):
        import json
        import torch
        from pathlib import Path
        from huggingface_hub import snapshot_download

        # Check if model_id is a local file
        if Path(model_id).is_file():
            local_file = Path(model_id)
            if local_file.suffix == '.safetensors':
                use_safetensors = True
            elif local_file.suffix == '.ckpt':
                use_safetensors = False
            else:
                raise ValueError(f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files.")
            
            # For local files, we need to provide the config separately
            config_path = local_file.with_name('config.json')
            if not config_path.exists():
                raise FileNotFoundError(
                    f"Config file not found: {config_path}. "
                    f"When loading from a local file, ensure that 'config.json' "
                    f"is present in the same directory as '{local_file.name}'. "
                    f"If you're trying to load a model from Hugging Face, "
                    f"please provide the model ID instead of a file path."
                )
            
            config = json.loads(config_path.read_text())
            hf_model_path = local_file.parent
        else:
            hf_model_path = Path(snapshot_download(model_id))
            config = json.loads((hf_model_path / "config.json").read_text())

        model = cls(config)

        if use_safetensors:
            try:
                from safetensors.torch import load_file
                checkpoint = load_file(hf_model_path / "model.safetensors" if not Path(model_id).is_file() else model_id)
            except ImportError:
                raise ImportError(
                    "The safetensors library is not installed. "
                    "Please install it with `pip install safetensors` "
                    "or use `use_safetensors=False` to load the model with PyTorch."
                )
        else:
            checkpoint = torch.load(hf_model_path / "model.ckpt" if not Path(model_id).is_file() else model_id)

        model.upsampler.load_state_dict(checkpoint, strict=True)
        return model



aura_sr = ZeroGPUAuraSR.from_pretrained("fal/AuraSR-v2")
aura_sr_v1 = ZeroGPUAuraSR.from_pretrained("fal-ai/AuraSR")


@spaces.GPU()
def predict(img, model_selection):
    return {'v1': aura_sr_v1, 'v2': aura_sr}.get(model_selection).upscale_4x(img)


demo = gr.Interface(
    predict,
    inputs=[gr.Image(), gr.Dropdown(value='v2', choices=['v1', 'v2'])],
    outputs=gr.Image()
)


demo.launch(share = True)