K00B404 commited on
Commit
d5d8a26
·
verified ·
1 Parent(s): d902dc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -12,6 +12,7 @@ from PIL import Image
12
  import numpy as np
13
  from small_256_model import UNet as small_UNet
14
  from big_1024_model import UNet as big_UNet
 
15
 
16
  # Device configuration
17
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -28,8 +29,11 @@ model_repo_id = "K00B404/pix2pix_flux"
28
  # Global model variable
29
  global_model = None
30
 
 
 
 
31
  def load_model():
32
- """Load the model at startup"""
33
  global global_model
34
  weights_name = 'big_model_weights.pth' if big else 'small_model_weights.pth'
35
  try:
@@ -47,6 +51,8 @@ def load_model():
47
  global_model = model
48
  return model
49
 
 
 
50
  # Dataset class remains the same
51
  class Pix2PixDataset(torch.utils.data.Dataset):
52
  def __init__(self, ds, transform):
 
12
  import numpy as np
13
  from small_256_model import UNet as small_UNet
14
  from big_1024_model import UNet as big_UNet
15
+ from CLIP import load as load_clip
16
 
17
  # Device configuration
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
29
  # Global model variable
30
  global_model = None
31
 
32
+ # clip
33
+ clip_model,clip_tokenizer = load_clip()
34
+
35
  def load_model():
36
+ """Load the models at startup"""
37
  global global_model
38
  weights_name = 'big_model_weights.pth' if big else 'small_model_weights.pth'
39
  try:
 
51
  global_model = model
52
  return model
53
 
54
+
55
+
56
  # Dataset class remains the same
57
  class Pix2PixDataset(torch.utils.data.Dataset):
58
  def __init__(self, ds, transform):