meepmoo commited on
Commit
08d9799
1 Parent(s): 4e7576b

Update worker_runpod.py

Browse files
Files changed (1) hide show
  1. worker_runpod.py +100 -161
worker_runpod.py CHANGED
@@ -1,170 +1,109 @@
1
- import json
2
- import os
3
- import runpod
4
- import numpy as np
5
  import torch
6
- import requests
7
- import uuid
8
- from diffusers import (AutoencoderKL, CogVideoXDDIMScheduler, DDIMScheduler,
9
- DPMSolverMultistepScheduler,
10
- EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
11
- PNDMScheduler)
12
- from transformers import T5EncoderModel, T5Tokenizer
13
- from omegaconf import OmegaConf
14
- from PIL import Image
15
- from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
16
- from cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX
17
- from cogvideox.pipeline.pipeline_cogvideox import CogVideoX_Fun_Pipeline
18
- from cogvideox.pipeline.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
19
  from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
20
- from cogvideox.utils.utils import get_image_to_video_latent, save_videos_grid
21
- from cogvideox.data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
22
- from huggingface_hub import HfApi, HfFolder
23
-
24
- tokenxf = os.getenv("HF_API_TOKEN")
25
- # Low GPU memory mode
26
- low_gpu_memory_mode = False
27
- lora_path = "/content/shirtlift.safetensors"
28
-
29
- def to_pil(image):
30
- if isinstance(image, Image.Image):
31
- return image
32
- if isinstance(image, torch.Tensor):
33
- return tensor2pil(image)
34
- if isinstance(image, np.ndarray):
35
- return numpy2pil(image)
36
- raise ValueError(f"Cannot convert {type(image)} to PIL.Image")
37
-
38
-
39
- def download_image(url, download_dir="asset"):
40
- # Ensure the download directory exists
41
- if not os.path.exists(download_dir):
42
- os.makedirs(download_dir, exist_ok=True)
43
-
44
- # Send the request and check for successful response
45
- response = requests.get(url, stream=True)
46
- if response.status_code == 200:
47
- # Determine file extension based on content type
48
- content_type = response.headers.get("Content-Type")
49
- if content_type == "image/png":
50
- ext = "png"
51
- elif content_type == "image/jpeg":
52
- ext = "jpg"
53
- else:
54
- ext = "jpg" # default to .jpg if content type is unrecognized
55
-
56
- # Generate a random filename with the correct extension
57
- filename = f"{uuid.uuid4().hex}.{ext}"
58
- file_path = os.path.join(download_dir, filename)
59
-
60
- # Save the image
61
- with open(file_path, "wb") as f:
62
- for chunk in response.iter_content(1024):
63
- f.write(chunk)
64
-
65
- print(f"Image downloaded to {file_path}")
66
- return file_path
67
- else:
68
- raise Exception(f"Failed to download image from {url}, status code: {response.status_code}")
69
-
70
- # Usage
71
- # validation_image_start = values.get("validation_image_start", "https://example.com/path/to/image.png")
72
- # downloaded_image_path = download_image(validation_image_start)
73
- model_id = "/content/model"
74
- transformer = CogVideoXTransformer3DModel.from_pretrained_2d(model_id, subfolder="transformer").to(torch.bfloat16)
75
-
76
- vae = AutoencoderKLCogVideoX.from_pretrained(model_id, subfolder="vae").to(torch.bfloat16)
77
-
78
- text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder")
79
-
80
- sampler_dict = {
81
- "Euler": EulerDiscreteScheduler,
82
- "Euler A": EulerAncestralDiscreteScheduler,
83
- "DPM++": DPMSolverMultistepScheduler,
84
- "PNDM": PNDMScheduler,
85
- "DDIM_Cog": CogVideoXDDIMScheduler,
86
- "DDIM_Origin": DDIMScheduler,
87
- }
88
- scheduler = sampler_dict["DPM++"].from_pretrained(model_id, subfolder="scheduler")
89
-
90
- # Pipeline setup
91
- if transformer.config.in_channels != vae.config.latent_channels:
92
- pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
93
- model_id, vae=vae, text_encoder=text_encoder,
94
- transformer=transformer, scheduler=scheduler,
95
- torch_dtype=torch.bfloat16
96
- )
97
- else:
98
- pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
99
- model_id, vae=vae, text_encoder=text_encoder,
100
- transformer=transformer, scheduler=scheduler,
101
- torch_dtype=torch.bfloat16
102
- )
103
 
104
- if low_gpu_memory_mode:
105
- pipeline.enable_sequential_cpu_offload()
106
- else:
107
- pipeline.enable_model_cpu_offload()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  @torch.inference_mode()
110
  def generate(input):
111
  values = input["input"]
112
- prompt = values["prompt"]
113
- negative_prompt = values.get("negative_prompt", "blurry, blurred, blurry face")
114
- guidance_scale = values.get("guidance_scale", 6.0)
115
- seed = values.get("seed", 42)
116
- num_inference_steps = values.get("num_inference_steps", 18)
117
- base_resolution = values.get("base_resolution", 512)
118
-
119
- video_length = values.get("video_length", 53)
120
- fps = values.get("fps", 10)
121
- lora_weight = values.get("lora_weight", 1.00)
122
- save_path = "samples"
123
- partial_video_length = values.get("partial_video_length", None)
124
- overlap_video_length = values.get("overlap_video_length", 4)
125
- validation_image_start = values.get("validation_image_start", "asset/1.png")
126
- downloaded_image_path = download_image(validation_image_start)
127
- validation_image_end = values.get("validation_image_end", None)
128
-
129
- generator = torch.Generator(device="cuda").manual_seed(seed)
130
- if lora_path is not None:
131
- pipeline = merge_lora(pipeline, lora_path, lora_weight)
132
-
133
- aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
134
- start_img = Image.open(downloaded_image_path)
135
- original_width, original_height = start_img.size
136
- closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
137
- height, width = [int(x / 16) * 16 for x in closest_size]
138
- sample_size = [height, width]
139
- video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
140
- input_video, input_video_mask, clip_image = get_image_to_video_latent(downloaded_image_path, validation_image_end, video_length=video_length, sample_size=sample_size)
141
-
142
- with torch.no_grad():
143
- sample = pipeline(prompt=prompt,num_frames=video_length,negative_prompt=negative_prompt,height=sample_size[0],width=sample_size[1],generator=generator,guidance_scale=guidance_scale,num_inference_steps=num_inference_steps,video=input_video,mask_video=input_video_mask).videos
144
-
145
- if not os.path.exists(save_path):
146
- os.makedirs(save_path, exist_ok=True)
147
-
148
- index = len([path for path in os.listdir(save_path)]) + 1
149
- prefix = str(index).zfill(8)
150
- video_path = os.path.join(save_path, f"{prefix}.mp4")
151
- save_videos_grid(sample, video_path, fps=fps)
152
-
153
-
154
- hf_api = HfApi()
155
- repo_id = "meepmoo/h4h4jejdf" # Set your HF repo
156
- hf_api.upload_file(
157
- path_or_fileobj=video_path,
158
- path_in_repo=f"{prefix}.mp4",
159
- repo_id=repo_id,
160
- token=tokenxf,
161
- repo_type="model"
162
- )
163
-
164
-
165
- result_url = f"https://huggingface.co/{repo_id}/blob/main/{prefix}.mp4"
166
- result_url = ""
167
- job_id = values.get("job_id", "default-job-id") # For RunPod job tracking
168
- return {"jobId": job_id, "result": result_url, "status": "DONE"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  runpod.serverless.start({"handler": generate})
 
1
+ import os, json, requests, random, runpod
2
+
 
 
3
  import torch
4
+ from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel
 
 
 
 
 
 
 
 
 
 
 
 
5
  from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
6
+ from diffusers.utils import export_to_video, load_image
7
+ from transformers import T5EncoderModel, T5Tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ with torch.inference_mode():
10
+ model_id = "/content/model"
11
+ transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float16)
12
+ text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.float16)
13
+ vae = AutoencoderKLCogVideoX.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float16)
14
+ tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer")
15
+ pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_id, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, torch_dtype=torch.float16).to("cuda")
16
+ # pipe.enable_model_cpu_offload()
17
+
18
+ def download_file(url, save_dir, file_name):
19
+ os.makedirs(save_dir, exist_ok=True)
20
+ original_file_name = url.split('/')[-1]
21
+ _, original_file_extension = os.path.splitext(original_file_name)
22
+ file_path = os.path.join(save_dir, file_name + original_file_extension)
23
+ response = requests.get(url)
24
+ response.raise_for_status()
25
+ with open(file_path, 'wb') as file:
26
+ file.write(response.content)
27
+ return file_path
28
 
29
  @torch.inference_mode()
30
  def generate(input):
31
  values = input["input"]
32
+ lora_path = "/content/shirtlift.safetensors"
33
+ lora_weight = 1.0
34
+ pipe = merge_lora(pipe, lora_path, lora_weight)
35
+ input_image = values['input_image_check']
36
+ input_image = download_file(url=input_image, save_dir='/content/input', file_name='input_image_tost')
37
+ prompt = values['prompt']
38
+ # guidance_scale = values['guidance_scale']
39
+ # use_dynamic_cfg = values['use_dynamic_cfg']
40
+ # num_inference_steps = values['num_inference_steps']
41
+ # fps = values['fps']
42
+ guidance_scale = 6
43
+ use_dynamic_cfg = True
44
+ num_inference_steps = 50
45
+ fps = 8
46
+
47
+ image = load_image(input_image)
48
+ video = pipe(image=image, prompt=prompt, guidance_scale=guidance_scale, use_dynamic_cfg=use_dynamic_cfg, num_inference_steps=num_inference_steps).frames[0]
49
+ export_to_video(video, "/content/cogvideox_5b_i2v_tost.mp4", fps=fps)
50
+
51
+ result = "/content/cogvideox_5b_i2v_tost.mp4"
52
+ try:
53
+ notify_uri = values['notify_uri']
54
+ del values['notify_uri']
55
+ notify_token = values['notify_token']
56
+ del values['notify_token']
57
+ discord_id = values['discord_id']
58
+ del values['discord_id']
59
+ if(discord_id == "discord_id"):
60
+ discord_id = os.getenv('com_camenduru_discord_id')
61
+ discord_channel = values['discord_channel']
62
+ del values['discord_channel']
63
+ if(discord_channel == "discord_channel"):
64
+ discord_channel = os.getenv('com_camenduru_discord_channel')
65
+ discord_token = values['discord_token']
66
+ del values['discord_token']
67
+ if(discord_token == "discord_token"):
68
+ discord_token = os.getenv('com_camenduru_discord_token')
69
+ job_id = values['job_id']
70
+ del values['job_id']
71
+ default_filename = os.path.basename(result)
72
+ with open(result, "rb") as file:
73
+ files = {default_filename: file.read()}
74
+ payload = {"content": f"{json.dumps(values)} <@{discord_id}>"}
75
+ response = requests.post(
76
+ f"https://discord.com/api/v9/channels/{discord_channel}/messages",
77
+ data=payload,
78
+ headers={"Authorization": f"Bot {discord_token}"},
79
+ files=files
80
+ )
81
+ response.raise_for_status()
82
+ result_url = response.json()['attachments'][0]['url']
83
+ notify_payload = {"jobId": job_id, "result": result_url, "status": "DONE"}
84
+ web_notify_uri = os.getenv('com_camenduru_web_notify_uri')
85
+ web_notify_token = os.getenv('com_camenduru_web_notify_token')
86
+ if(notify_uri == "notify_uri"):
87
+ requests.post(web_notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
88
+ else:
89
+ requests.post(web_notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
90
+ requests.post(notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": notify_token})
91
+ return {"jobId": job_id, "result": result_url, "status": "DONE"}
92
+ except Exception as e:
93
+ error_payload = {"jobId": job_id, "status": "FAILED"}
94
+ try:
95
+ if(notify_uri == "notify_uri"):
96
+ requests.post(web_notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
97
+ else:
98
+ requests.post(web_notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
99
+ requests.post(notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": notify_token})
100
+ except:
101
+ pass
102
+ return {"jobId": job_id, "result": f"FAILED: {str(e)}", "status": "FAILED"}
103
+ finally:
104
+ if os.path.exists(result):
105
+ os.remove(result)
106
+ if os.path.exists(input_image):
107
+ os.remove(input_image)
108
 
109
  runpod.serverless.start({"handler": generate})