Update worker_runpod.py
Browse files- worker_runpod.py +4 -1
worker_runpod.py
CHANGED
@@ -23,7 +23,7 @@ from huggingface_hub import HfApi, HfFolder
|
|
23 |
tokenxf = os.getenv("HF_API_TOKEN")
|
24 |
# Low GPU memory mode
|
25 |
low_gpu_memory_mode = False
|
26 |
-
|
27 |
def download_image(url, download_dir="asset"):
|
28 |
# Ensure the download directory exists
|
29 |
if not os.path.exists(download_dir):
|
@@ -121,6 +121,9 @@ def generate(input):
|
|
121 |
validation_image_end = values.get("validation_image_end", None)
|
122 |
|
123 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
|
|
|
|
|
|
124 |
aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
|
125 |
start_img = Image.open(downloaded_image_path)
|
126 |
original_width, original_height = start_img.size
|
|
|
23 |
tokenxf = os.getenv("HF_API_TOKEN")
|
24 |
# Low GPU memory mode
|
25 |
low_gpu_memory_mode = False
|
26 |
+
lora_path = "/content/shirtlift.safetensors"
|
27 |
def download_image(url, download_dir="asset"):
|
28 |
# Ensure the download directory exists
|
29 |
if not os.path.exists(download_dir):
|
|
|
121 |
validation_image_end = values.get("validation_image_end", None)
|
122 |
|
123 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
124 |
+
if lora_path is not None:
|
125 |
+
pipeline = merge_lora(pipeline, lora_path, lora_weight)
|
126 |
+
|
127 |
aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
|
128 |
start_img = Image.open(downloaded_image_path)
|
129 |
original_width, original_height = start_img.size
|