Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ from PIL import Image
|
|
12 |
import numpy as np
|
13 |
from small_256_model import UNet as small_UNet
|
14 |
from big_1024_model import UNet as big_UNet
|
|
|
15 |
|
16 |
# Device configuration
|
17 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
@@ -28,8 +29,11 @@ model_repo_id = "K00B404/pix2pix_flux"
|
|
28 |
# Global model variable
|
29 |
global_model = None
|
30 |
|
|
|
|
|
|
|
31 |
def load_model():
|
32 |
-
"""Load the
|
33 |
global global_model
|
34 |
weights_name = 'big_model_weights.pth' if big else 'small_model_weights.pth'
|
35 |
try:
|
@@ -47,6 +51,8 @@ def load_model():
|
|
47 |
global_model = model
|
48 |
return model
|
49 |
|
|
|
|
|
50 |
# Dataset class remains the same
|
51 |
class Pix2PixDataset(torch.utils.data.Dataset):
|
52 |
def __init__(self, ds, transform):
|
|
|
12 |
import numpy as np
|
13 |
from small_256_model import UNet as small_UNet
|
14 |
from big_1024_model import UNet as big_UNet
|
15 |
+
from CLIP import load as load_clip
|
16 |
|
17 |
# Device configuration
|
18 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
29 |
# Global model variable
|
30 |
global_model = None
|
31 |
|
32 |
+
# clip
|
33 |
+
clip_model,clip_tokenizer = load_clip()
|
34 |
+
|
35 |
def load_model():
|
36 |
+
"""Load the models at startup"""
|
37 |
global global_model
|
38 |
weights_name = 'big_model_weights.pth' if big else 'small_model_weights.pth'
|
39 |
try:
|
|
|
51 |
global_model = model
|
52 |
return model
|
53 |
|
54 |
+
|
55 |
+
|
56 |
# Dataset class remains the same
|
57 |
class Pix2PixDataset(torch.utils.data.Dataset):
|
58 |
def __init__(self, ds, transform):
|