K00B404 commited on
Commit
b612bcb
·
verified ·
1 Parent(s): 3783d54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -3
app.py CHANGED
@@ -66,10 +66,99 @@ class Pix2PixDataset(torch.utils.data.Dataset):
66
  target = target_img.convert('RGB')
67
  return self.transform(original), self.transform(target)
68
 
69
- # UNetWrapper class remains the same
70
  class UNetWrapper:
71
- # ... [Previous UNetWrapper implementation remains unchanged]
72
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  def prepare_input(image, device='cpu'):
75
  """Prepare image for inference"""
 
66
  target = target_img.convert('RGB')
67
  return self.transform(original), self.transform(target)
68
 
 
69
  class UNetWrapper:
70
+ def __init__(self, unet_model, repo_id):
71
+ self.model = unet_model
72
+ self.repo_id = repo_id
73
+ self.token = os.getenv('NEW_TOKEN') # Make sure this environment variable is set
74
+ self.api = HfApi(token=os.getenv('NEW_TOKEN'))
75
+
76
+ def push_to_hub(self):
77
+ try:
78
+ # Save model state and configuration
79
+ save_dict = {
80
+ 'model_state_dict': self.model.state_dict(),
81
+ 'model_config': {
82
+ 'big': isinstance(self.model, big_UNet),
83
+ 'img_size': 1024 if isinstance(self.model, big_UNet) else 256
84
+ },
85
+ 'model_architecture': str(self.model)
86
+ }
87
+
88
+ # Save model locally
89
+ pth_name = 'model_weights.pth'
90
+ torch.save(save_dict, pth_name)
91
+
92
+ # Create repo if it doesn't exist
93
+ try:
94
+ create_repo(
95
+ repo_id=self.repo_id,
96
+ token=self.token,
97
+ exist_ok=True
98
+ )
99
+ except Exception as e:
100
+ print(f"Repository creation note: {e}")
101
+
102
+ # Upload the model file
103
+ self.api.upload_file(
104
+ path_or_fileobj=pth_name,
105
+ path_in_repo=pth_name,
106
+ repo_id=self.repo_id,
107
+ token=self.token,
108
+ repo_type="model"
109
+ )
110
+
111
+ # Create and upload model card
112
+ model_card = f"""---
113
+ tags:
114
+ - unet
115
+ - pix2pix
116
+ library_name: pytorch
117
+ ---
118
+
119
+ # Pix2Pix UNet Model
120
+
121
+ ## Model Description
122
+ Custom UNet model for Pix2Pix image translation.
123
+ - Image Size: {1024 if isinstance(self.model, big_UNet) else 256}
124
+ - Model Type: {"Big (1024)" if isinstance(self.model, big_UNet) else "Small (256)"}
125
+
126
+ ## Usage
127
+
128
+ ```python
129
+ import torch
130
+ from small_256_model import UNet as small_UNet
131
+ from big_1024_model import UNet as big_UNet
132
+
133
+ # Load the model
134
+ checkpoint = torch.load('model_weights.pth')
135
+ model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
136
+ model.load_state_dict(checkpoint['model_state_dict'])
137
+ model.eval()
138
+ Model Architecture
139
+ {str(self.model)}
140
+ """
141
+ # Save and upload README
142
+ with open("README.md", "w") as f:
143
+ f.write(model_card)
144
+
145
+ self.api.upload_file(
146
+ path_or_fileobj="README.md",
147
+ path_in_repo="README.md",
148
+ repo_id=self.repo_id,
149
+ token=self.token,
150
+ repo_type="model"
151
+ )
152
+
153
+ # Clean up local files
154
+ os.remove(pth_name)
155
+ os.remove("README.md")
156
+
157
+ print(f"Model successfully uploaded to {self.repo_id}")
158
+
159
+ except Exception as e:
160
+ print(f"Error uploading model: {e}")
161
+
162
 
163
  def prepare_input(image, device='cpu'):
164
  """Prepare image for inference"""