|
import gradio as gr |
|
import torch |
|
|
|
model_map = torch.hub.load('nateraw/image-generation:main', 'model_map') |
|
on_gpu = torch.cuda.is_available() |
|
|
|
print(f"GPU enabled? - {'π΄' if not on_gpu else 'π’'}") |
|
|
|
class InferenceWrapper: |
|
def __init__(self, model): |
|
self.model = model |
|
self.pipe = torch.hub.load('nateraw/image-generation:main', 'styleganv3', pretrained=self.model) |
|
def __call__(self, seed, model): |
|
if model != self.model: |
|
print(f"Loading model: {model}") |
|
self.model = model |
|
self.pipe = torch.hub.load('nateraw/image-generation:main', 'styleganv3', pretrained=self.model) |
|
else: |
|
print(f"Model '{model}' already loaded, reusing it.") |
|
return self.pipe(seed) |
|
|
|
wrapper = InferenceWrapper('wikiart-1024') |
|
def fn(seed, model): |
|
return wrapper(seed, model) |
|
|
|
gr.Interface( |
|
fn, |
|
inputs=[ |
|
gr.inputs.Slider(minimum=0, maximum=999999999, step=1, default=0, label='Random Seed'), |
|
gr.inputs.Radio(list(model_map), type="value", default='wikiart-1024', label='Pretrained Model') |
|
], |
|
outputs='image', |
|
examples=[[343, 'wikiart-1024'], [456, 'landscapes-256'], [1234, 'stylegan3-r-ffhqu-256x256.pkl']], |
|
enable_queue=True |
|
).launch() |