Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -66,10 +66,99 @@ class Pix2PixDataset(torch.utils.data.Dataset):
|
|
66 |
target = target_img.convert('RGB')
|
67 |
return self.transform(original), self.transform(target)
|
68 |
|
69 |
-
# UNetWrapper class remains the same
|
70 |
class UNetWrapper:
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
def prepare_input(image, device='cpu'):
|
75 |
"""Prepare image for inference"""
|
|
|
66 |
target = target_img.convert('RGB')
|
67 |
return self.transform(original), self.transform(target)
|
68 |
|
|
|
69 |
class UNetWrapper:
|
70 |
+
def __init__(self, unet_model, repo_id):
|
71 |
+
self.model = unet_model
|
72 |
+
self.repo_id = repo_id
|
73 |
+
self.token = os.getenv('NEW_TOKEN') # Make sure this environment variable is set
|
74 |
+
self.api = HfApi(token=os.getenv('NEW_TOKEN'))
|
75 |
+
|
76 |
+
def push_to_hub(self):
|
77 |
+
try:
|
78 |
+
# Save model state and configuration
|
79 |
+
save_dict = {
|
80 |
+
'model_state_dict': self.model.state_dict(),
|
81 |
+
'model_config': {
|
82 |
+
'big': isinstance(self.model, big_UNet),
|
83 |
+
'img_size': 1024 if isinstance(self.model, big_UNet) else 256
|
84 |
+
},
|
85 |
+
'model_architecture': str(self.model)
|
86 |
+
}
|
87 |
+
|
88 |
+
# Save model locally
|
89 |
+
pth_name = 'model_weights.pth'
|
90 |
+
torch.save(save_dict, pth_name)
|
91 |
+
|
92 |
+
# Create repo if it doesn't exist
|
93 |
+
try:
|
94 |
+
create_repo(
|
95 |
+
repo_id=self.repo_id,
|
96 |
+
token=self.token,
|
97 |
+
exist_ok=True
|
98 |
+
)
|
99 |
+
except Exception as e:
|
100 |
+
print(f"Repository creation note: {e}")
|
101 |
+
|
102 |
+
# Upload the model file
|
103 |
+
self.api.upload_file(
|
104 |
+
path_or_fileobj=pth_name,
|
105 |
+
path_in_repo=pth_name,
|
106 |
+
repo_id=self.repo_id,
|
107 |
+
token=self.token,
|
108 |
+
repo_type="model"
|
109 |
+
)
|
110 |
+
|
111 |
+
# Create and upload model card
|
112 |
+
model_card = f"""---
|
113 |
+
tags:
|
114 |
+
- unet
|
115 |
+
- pix2pix
|
116 |
+
library_name: pytorch
|
117 |
+
---
|
118 |
+
|
119 |
+
# Pix2Pix UNet Model
|
120 |
+
|
121 |
+
## Model Description
|
122 |
+
Custom UNet model for Pix2Pix image translation.
|
123 |
+
- Image Size: {1024 if isinstance(self.model, big_UNet) else 256}
|
124 |
+
- Model Type: {"Big (1024)" if isinstance(self.model, big_UNet) else "Small (256)"}
|
125 |
+
|
126 |
+
## Usage
|
127 |
+
|
128 |
+
```python
|
129 |
+
import torch
|
130 |
+
from small_256_model import UNet as small_UNet
|
131 |
+
from big_1024_model import UNet as big_UNet
|
132 |
+
|
133 |
+
# Load the model
|
134 |
+
checkpoint = torch.load('model_weights.pth')
|
135 |
+
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
|
136 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
137 |
+
model.eval()
|
138 |
+
Model Architecture
|
139 |
+
{str(self.model)}
|
140 |
+
"""
|
141 |
+
# Save and upload README
|
142 |
+
with open("README.md", "w") as f:
|
143 |
+
f.write(model_card)
|
144 |
+
|
145 |
+
self.api.upload_file(
|
146 |
+
path_or_fileobj="README.md",
|
147 |
+
path_in_repo="README.md",
|
148 |
+
repo_id=self.repo_id,
|
149 |
+
token=self.token,
|
150 |
+
repo_type="model"
|
151 |
+
)
|
152 |
+
|
153 |
+
# Clean up local files
|
154 |
+
os.remove(pth_name)
|
155 |
+
os.remove("README.md")
|
156 |
+
|
157 |
+
print(f"Model successfully uploaded to {self.repo_id}")
|
158 |
+
|
159 |
+
except Exception as e:
|
160 |
+
print(f"Error uploading model: {e}")
|
161 |
+
|
162 |
|
163 |
def prepare_input(image, device='cpu'):
|
164 |
"""Prepare image for inference"""
|