jayparmr commited on
Commit
fd5252e
·
1 Parent(s): 049a85c

Upload folder using huggingface_hub

Browse files
handler.py CHANGED
@@ -4,8 +4,8 @@ from pathlib import Path
4
  from typing import Any, Dict, List
5
 
6
  from inference import model_fn, predict_fn
7
- from internals.util.config import set_hf_cache_dir
8
- from internals.util.model_downloader import BaseModelDownloader
9
 
10
 
11
  class EndpointHandler:
@@ -13,27 +13,6 @@ class EndpointHandler:
13
  set_hf_cache_dir(Path.home() / ".cache" / "hf_cache")
14
  self.model_dir = path
15
 
16
- if os.path.exists(path + "/inference.json"):
17
- with open(path + "/inference.json", "r") as f:
18
- config = json.loads(f.read())
19
- if config.get("model_type") == "huggingface":
20
- self.model_dir = config["model_path"]
21
- if config.get("model_type") == "s3":
22
- s3_config = config["model_path"]["s3"]
23
- base_url = s3_config["base_url"]
24
-
25
- urls = [base_url + item for item in s3_config["paths"]]
26
- out_dir = Path.home() / ".cache" / "base_model"
27
- if out_dir.exists():
28
- print("Model already exist")
29
- else:
30
- print("Downloading model")
31
- BaseModelDownloader(
32
- urls, s3_config["paths"], out_dir
33
- ).download()
34
-
35
- self.model_dir = str(out_dir)
36
-
37
  return model_fn(self.model_dir)
38
 
39
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
 
4
  from typing import Any, Dict, List
5
 
6
  from inference import model_fn, predict_fn
7
+ from internals.util.config import set_hf_cache_dir, set_model_config
8
+ from internals.util.model_loader import load_model_from_config
9
 
10
 
11
  class EndpointHandler:
 
13
  set_hf_cache_dir(Path.home() / ".cache" / "hf_cache")
14
  self.model_dir = path
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  return model_fn(self.model_dir)
17
 
18
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
inference.py CHANGED
@@ -21,10 +21,11 @@ from internals.util.avatar import Avatar
21
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
22
  from internals.util.commons import download_image, upload_image, upload_images
23
  from internals.util.config import (get_model_dir, num_return_sequences,
24
- set_configs_from_task, set_model_dir,
25
  set_root_dir)
26
  from internals.util.failure_hander import FailureHandler
27
  from internals.util.lora_style import LoraStyle
 
28
  from internals.util.slack import Slack
29
 
30
  torch.backends.cudnn.benchmark = True
@@ -496,13 +497,14 @@ def load_model_by_task(task: Task):
496
  ):
497
  text2img_pipe.load(get_model_dir())
498
  img2img_pipe.create(text2img_pipe)
499
- inpainter.create(text2img_pipe)
500
  high_res.load(img2img_pipe)
501
 
502
  safety_checker.apply(text2img_pipe)
503
  safety_checker.apply(img2img_pipe)
 
504
  elif task.get_type() == TaskType.REPLACE_BG:
505
- replace_background.load(controlnet=controlnet, high_res=high_res)
506
  else:
507
  if task.get_type() == TaskType.TILE_UPSCALE:
508
  controlnet.load_tile_upscaler()
@@ -521,7 +523,8 @@ def load_model_by_task(task: Task):
521
  def model_fn(model_dir):
522
  print("Logs: model loaded .... starts")
523
 
524
- set_model_dir(model_dir)
 
525
  set_root_dir(__file__)
526
 
527
  FailureHandler.register()
 
21
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
22
  from internals.util.commons import download_image, upload_image, upload_images
23
  from internals.util.config import (get_model_dir, num_return_sequences,
24
+ set_configs_from_task, set_model_config,
25
  set_root_dir)
26
  from internals.util.failure_hander import FailureHandler
27
  from internals.util.lora_style import LoraStyle
28
+ from internals.util.model_loader import load_model_from_config
29
  from internals.util.slack import Slack
30
 
31
  torch.backends.cudnn.benchmark = True
 
497
  ):
498
  text2img_pipe.load(get_model_dir())
499
  img2img_pipe.create(text2img_pipe)
500
+ inpainter.load()
501
  high_res.load(img2img_pipe)
502
 
503
  safety_checker.apply(text2img_pipe)
504
  safety_checker.apply(img2img_pipe)
505
+ safety_checker.apply(inpainter)
506
  elif task.get_type() == TaskType.REPLACE_BG:
507
+ replace_background.load(inpainter=inpainter, high_res=high_res)
508
  else:
509
  if task.get_type() == TaskType.TILE_UPSCALE:
510
  controlnet.load_tile_upscaler()
 
523
  def model_fn(model_dir):
524
  print("Logs: model loaded .... starts")
525
 
526
+ config = load_model_from_config(model_dir)
527
+ set_model_config(config)
528
  set_root_dir(__file__)
529
 
530
  FailureHandler.register()
inference2.py CHANGED
@@ -23,9 +23,10 @@ from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
23
  from internals.util.commons import (construct_default_s3_url, upload_image,
24
  upload_images)
25
  from internals.util.config import (num_return_sequences, set_configs_from_task,
26
- set_model_dir, set_root_dir)
27
  from internals.util.failure_hander import FailureHandler
28
  from internals.util.lora_style import LoraStyle
 
29
  from internals.util.slack import Slack
30
 
31
  torch.backends.cudnn.benchmark = True
@@ -214,7 +215,8 @@ def upscale_image(task: Task):
214
  def model_fn(model_dir):
215
  print("Logs: model loaded .... starts")
216
 
217
- set_model_dir(model_dir)
 
218
  set_root_dir(__file__)
219
 
220
  FailureHandler.register()
 
23
  from internals.util.commons import (construct_default_s3_url, upload_image,
24
  upload_images)
25
  from internals.util.config import (num_return_sequences, set_configs_from_task,
26
+ set_model_config, set_root_dir)
27
  from internals.util.failure_hander import FailureHandler
28
  from internals.util.lora_style import LoraStyle
29
+ from internals.util.model_loader import load_model_from_config
30
  from internals.util.slack import Slack
31
 
32
  torch.backends.cudnn.benchmark = True
 
215
  def model_fn(model_dir):
216
  print("Logs: model loaded .... starts")
217
 
218
+ config = load_model_from_config(model_dir)
219
+ set_model_config(config)
220
  set_root_dir(__file__)
221
 
222
  FailureHandler.register()
internals/pipelines/controlnets.py CHANGED
@@ -4,15 +4,11 @@ import cv2
4
  import numpy as np
5
  import torch
6
  from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
7
- from diffusers import (
8
- ControlNetModel,
9
- DiffusionPipeline,
10
- StableDiffusionControlNetPipeline,
11
- UniPCMultistepScheduler,
12
- )
13
- from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import (
14
- MultiControlNetModel,
15
- )
16
  from PIL import Image
17
  from torch.nn import Linear
18
  from tqdm import gui
@@ -22,9 +18,8 @@ import internals.util.image as ImageUtil
22
  from external.midas import apply_midas
23
  from internals.data.result import Result
24
  from internals.pipelines.commons import AbstractPipeline
25
- from internals.pipelines.tileUpscalePipeline import (
26
- StableDiffusionControlNetImg2ImgPipeline,
27
- )
28
  from internals.util.cache import clear_cuda_and_gc
29
  from internals.util.commons import download_image
30
  from internals.util.config import get_hf_cache_dir, get_hf_token, get_model_dir
@@ -86,7 +81,7 @@ class ControlNet(AbstractPipeline):
86
  if self.__current_task_name == "pose":
87
  return
88
  pose = ControlNetModel.from_pretrained(
89
- "lllyasviel/sd-controlnet-openpose",
90
  torch_dtype=torch.float16,
91
  cache_dir=get_hf_cache_dir(),
92
  ).to("cuda")
 
4
  import numpy as np
5
  import torch
6
  from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
7
+ from diffusers import (ControlNetModel, DiffusionPipeline,
8
+ StableDiffusionControlNetPipeline,
9
+ UniPCMultistepScheduler)
10
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import \
11
+ MultiControlNetModel
 
 
 
 
12
  from PIL import Image
13
  from torch.nn import Linear
14
  from tqdm import gui
 
18
  from external.midas import apply_midas
19
  from internals.data.result import Result
20
  from internals.pipelines.commons import AbstractPipeline
21
+ from internals.pipelines.tileUpscalePipeline import \
22
+ StableDiffusionControlNetImg2ImgPipeline
 
23
  from internals.util.cache import clear_cuda_and_gc
24
  from internals.util.commons import download_image
25
  from internals.util.config import get_hf_cache_dir, get_hf_token, get_model_dir
 
81
  if self.__current_task_name == "pose":
82
  return
83
  pose = ControlNetModel.from_pretrained(
84
+ "lllyasviel/control_v11p_sd15_openpose",
85
  torch_dtype=torch.float16,
86
  cache_dir=get_hf_cache_dir(),
87
  ).to("cuda")
internals/pipelines/inpainter.py CHANGED
@@ -5,18 +5,28 @@ from diffusers import StableDiffusionInpaintPipeline
5
 
6
  from internals.pipelines.commons import AbstractPipeline
7
  from internals.util.commons import disable_safety_checker, download_image
8
- from internals.util.config import get_hf_cache_dir
 
9
 
10
 
11
  class InPainter(AbstractPipeline):
 
 
12
  def load(self):
 
 
 
13
  self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
14
- "jayparmr/icbinp_v8_inpaint_v2",
15
  torch_dtype=torch.float16,
16
  cache_dir=get_hf_cache_dir(),
 
17
  ).to("cuda")
 
18
  disable_safety_checker(self.pipe)
19
 
 
 
20
  def create(self, pipeline: AbstractPipeline):
21
  self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to(
22
  "cuda"
 
5
 
6
  from internals.pipelines.commons import AbstractPipeline
7
  from internals.util.commons import disable_safety_checker, download_image
8
+ from internals.util.config import (get_hf_cache_dir, get_hf_token,
9
+ get_inpaint_model_path)
10
 
11
 
12
  class InPainter(AbstractPipeline):
13
+ __loaded = False
14
+
15
  def load(self):
16
+ if self.__loaded:
17
+ return
18
+
19
  self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
20
+ get_inpaint_model_path(),
21
  torch_dtype=torch.float16,
22
  cache_dir=get_hf_cache_dir(),
23
+ use_auth_token=get_hf_token(),
24
  ).to("cuda")
25
+
26
  disable_safety_checker(self.pipe)
27
 
28
+ self.__loaded = True
29
+
30
  def create(self, pipeline: AbstractPipeline):
31
  self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to(
32
  "cuda"
internals/pipelines/replace_background.py CHANGED
@@ -2,6 +2,7 @@ from io import BytesIO
2
  from typing import List, Optional, Union
3
 
4
  import torch
 
5
  from diffusers import (ControlNetModel,
6
  StableDiffusionControlNetInpaintPipeline,
7
  StableDiffusionInpaintPipeline, UniPCMultistepScheduler)
@@ -12,10 +13,12 @@ from internals.data.result import Result
12
  from internals.pipelines.commons import AbstractPipeline
13
  from internals.pipelines.controlnets import ControlNet
14
  from internals.pipelines.high_res import HighRes
 
15
  from internals.pipelines.remove_background import RemoveBackgroundV2
16
  from internals.pipelines.upscaler import Upscaler
17
  from internals.util.commons import download_image
18
- from internals.util.config import get_hf_cache_dir, get_model_dir
 
19
 
20
 
21
  class ReplaceBackground(AbstractPipeline):
@@ -25,7 +28,7 @@ class ReplaceBackground(AbstractPipeline):
25
  self,
26
  upscaler: Optional[Upscaler] = None,
27
  remove_background: Optional[RemoveBackgroundV2] = None,
28
- controlnet: Optional[ControlNet] = None,
29
  high_res: Optional[HighRes] = None,
30
  ):
31
  if self.__loaded:
@@ -35,18 +38,19 @@ class ReplaceBackground(AbstractPipeline):
35
  torch_dtype=torch.float16,
36
  cache_dir=get_hf_cache_dir(),
37
  ).to("cuda")
38
- if controlnet:
39
- controlnet.load_linearart()
40
  pipe = StableDiffusionControlNetInpaintPipeline(
41
- **controlnet.pipe.components
 
42
  )
43
- pipe.controlnet = controlnet_model
44
  else:
45
  pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
46
  "runwayml/stable-diffusion-inpainting",
47
  controlnet=controlnet_model,
48
  torch_dtype=torch.float16,
49
  cache_dir=get_hf_cache_dir(),
 
50
  )
51
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
52
  pipe.to("cuda")
@@ -104,14 +108,14 @@ class ReplaceBackground(AbstractPipeline):
104
 
105
  print(width, height, n_width, n_height)
106
 
 
107
  if extend_object:
108
- condition_image = ControlNet.linearart_condition_image(image).resize(
109
- (n_width, n_height)
110
- )
111
  condition_image = ImageUtil.padd_image(condition_image, width, height)
112
  condition_image = condition_image.convert("RGB")
113
 
114
- image = image.resize((n_width, n_height))
115
  image = ImageUtil.padd_image(image, width, height)
116
 
117
  mask = image.copy()
@@ -130,46 +134,20 @@ class ReplaceBackground(AbstractPipeline):
130
  condition_image = ControlNet.linearart_condition_image(image)
131
  mask = mask.convert("RGB")
132
 
133
- if apply_high_res and hasattr(self, "high_res"):
134
- w, h = HighRes.get_intermediate_dimension(width, height)
135
- result = self.pipe.__call__(
136
- prompt=prompt,
137
- negative_prompt=negative_prompt,
138
- image=image,
139
- mask_image=mask,
140
- control_image=condition_image,
141
- controlnet_conditioning_scale=conditioning_scale,
142
- guidance_scale=9,
143
- strength=1,
144
- num_inference_steps=steps,
145
- height=w,
146
- width=h,
147
- )
148
- for i, _ in enumerate(result.images):
149
- out_bytes = self.upscaler.upscale(
150
- image=result.images[i],
151
- width=w,
152
- height=h,
153
- face_enhance=False,
154
- resize_dimension=max(width, height),
155
- )
156
- result.images[i] = Image.open(BytesIO(out_bytes)).convert("RGB")
157
- result = Result.from_result(result)
158
- else:
159
- result = self.pipe.__call__(
160
- prompt=prompt,
161
- negative_prompt=negative_prompt,
162
- image=image,
163
- mask_image=mask,
164
- control_image=condition_image,
165
- controlnet_conditioning_scale=conditioning_scale,
166
- guidance_scale=9,
167
- strength=1,
168
- height=height,
169
- num_inference_steps=steps,
170
- width=width,
171
- )
172
- result = Result.from_result(result)
173
 
174
  images, has_nsfw = result
175
 
 
2
  from typing import List, Optional, Union
3
 
4
  import torch
5
+ from cv2 import inpaint
6
  from diffusers import (ControlNetModel,
7
  StableDiffusionControlNetInpaintPipeline,
8
  StableDiffusionInpaintPipeline, UniPCMultistepScheduler)
 
13
  from internals.pipelines.commons import AbstractPipeline
14
  from internals.pipelines.controlnets import ControlNet
15
  from internals.pipelines.high_res import HighRes
16
+ from internals.pipelines.inpainter import InPainter
17
  from internals.pipelines.remove_background import RemoveBackgroundV2
18
  from internals.pipelines.upscaler import Upscaler
19
  from internals.util.commons import download_image
20
+ from internals.util.config import (get_hf_cache_dir, get_hf_token,
21
+ get_inpaint_model_path, get_model_dir)
22
 
23
 
24
  class ReplaceBackground(AbstractPipeline):
 
28
  self,
29
  upscaler: Optional[Upscaler] = None,
30
  remove_background: Optional[RemoveBackgroundV2] = None,
31
+ inpainter: Optional[InPainter] = None,
32
  high_res: Optional[HighRes] = None,
33
  ):
34
  if self.__loaded:
 
38
  torch_dtype=torch.float16,
39
  cache_dir=get_hf_cache_dir(),
40
  ).to("cuda")
41
+ if inpainter:
42
+ inpainter.load()
43
  pipe = StableDiffusionControlNetInpaintPipeline(
44
+ **inpainter.pipe.components,
45
+ controlnet=controlnet_model,
46
  )
 
47
  else:
48
  pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
49
  "runwayml/stable-diffusion-inpainting",
50
  controlnet=controlnet_model,
51
  torch_dtype=torch.float16,
52
  cache_dir=get_hf_cache_dir(),
53
+ use_auth_token=get_hf_token(),
54
  )
55
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
56
  pipe.to("cuda")
 
108
 
109
  print(width, height, n_width, n_height)
110
 
111
+ resolution = min(n_width, n_height)
112
  if extend_object:
113
+ condition_image = ControlNet.linearart_condition_image(image)
114
+ condition_image = ImageUtil.resize_image(condition_image, resolution)
 
115
  condition_image = ImageUtil.padd_image(condition_image, width, height)
116
  condition_image = condition_image.convert("RGB")
117
 
118
+ image = ImageUtil.resize_image(image, resolution)
119
  image = ImageUtil.padd_image(image, width, height)
120
 
121
  mask = image.copy()
 
134
  condition_image = ControlNet.linearart_condition_image(image)
135
  mask = mask.convert("RGB")
136
 
137
+ result = self.pipe.__call__(
138
+ prompt=prompt,
139
+ negative_prompt=negative_prompt,
140
+ image=image,
141
+ mask_image=mask,
142
+ control_image=condition_image,
143
+ controlnet_conditioning_scale=conditioning_scale,
144
+ guidance_scale=9,
145
+ strength=1,
146
+ height=height,
147
+ num_inference_steps=steps,
148
+ width=width,
149
+ )
150
+ result = Result.from_result(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  images, has_nsfw = result
153
 
internals/util/config.py CHANGED
@@ -3,13 +3,14 @@ from pathlib import Path
3
  from typing import Union
4
 
5
  from internals.data.task import Task
 
6
 
7
  env = "prod"
8
  nsfw_threshold = 0.0
9
  nsfw_access = False
10
  access_token = ""
11
  root_dir = ""
12
- model_dir = ""
13
  hf_token = "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"
14
  hf_cache_dir = "/tmp/hf_hub"
15
 
@@ -28,16 +29,16 @@ def get_hf_cache_dir():
28
  return hf_cache_dir
29
 
30
 
31
- def set_model_dir(dir: str):
32
- global model_dir
33
- model_dir = dir
34
-
35
-
36
  def set_root_dir(main_file: str):
37
  global root_dir
38
  root_dir = os.path.dirname(os.path.abspath(main_file))
39
 
40
 
 
 
 
 
 
41
  def set_configs_from_task(task: Task):
42
  global env, nsfw_threshold, nsfw_access, access_token
43
  name = task.get_queue_name()
@@ -51,8 +52,13 @@ def set_configs_from_task(task: Task):
51
 
52
 
53
  def get_model_dir():
54
- global model_dir
55
- return model_dir
 
 
 
 
 
56
 
57
 
58
  def get_root_dir():
 
3
  from typing import Union
4
 
5
  from internals.data.task import Task
6
+ from internals.util.model_loader import ModelConfig
7
 
8
  env = "prod"
9
  nsfw_threshold = 0.0
10
  nsfw_access = False
11
  access_token = ""
12
  root_dir = ""
13
+ model_config = None
14
  hf_token = "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"
15
  hf_cache_dir = "/tmp/hf_hub"
16
 
 
29
  return hf_cache_dir
30
 
31
 
 
 
 
 
 
32
  def set_root_dir(main_file: str):
33
  global root_dir
34
  root_dir = os.path.dirname(os.path.abspath(main_file))
35
 
36
 
37
+ def set_model_config(config: ModelConfig):
38
+ global model_config
39
+ model_config = config
40
+
41
+
42
  def set_configs_from_task(task: Task):
43
  global env, nsfw_threshold, nsfw_access, access_token
44
  name = task.get_queue_name()
 
52
 
53
 
54
  def get_model_dir():
55
+ global model_config
56
+ return model_config.base_model_path # pyright: ignore
57
+
58
+
59
+ def get_inpaint_model_path():
60
+ global model_config
61
+ return model_config.base_inpaint_model_path # pyright: ignore
62
 
63
 
64
  def get_root_dir():
internals/util/model_loader.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import shutil
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from threading import Thread
7
+ from typing import Any, Dict, List, Optional
8
+
9
+ import requests
10
+ from tqdm import tqdm
11
+
12
+
13
+ @dataclass
14
+ class ModelConfig:
15
+ base_model_path: str
16
+ base_inpaint_model_path: str
17
+
18
+
19
+ def load_model_from_config(path):
20
+ m_config = ModelConfig(path, path)
21
+ if os.path.exists(path + "/inference.json"):
22
+ with open(path + "/inference.json", "r") as f:
23
+ config = json.loads(f.read())
24
+ model_path = config.get("model_path", path)
25
+ inpaint_model_path = config.get("inpaint_model_path", path)
26
+
27
+ m_config.base_model_path = model_path
28
+ m_config.base_inpaint_model_path = inpaint_model_path
29
+
30
+ #
31
+ # if config.get("model_type") == "huggingface":
32
+ # model_dir = config["model_path"]
33
+ # if config.get("model_type") == "s3":
34
+ # s3_config = config["model_path"]["s3"]
35
+ # base_url = s3_config["base_url"]
36
+ #
37
+ # urls = [base_url + item for item in s3_config["paths"]]
38
+ # out_dir = Path.home() / ".cache" / "base_model"
39
+ # if out_dir.exists():
40
+ # print("Model already exist")
41
+ # else:
42
+ # print("Downloading model")
43
+ # BaseModelDownloader(urls, s3_config["paths"], out_dir).download()
44
+ # model_dir = str(out_dir)
45
+ return m_config
46
+
47
+
48
+ class BaseModelDownloader:
49
+ """
50
+ A utility for fast download of base model from S3 or any CDN served storage.
51
+ Works by downloading multiple files in parallel and dividing large files
52
+ into smaller chunks and combining them at the end.
53
+
54
+ Currently it uses multithreading (not multiprocessing) assuming GIL won't
55
+ interfere with network/disk IO.
56
+
57
+ Created by: KP
58
+ """
59
+
60
+ def __init__(self, urls: List[str], url_paths: List[str], out_dir: Path):
61
+ self.urls = urls
62
+ self.url_paths = url_paths
63
+ shutil.rmtree(out_dir, ignore_errors=True)
64
+ out_dir.mkdir(parents=True, exist_ok=True)
65
+ self.out_dir = out_dir
66
+
67
+ def download(self):
68
+ threads = []
69
+ batch_urls = {}
70
+
71
+ for url, url_path in zip(self.urls, self.url_paths):
72
+ out_dir = self.out_dir / url_path
73
+ self.out_dir.parent.mkdir(parents=True, exist_ok=True)
74
+ if url.endswith(".bin"):
75
+ if "unet/" in url_path:
76
+ thread = Thread(
77
+ target=self.__download_parallel, args=(url, out_dir, 6)
78
+ )
79
+ thread.start()
80
+ threads.append(thread)
81
+ else:
82
+ thread = Thread(
83
+ target=self.__download_files, args=([url], [out_dir])
84
+ )
85
+ thread.start()
86
+ threads.append(thread)
87
+ pass
88
+ else:
89
+ batch_urls[url] = out_dir
90
+
91
+ if batch_urls:
92
+ thread = Thread(
93
+ target=self.__download_files,
94
+ args=(list(batch_urls.keys()), list(batch_urls.values())),
95
+ )
96
+ thread.start()
97
+ threads.append(thread)
98
+ pass
99
+
100
+ for thread in threads:
101
+ thread.join()
102
+
103
+ def __download_parallel(self, url, output_filename, num_parts=4):
104
+ response = requests.head(url)
105
+ total_size = int(response.headers.get("content-length", 0))
106
+ print("total_size", total_size)
107
+
108
+ chunk_size = total_size // num_parts
109
+ ranges = [
110
+ (i * chunk_size, (i + 1) * chunk_size - 1) for i in range(num_parts - 1)
111
+ ]
112
+ ranges.append((ranges[-1][1] + 1, total_size))
113
+
114
+ print(ranges)
115
+
116
+ save_dir = Path.home() / ".cache" / "download_parts"
117
+ os.makedirs(save_dir, exist_ok=True)
118
+
119
+ threads = []
120
+ for i, (start, end) in enumerate(ranges):
121
+ thread = Thread(
122
+ target=self.__download_part, args=(url, start, end, i, save_dir)
123
+ )
124
+ thread.start()
125
+ threads.append(thread)
126
+
127
+ for thread in threads:
128
+ thread.join()
129
+
130
+ self.__combine_parts(save_dir, output_filename, num_parts)
131
+ os.rmdir(save_dir)
132
+
133
+ def __combine_parts(self, save_dir, output_filename, num_parts):
134
+ part_files = [os.path.join(save_dir, f"part_{i}.tmp") for i in range(num_parts)]
135
+
136
+ output_filename.parent.mkdir(parents=True, exist_ok=True)
137
+ with open(output_filename, "wb") as output_file:
138
+ for part_file in part_files:
139
+ print("combining: ", part_file)
140
+ with open(part_file, "rb") as part:
141
+ output_file.write(part.read())
142
+
143
+ out_file_size = output_file.tell()
144
+ print("out_file_size", out_file_size)
145
+
146
+ for part_file in part_files:
147
+ os.remove(part_file)
148
+
149
+ def __download_part(self, url, start_byte, end_byte, part_num, save_dir):
150
+ headers = {"Range": f"bytes={start_byte}-{end_byte}"}
151
+ response = requests.get(url, headers=headers, stream=True)
152
+
153
+ part_filename = os.path.join(save_dir, f"part_{part_num}.tmp")
154
+ print("Downloading part: ", url, part_filename, end_byte - start_byte)
155
+
156
+ with open(part_filename, "wb") as part_file, tqdm(
157
+ desc=str(part_filename),
158
+ total=end_byte - start_byte,
159
+ unit="B",
160
+ unit_scale=True,
161
+ unit_divisor=1024,
162
+ ) as bar:
163
+ for chunk in response.iter_content(chunk_size=8192):
164
+ if chunk:
165
+ size = part_file.write(chunk)
166
+ bar.update(size)
167
+
168
+ return part_filename
169
+
170
+ def __download_files(self, urls, out_paths: List[Path]):
171
+ for url, out_path in zip(urls, out_paths):
172
+ out_path.parent.mkdir(parents=True, exist_ok=True)
173
+ with requests.get(url, stream=True) as r:
174
+ print("Downloading: ", url)
175
+ total_size = int(r.headers.get("content-length", 0))
176
+ chunk_size = 8192
177
+ r.raise_for_status()
178
+ with open(out_path, "wb") as f, tqdm(
179
+ desc=str(out_path),
180
+ total=total_size,
181
+ unit="B",
182
+ unit_scale=True,
183
+ unit_divisor=1024,
184
+ ) as bar:
185
+ for data in r.iter_content(chunk_size=chunk_size):
186
+ size = f.write(data)
187
+ bar.update(size)
requirements.txt CHANGED
@@ -5,7 +5,7 @@ fastapi==0.87.0
5
  Pillow==9.3.0
6
  redis==4.3.4
7
  requests==2.28.1
8
- transformers
9
  rembg==2.0.30
10
  gfpgan==1.3.8
11
  rembg==2.0.30
 
5
  Pillow==9.3.0
6
  redis==4.3.4
7
  requests==2.28.1
8
+ transformers==4.34.1
9
  rembg==2.0.30
10
  gfpgan==1.3.8
11
  rembg==2.0.30