File size: 1,272 Bytes
d0f5c68
 
 
 
ca77dd1
 
 
d0f5c68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c119503
d0f5c68
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import gradio as gr
import torch

model_map = torch.hub.load('nateraw/image-generation:main', 'model_map')
on_gpu = torch.cuda.is_available()

print(f"GPU enabled? - {'🔴' if not on_gpu else '🟢'}")

class InferenceWrapper:
    def __init__(self, model):
        self.model = model
        self.pipe = torch.hub.load('nateraw/image-generation:main', 'styleganv3', pretrained=self.model)
    def __call__(self, seed, model):
        if model != self.model:
            print(f"Loading model: {model}")
            self.model = model
            self.pipe = torch.hub.load('nateraw/image-generation:main', 'styleganv3', pretrained=self.model)
        else:
            print(f"Model '{model}' already loaded, reusing it.")
        return self.pipe(seed)

wrapper = InferenceWrapper('wikiart-1024')
def fn(seed, model):
    return wrapper(seed, model)

gr.Interface(
    fn,
    inputs=[
        gr.inputs.Slider(minimum=0, maximum=999999999, step=1, default=0, label='Random Seed'),
        gr.inputs.Radio(list(model_map), type="value", default='stylegan3-t-afhqv2-512x512.pkl', label='Pretrained Model')
    ],
    outputs='image',
    examples=[[343, 'wikiart-1024'], [456, 'landscapes-256'], [1234, 'stylegan3-r-ffhqu-256x256.pkl']],
    enable_queue=True
).launch()