pan-yl commited on
Commit
6cad0b7
β€’
1 Parent(s): 18aff81
Files changed (5) hide show
  1. .gitignore +20 -0
  2. app.py +6 -8
  3. config/chatbot_ui.yaml +3 -1
  4. infer.py +38 -7
  5. requirements.txt +4 -2
.gitignore ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pyc
2
+ *.pth
3
+ *.pt
4
+ *.pkl
5
+ *.ckpt
6
+ *.DS_Store
7
+ *__pycache__*
8
+ *.cache*
9
+ *.bin
10
+ *.idea
11
+ *.csv
12
+ cache
13
+ build
14
+ dist
15
+ dev
16
+ scepter.egg-info
17
+ .readthedocs.yml
18
+ *resources
19
+ *.ipynb_checkpoints*
20
+ *.vscode
app.py CHANGED
@@ -27,9 +27,9 @@ from scepter.modules.utils.directory import get_md5
27
  from scepter.modules.utils.file_system import FS
28
  from scepter.studio.utils.env import init_env
29
 
30
- from .infer import ACEInference
31
- from .example import get_examples
32
- from .utils import load_image
33
 
34
 
35
  refresh_sty = '\U0001f504' # πŸ”„
@@ -44,10 +44,9 @@ lock = threading.Lock()
44
 
45
  class ChatBotUI(object):
46
  def __init__(self,
47
- cfg_general_file,
48
  root_work_dir='./'):
49
 
50
- cfg = Config(cfg_file=cfg_general_file)
51
  cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR)
52
  if not FS.exists(cfg.WORK_DIR):
53
  FS.make_dir(cfg.WORK_DIR)
@@ -1189,11 +1188,10 @@ class ChatBotUI(object):
1189
 
1190
 
1191
  if __name__ == '__main__':
1192
- cfg = Config(cfg_file="config/chatbot_ui.yaml")
1193
 
 
1194
  with gr.Blocks() as demo:
1195
  chatbot = ChatBotUI(cfg)
1196
- chatbot.create_bot_ui()
1197
  chatbot.set_callbacks()
1198
-
1199
  demo.launch()
 
27
  from scepter.modules.utils.file_system import FS
28
  from scepter.studio.utils.env import init_env
29
 
30
+ from infer import ACEInference
31
+ from example import get_examples
32
+ from utils import load_image
33
 
34
 
35
  refresh_sty = '\U0001f504' # πŸ”„
 
44
 
45
  class ChatBotUI(object):
46
  def __init__(self,
47
+ cfg,
48
  root_work_dir='./'):
49
 
 
50
  cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR)
51
  if not FS.exists(cfg.WORK_DIR):
52
  FS.make_dir(cfg.WORK_DIR)
 
1188
 
1189
 
1190
  if __name__ == '__main__':
 
1191
 
1192
+ cfg = Config(cfg_file="config/chatbot_ui.yaml")
1193
  with gr.Blocks() as demo:
1194
  chatbot = ChatBotUI(cfg)
1195
+ chatbot.create_ui()
1196
  chatbot.set_callbacks()
 
1197
  demo.launch()
config/chatbot_ui.yaml CHANGED
@@ -2,10 +2,12 @@ WORK_DIR: ./cache/chatbot
2
  FILE_SYSTEM:
3
  - NAME: LocalFs
4
  TEMP_DIR: ./cache
5
- - NAME: ModelscopeFs
6
  TEMP_DIR: ./cache
7
  - NAME: HuggingfaceFs
8
  TEMP_DIR: ./cache
 
 
9
  #
10
  ENABLE_I2V: False
11
  #
 
2
  FILE_SYSTEM:
3
  - NAME: LocalFs
4
  TEMP_DIR: ./cache
5
+ - NAME: HttpFs
6
  TEMP_DIR: ./cache
7
  - NAME: HuggingfaceFs
8
  TEMP_DIR: ./cache
9
+ - NAME: ModelscopeFs
10
+ TEMP_DIR: ./cache
11
  #
12
  ENABLE_I2V: False
13
  #
infer.py CHANGED
@@ -12,17 +12,48 @@ import torch.nn.functional as F
12
  import torchvision.transforms.functional as TF
13
 
14
  from scepter.modules.model.registry import DIFFUSIONS
15
- from scepter.modules.model.utils.basic_utils import (
16
- check_list_of_list,
17
- pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor,
18
- to_device,
19
- unpack_tensor_into_imagelist
20
- )
21
  from scepter.modules.utils.distribute import we
22
  from scepter.modules.utils.logger import get_logger
23
-
24
  from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def process_edit_image(images,
28
  masks,
 
12
  import torchvision.transforms.functional as TF
13
 
14
  from scepter.modules.model.registry import DIFFUSIONS
 
 
 
 
 
 
15
  from scepter.modules.utils.distribute import we
16
  from scepter.modules.utils.logger import get_logger
 
17
  from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
18
 
19
+ def check_list_of_list(ll):
20
+ return isinstance(ll, list) and all(isinstance(i, list) for i in ll)
21
+
22
+ def pack_imagelist_into_tensor(image_list):
23
+ # allow None
24
+ example = None
25
+ image_tensor, shapes = [], []
26
+ for img in image_list:
27
+ if img is None:
28
+ example = find_example(image_tensor,
29
+ image_list) if example is None else example
30
+ image_tensor.append(example)
31
+ shapes.append(None)
32
+ continue
33
+ _, c, h, w = img.size()
34
+ image_tensor.append(img.view(c, h * w).transpose(1, 0)) # h*w, c
35
+ shapes.append((h, w))
36
+
37
+ image_tensor = pad_sequence(image_tensor,
38
+ batch_first=True).permute(0, 2, 1) # b, c, l
39
+ return image_tensor, shapes
40
+
41
+ def to_device(inputs, strict=True):
42
+ if inputs is None:
43
+ return None
44
+ if strict:
45
+ assert all(isinstance(i, torch.Tensor) for i in inputs)
46
+ return [i.to(we.device_id) if i is not None else None for i in inputs]
47
+
48
+
49
+ def unpack_tensor_into_imagelist(image_tensor, shapes):
50
+ image_list = []
51
+ for img, shape in zip(image_tensor, shapes):
52
+ h, w = shape[0], shape[1]
53
+ image_list.append(img[:, :h * w].view(1, -1, h, w))
54
+
55
+ return image_list
56
+
57
 
58
  def process_edit_image(images,
59
  masks,
requirements.txt CHANGED
@@ -1,5 +1,7 @@
1
- huggingface_hub==0.25.2
2
  scepter>=1.1.0
3
  diffusers
4
  transformers
5
- gradio_imageslider
 
 
 
1
+ huggingface_hub
2
  scepter>=1.1.0
3
  diffusers
4
  transformers
5
+ gradio_imageslider
6
+ torch>=2.4.0
7
+ torchvision