Artiprocher commited on
Commit
e82be26
1 Parent(s): bd932b0

support new model

Browse files
Files changed (3) hide show
  1. LdmZhPipeline.py +0 -1036
  2. app.py +10 -19
  3. requirements.txt +1 -1
LdmZhPipeline.py DELETED
@@ -1,1036 +0,0 @@
1
- # coding=utf-8
2
-
3
- import importlib
4
- import inspect
5
- import os
6
- from dataclasses import dataclass
7
- from typing import Any, Dict, List, Optional, Union
8
- from collections import OrderedDict
9
-
10
- import numpy as np
11
- import torch
12
- from torch import nn
13
- import functools
14
-
15
- import diffusers
16
- import PIL
17
- from accelerate.utils.versions import is_torch_version
18
- from huggingface_hub import snapshot_download
19
- from packaging import version
20
- from PIL import Image
21
- from tqdm.auto import tqdm
22
-
23
- from diffusers.configuration_utils import ConfigMixin, register_to_config
24
- from diffusers.dynamic_modules_utils import get_class_from_dynamic_module
25
- from diffusers.modeling_utils import ModelMixin
26
- from diffusers.hub_utils import http_user_agent
27
- from diffusers.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
28
- from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
29
- from diffusers.utils import (
30
- CONFIG_NAME,
31
- DIFFUSERS_CACHE,
32
- ONNX_WEIGHTS_NAME,
33
- WEIGHTS_NAME,
34
- BaseOutput,
35
- deprecate,
36
- is_transformers_available,
37
- logging,
38
- )
39
-
40
-
41
- if is_transformers_available():
42
- import transformers
43
- from transformers import PreTrainedModel
44
-
45
-
46
- INDEX_FILE = "diffusion_pytorch_model.bin"
47
- CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
48
- DUMMY_MODULES_FOLDER = "diffusers.utils"
49
-
50
-
51
- logger = logging.get_logger(__name__)
52
-
53
-
54
- LOADABLE_CLASSES = {
55
- "diffusers": {
56
- "ModelMixin": ["save_pretrained", "from_pretrained"],
57
- "SchedulerMixin": ["save_config", "from_config"],
58
- "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
59
- "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
60
- },
61
- "transformers": {
62
- "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
63
- "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
64
- "PreTrainedModel": ["save_pretrained", "from_pretrained"],
65
- "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
66
- },
67
- "LdmZhPipeline": {
68
- "WukongClipTextEncoder": ["save_pretrained", "from_pretrained"],
69
- "ESRGAN": ["save_pretrained", "from_pretrained"],
70
- },
71
- }
72
-
73
- ALL_IMPORTABLE_CLASSES = {}
74
- for library in LOADABLE_CLASSES:
75
- ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
76
-
77
-
78
- @dataclass
79
- class ImagePipelineOutput(BaseOutput):
80
- """
81
- Output class for image pipelines.
82
-
83
- Args:
84
- images (`List[PIL.Image.Image]` or `np.ndarray`)
85
- List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
86
- num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
87
- """
88
-
89
- images: Union[List[PIL.Image.Image], np.ndarray]
90
-
91
-
92
- @dataclass
93
- class AudioPipelineOutput(BaseOutput):
94
- """
95
- Output class for audio pipelines.
96
-
97
- Args:
98
- audios (`np.ndarray`)
99
- List of denoised samples of shape `(batch_size, num_channels, sample_rate)`. Numpy array present the
100
- denoised audio samples of the diffusion pipeline.
101
- """
102
-
103
- audios: np.ndarray
104
-
105
-
106
- class DiffusionPipeline(ConfigMixin):
107
- r"""
108
- Base class for all models.
109
-
110
- [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines
111
- and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to:
112
-
113
- - move all PyTorch modules to the device of your choice
114
- - enabling/disabling the progress bar for the denoising iteration
115
-
116
- Class attributes:
117
-
118
- - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
119
- components of the diffusion pipeline.
120
- """
121
- config_name = "model_index.json"
122
-
123
- def register_modules(self, **kwargs):
124
- # import it here to avoid circular import
125
- from diffusers import pipelines
126
-
127
- for name, module in kwargs.items():
128
- # retrieve library
129
- if module is None:
130
- register_dict = {name: (None, None)}
131
- else:
132
- library = module.__module__.split(".")[0]
133
-
134
- # check if the module is a pipeline module
135
- pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None
136
- path = module.__module__.split(".")
137
- is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
138
-
139
- # if library is not in LOADABLE_CLASSES, then it is a custom module.
140
- # Or if it's a pipeline module, then the module is inside the pipeline
141
- # folder so we set the library to module name.
142
- if library not in LOADABLE_CLASSES or is_pipeline_module:
143
- library = pipeline_dir
144
-
145
- # retrieve class_name
146
- class_name = module.__class__.__name__
147
-
148
- register_dict = {name: (library, class_name)}
149
-
150
- # save model index config
151
- self.register_to_config(**register_dict)
152
-
153
- # set models
154
- setattr(self, name, module)
155
-
156
- def save_pretrained(self, save_directory: Union[str, os.PathLike]):
157
- """
158
- Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
159
- a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
160
- method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method.
161
-
162
- Arguments:
163
- save_directory (`str` or `os.PathLike`):
164
- Directory to which to save. Will be created if it doesn't exist.
165
- """
166
- self.save_config(save_directory)
167
-
168
- model_index_dict = dict(self.config)
169
- model_index_dict.pop("_class_name")
170
- model_index_dict.pop("_diffusers_version")
171
- model_index_dict.pop("_module", None)
172
-
173
- for pipeline_component_name in model_index_dict.keys():
174
- sub_model = getattr(self, pipeline_component_name)
175
- if sub_model is None:
176
- # edge case for saving a pipeline with safety_checker=None
177
- continue
178
-
179
- model_cls = sub_model.__class__
180
-
181
- save_method_name = None
182
- # search for the model's base class in LOADABLE_CLASSES
183
- for library_name, library_classes in LOADABLE_CLASSES.items():
184
- library = importlib.import_module(library_name)
185
- for base_class, save_load_methods in library_classes.items():
186
- class_candidate = getattr(library, base_class)
187
- if issubclass(model_cls, class_candidate):
188
- # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
189
- save_method_name = save_load_methods[0]
190
- break
191
- if save_method_name is not None:
192
- break
193
-
194
- save_method = getattr(sub_model, save_method_name)
195
- save_method(os.path.join(save_directory, pipeline_component_name))
196
-
197
- def to(self, torch_device: Optional[Union[str, torch.device]] = None):
198
- if torch_device is None:
199
- return self
200
-
201
- module_names, _ = self.extract_init_dict(dict(self.config))
202
- for name in module_names.keys():
203
- module = getattr(self, name)
204
- if isinstance(module, torch.nn.Module):
205
- if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
206
- logger.warning(
207
- "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It"
208
- " is not recommended to move them to `cpu` or `mps` as running them will fail. Please make"
209
- " sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for"
210
- " `float16` operations on those devices in PyTorch. Please remove the"
211
- " `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference."
212
- )
213
- module.to(torch_device)
214
- return self
215
-
216
- @property
217
- def device(self) -> torch.device:
218
- r"""
219
- Returns:
220
- `torch.device`: The torch device on which the pipeline is located.
221
- """
222
- module_names, _ = self.extract_init_dict(dict(self.config))
223
- for name in module_names.keys():
224
- module = getattr(self, name)
225
- if isinstance(module, torch.nn.Module):
226
- # if module.device == torch.device("meta"):
227
- # return torch.device("cpu")
228
- return module.device
229
- return torch.device("cpu")
230
-
231
- @classmethod
232
- def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
233
- r"""
234
- Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
235
-
236
- The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
237
-
238
- The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
239
- pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
240
- task.
241
-
242
- The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
243
- weights are discarded.
244
-
245
- Parameters:
246
- pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
247
- Can be either:
248
-
249
- - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
250
- https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
251
- `CompVis/ldm-text2im-large-256`.
252
- - A path to a *directory* containing pipeline weights saved using
253
- [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
254
- torch_dtype (`str` or `torch.dtype`, *optional*):
255
- Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
256
- will be automatically derived from the model's weights.
257
- custom_pipeline (`str`, *optional*):
258
-
259
- <Tip warning={true}>
260
-
261
- This is an experimental feature and is likely to change in the future.
262
-
263
- </Tip>
264
-
265
- Can be either:
266
-
267
- - A string, the *repo id* of a custom pipeline hosted inside a model repo on
268
- https://huggingface.co/. Valid repo ids have to be located under a user or organization name,
269
- like `hf-internal-testing/diffusers-dummy-pipeline`.
270
-
271
- <Tip>
272
-
273
- It is required that the model repo has a file, called `pipeline.py` that defines the custom
274
- pipeline.
275
-
276
- </Tip>
277
-
278
- - A string, the *file name* of a community pipeline hosted on GitHub under
279
- https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to
280
- match exactly the file name without `.py` located under the above link, *e.g.*
281
- `clip_guided_stable_diffusion`.
282
-
283
- <Tip>
284
-
285
- Community pipelines are always loaded from the current `main` branch of GitHub.
286
-
287
- </Tip>
288
-
289
- - A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`.
290
-
291
- <Tip>
292
-
293
- It is required that the directory has a file, called `pipeline.py` that defines the custom
294
- pipeline.
295
-
296
- </Tip>
297
-
298
- For more information on how to load and create custom pipelines, please have a look at [Loading and
299
- Creating Custom
300
- Pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/custom_pipelines)
301
-
302
- torch_dtype (`str` or `torch.dtype`, *optional*):
303
- force_download (`bool`, *optional*, defaults to `False`):
304
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
305
- cached versions if they exist.
306
- resume_download (`bool`, *optional*, defaults to `False`):
307
- Whether or not to delete incompletely received files. Will attempt to resume the download if such a
308
- file exists.
309
- proxies (`Dict[str, str]`, *optional*):
310
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
311
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
312
- output_loading_info(`bool`, *optional*, defaults to `False`):
313
- Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
314
- local_files_only(`bool`, *optional*, defaults to `False`):
315
- Whether or not to only look at local files (i.e., do not try to download the model).
316
- use_auth_token (`str` or *bool*, *optional*):
317
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
318
- when running `huggingface-cli login` (stored in `~/.huggingface`).
319
- revision (`str`, *optional*, defaults to `"main"`):
320
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
321
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
322
- identifier allowed by git.
323
- mirror (`str`, *optional*):
324
- Mirror source to accelerate downloads in China. If you are from China and have an accessibility
325
- problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
326
- Please refer to the mirror site for more information. specify the folder name here.
327
- device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
328
- A map that specifies where each submodule should go. It doesn't need to be refined to each
329
- parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
330
- same device.
331
-
332
- To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
333
- more information about each option see [designing a device
334
- map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
335
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
336
- Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
337
- also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
338
- model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
339
- setting this argument to `True` will raise an error.
340
-
341
- kwargs (remaining dictionary of keyword arguments, *optional*):
342
- Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
343
- specific pipeline class. The overwritten components are then directly passed to the pipelines
344
- `__init__` method. See example below for more information.
345
-
346
- <Tip>
347
-
348
- It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
349
- models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"`
350
-
351
- </Tip>
352
-
353
- <Tip>
354
-
355
- Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
356
- this method in a firewalled environment.
357
-
358
- </Tip>
359
-
360
- Examples:
361
-
362
- ```py
363
- >>> from diffusers import DiffusionPipeline
364
-
365
- >>> # Download pipeline from huggingface.co and cache.
366
- >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
367
-
368
- >>> # Download pipeline that requires an authorization token
369
- >>> # For more information on access tokens, please refer to this section
370
- >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
371
- >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
372
-
373
- >>> # Download pipeline, but overwrite scheduler
374
- >>> from diffusers import LMSDiscreteScheduler
375
-
376
- >>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
377
- >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
378
- ```
379
- """
380
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
381
- resume_download = kwargs.pop("resume_download", False)
382
- force_download = kwargs.pop("force_download", False)
383
- proxies = kwargs.pop("proxies", None)
384
- local_files_only = kwargs.pop("local_files_only", False)
385
- use_auth_token = kwargs.pop("use_auth_token", None)
386
- revision = kwargs.pop("revision", None)
387
- torch_dtype = kwargs.pop("torch_dtype", None)
388
- custom_pipeline = kwargs.pop("custom_pipeline", None)
389
- provider = kwargs.pop("provider", None)
390
- sess_options = kwargs.pop("sess_options", None)
391
- device_map = kwargs.pop("device_map", None)
392
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
393
-
394
- if device_map is not None and not is_torch_version(">=", "1.9.0"):
395
- raise NotImplementedError(
396
- "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
397
- " `device_map=None`."
398
- )
399
-
400
- if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
401
- raise NotImplementedError(
402
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
403
- " `low_cpu_mem_usage=False`."
404
- )
405
-
406
- if low_cpu_mem_usage is False and device_map is not None:
407
- raise ValueError(
408
- f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
409
- " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
410
- )
411
-
412
- # 1. Download the checkpoints and configs
413
- # use snapshot download here to get it working from from_pretrained
414
- if not os.path.isdir(pretrained_model_name_or_path):
415
- config_dict = cls.get_config_dict(
416
- pretrained_model_name_or_path,
417
- cache_dir=cache_dir,
418
- resume_download=resume_download,
419
- force_download=force_download,
420
- proxies=proxies,
421
- local_files_only=local_files_only,
422
- use_auth_token=use_auth_token,
423
- revision=revision,
424
- )
425
- # make sure we only download sub-folders and `diffusers` filenames
426
- folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
427
- allow_patterns = [os.path.join(k, "*") for k in folder_names]
428
- allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
429
-
430
- # make sure we don't download flax weights
431
- ignore_patterns = "*.msgpack"
432
-
433
- if custom_pipeline is not None:
434
- allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
435
-
436
- if cls != DiffusionPipeline:
437
- requested_pipeline_class = cls.__name__
438
- else:
439
- requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
440
- user_agent = {"pipeline_class": requested_pipeline_class}
441
- if custom_pipeline is not None:
442
- user_agent["custom_pipeline"] = custom_pipeline
443
- user_agent = http_user_agent(user_agent)
444
-
445
- # download all allow_patterns
446
- cached_folder = snapshot_download(
447
- pretrained_model_name_or_path,
448
- cache_dir=cache_dir,
449
- resume_download=resume_download,
450
- proxies=proxies,
451
- local_files_only=local_files_only,
452
- use_auth_token=use_auth_token,
453
- revision=revision,
454
- allow_patterns=allow_patterns,
455
- ignore_patterns=ignore_patterns,
456
- user_agent=user_agent,
457
- )
458
- else:
459
- cached_folder = pretrained_model_name_or_path
460
-
461
- config_dict = cls.get_config_dict(cached_folder)
462
-
463
- # 2. Load the pipeline class, if using custom module then load it from the hub
464
- # if we load from explicit class, let's use it
465
- if custom_pipeline is not None:
466
- pipeline_class = get_class_from_dynamic_module(
467
- custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline
468
- )
469
- elif cls != DiffusionPipeline:
470
- pipeline_class = cls
471
- else:
472
- diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
473
- pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
474
-
475
- # To be removed in 1.0.0
476
- if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
477
- version.parse(config_dict["_diffusers_version"]).base_version
478
- ) <= version.parse("0.5.1"):
479
- from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
480
-
481
- pipeline_class = StableDiffusionInpaintPipelineLegacy
482
-
483
- deprecation_message = (
484
- "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
485
- f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
486
- " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
487
- " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
488
- f" checkpoint {pretrained_model_name_or_path} to the format of"
489
- " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
490
- " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
491
- )
492
- deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
493
-
494
- # some modules can be passed directly to the init
495
- # in this case they are already instantiated in `kwargs`
496
- # extract them here
497
- expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
498
- passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
499
-
500
- init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs)
501
-
502
- if len(unused_kwargs) > 0:
503
- logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
504
-
505
- init_kwargs = {}
506
-
507
- # import it here to avoid circular import
508
- from diffusers import pipelines
509
-
510
- # 3. Load each module in the pipeline
511
- for name, (library_name, class_name) in init_dict.items():
512
- if class_name is None:
513
- # edge case for when the pipeline was saved with safety_checker=None
514
- init_kwargs[name] = None
515
- continue
516
-
517
- # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
518
- if class_name.startswith("Flax"):
519
- class_name = class_name[4:]
520
-
521
- is_pipeline_module = hasattr(pipelines, library_name)
522
- loaded_sub_model = None
523
- sub_model_should_be_defined = True
524
-
525
- # if the model is in a pipeline module, then we load it from the pipeline
526
- if name in passed_class_obj:
527
- # 1. check that passed_class_obj has correct parent class
528
- if not is_pipeline_module:
529
- library = importlib.import_module(library_name)
530
- class_obj = getattr(library, class_name)
531
- importable_classes = LOADABLE_CLASSES[library_name]
532
- class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
533
-
534
- expected_class_obj = None
535
- for class_name, class_candidate in class_candidates.items():
536
- if issubclass(class_obj, class_candidate):
537
- expected_class_obj = class_candidate
538
-
539
- if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
540
- raise ValueError(
541
- f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
542
- f" {expected_class_obj}"
543
- )
544
- elif passed_class_obj[name] is None:
545
- logger.warn(
546
- f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
547
- f" that this might lead to problems when using {pipeline_class} and is not recommended."
548
- )
549
- sub_model_should_be_defined = False
550
- else:
551
- logger.warn(
552
- f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
553
- " has the correct type"
554
- )
555
-
556
- # set passed class object
557
- loaded_sub_model = passed_class_obj[name]
558
- elif is_pipeline_module:
559
- pipeline_module = getattr(pipelines, library_name)
560
- class_obj = getattr(pipeline_module, class_name)
561
- importable_classes = ALL_IMPORTABLE_CLASSES
562
- class_candidates = {c: class_obj for c in importable_classes.keys()}
563
- else:
564
- # else we just import it from the library.
565
- library = importlib.import_module(library_name)
566
- class_obj = getattr(library, class_name)
567
- importable_classes = LOADABLE_CLASSES[library_name]
568
- class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
569
-
570
- if loaded_sub_model is None and sub_model_should_be_defined:
571
- load_method_name = None
572
- for class_name, class_candidate in class_candidates.items():
573
- if issubclass(class_obj, class_candidate):
574
- load_method_name = importable_classes[class_name][1]
575
-
576
- if load_method_name is None:
577
- none_module = class_obj.__module__
578
- if none_module.startswith(DUMMY_MODULES_FOLDER) and "dummy" in none_module:
579
- # call class_obj for nice error message of missing requirements
580
- class_obj()
581
-
582
- raise ValueError(
583
- f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
584
- f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
585
- )
586
-
587
- load_method = getattr(class_obj, load_method_name)
588
- loading_kwargs = {}
589
-
590
- if issubclass(class_obj, torch.nn.Module):
591
- loading_kwargs["torch_dtype"] = torch_dtype
592
- if issubclass(class_obj, diffusers.OnnxRuntimeModel):
593
- loading_kwargs["provider"] = provider
594
- loading_kwargs["sess_options"] = sess_options
595
-
596
- is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
597
- is_transformers_model = (
598
- is_transformers_available()
599
- and issubclass(class_obj, PreTrainedModel)
600
- and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
601
- )
602
-
603
- # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
604
- # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
605
- # This makes sure that the weights won't be initialized which significantly speeds up loading.
606
- if is_diffusers_model or is_transformers_model:
607
- loading_kwargs["device_map"] = device_map
608
- loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
609
-
610
- # check if the module is in a subdirectory
611
- if os.path.isdir(os.path.join(cached_folder, name)):
612
- loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
613
- else:
614
- # else load from the root directory
615
- loaded_sub_model = load_method(cached_folder, **loading_kwargs)
616
-
617
- init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
618
-
619
- # 4. Potentially add passed objects if expected
620
- missing_modules = set(expected_modules) - set(init_kwargs.keys())
621
- if len(missing_modules) > 0 and missing_modules <= set(passed_class_obj.keys()):
622
- for module in missing_modules:
623
- init_kwargs[module] = passed_class_obj[module]
624
- elif len(missing_modules) > 0:
625
- passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys()))
626
- raise ValueError(
627
- f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
628
- )
629
-
630
- # 5. Instantiate the pipeline
631
- model = pipeline_class(**init_kwargs)
632
- return model
633
-
634
- @property
635
- def components(self) -> Dict[str, Any]:
636
- r"""
637
-
638
- The `self.components` property can be useful to run different pipelines with the same weights and
639
- configurations to not have to re-allocate memory.
640
-
641
- Examples:
642
-
643
- ```py
644
- >>> from diffusers import (
645
- ... StableDiffusionPipeline,
646
- ... StableDiffusionImg2ImgPipeline,
647
- ... StableDiffusionInpaintPipeline,
648
- ... )
649
-
650
- >>> img2text = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
651
- >>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components)
652
- >>> inpaint = StableDiffusionInpaintPipeline(**img2text.components)
653
- ```
654
-
655
- Returns:
656
- A dictionaly containing all the modules needed to initialize the pipeline.
657
- """
658
- components = {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
659
- expected_modules = set(inspect.signature(self.__init__).parameters.keys()) - set(["self"])
660
-
661
- if set(components.keys()) != expected_modules:
662
- raise ValueError(
663
- f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
664
- f" {expected_modules} to be defined, but {components} are defined."
665
- )
666
-
667
- return components
668
-
669
- @staticmethod
670
- def numpy_to_pil(images):
671
- """
672
- Convert a numpy image or a batch of images to a PIL image.
673
- """
674
- if images.ndim == 3:
675
- images = images[None, ...]
676
- images = (images * 255).round().astype("uint8")
677
- if images.shape[-1] == 1:
678
- # special case for grayscale (single channel) images
679
- pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
680
- else:
681
- pil_images = [Image.fromarray(image) for image in images]
682
-
683
- return pil_images
684
-
685
- def progress_bar(self, iterable):
686
- if not hasattr(self, "_progress_bar_config"):
687
- self._progress_bar_config = {}
688
- elif not isinstance(self._progress_bar_config, dict):
689
- raise ValueError(
690
- f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
691
- )
692
-
693
- return tqdm(iterable, **self._progress_bar_config)
694
-
695
- def set_progress_bar_config(self, **kwargs):
696
- self._progress_bar_config = kwargs
697
-
698
-
699
- class LDMZhTextToImagePipeline(DiffusionPipeline):
700
-
701
- def __init__(
702
- self,
703
- vqvae,
704
- bert,
705
- tokenizer,
706
- unet,
707
- scheduler,
708
- sr,
709
- ):
710
- super().__init__()
711
- self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler, sr=sr)
712
-
713
- @torch.no_grad()
714
- def __call__(
715
- self,
716
- prompt: Union[str, List[str]],
717
- height: Optional[int] = 256,
718
- width: Optional[int] = 256,
719
- num_inference_steps: Optional[int] = 50,
720
- guidance_scale: Optional[float] = 5.0,
721
- eta: Optional[float] = 0.0,
722
- generator: Optional[torch.Generator] = None,
723
- output_type: Optional[str] = "pil",
724
- return_dict: bool = True,
725
- use_sr: bool = False,
726
- **kwargs,
727
- ):
728
- r"""
729
- Args:
730
- prompt (`str` or `List[str]`):
731
- The prompt or prompts to guide the image generation.
732
- height (`int`, *optional*, defaults to 256):
733
- The height in pixels of the generated image.
734
- width (`int`, *optional*, defaults to 256):
735
- The width in pixels of the generated image.
736
- num_inference_steps (`int`, *optional*, defaults to 50):
737
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
738
- expense of slower inference.
739
- guidance_scale (`float`, *optional*, defaults to 1.0):
740
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
741
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
742
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
743
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at
744
- the, usually at the expense of lower image quality.
745
- generator (`torch.Generator`, *optional*):
746
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
747
- deterministic.
748
- output_type (`str`, *optional*, defaults to `"pil"`):
749
- The output format of the generate image. Choose between
750
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
751
- return_dict (`bool`, *optional*):
752
- Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
753
-
754
- Returns:
755
- [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
756
- `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
757
- generated images.
758
- """
759
-
760
- if isinstance(prompt, str):
761
- batch_size = 1
762
- elif isinstance(prompt, list):
763
- batch_size = len(prompt)
764
- else:
765
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
766
-
767
- if height % 8 != 0 or width % 8 != 0:
768
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
769
-
770
- # get unconditional embeddings for classifier free guidance
771
- if guidance_scale != 1.0:
772
- uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=32, return_tensors="pt")
773
- uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))
774
-
775
- # get prompt text embeddings
776
- text_input = self.tokenizer(prompt, padding="max_length", max_length=32, return_tensors="pt")
777
- text_embeddings = self.bert(text_input.input_ids.to(self.device))
778
-
779
- latents = torch.randn(
780
- (batch_size, self.unet.in_channels, height // 8, width // 8),
781
- generator=generator,
782
- )
783
- latents = latents.to(self.device)
784
-
785
- self.scheduler.set_timesteps(num_inference_steps)
786
-
787
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
788
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
789
-
790
- extra_kwargs = {}
791
- if accepts_eta:
792
- extra_kwargs["eta"] = eta
793
-
794
- for t in self.progress_bar(self.scheduler.timesteps):
795
- if guidance_scale == 1.0:
796
- # guidance_scale of 1 means no guidance
797
- latents_input = latents
798
- context = text_embeddings
799
- else:
800
- # For classifier free guidance, we need to do two forward passes.
801
- # Here we concatenate the unconditional and text embeddings into a single batch
802
- # to avoid doing two forward passes
803
- latents_input = torch.cat([latents] * 2)
804
- context = torch.cat([uncond_embeddings, text_embeddings])
805
-
806
- # predict the noise residual
807
- noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample
808
- # perform guidance
809
- if guidance_scale != 1.0:
810
- noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
811
- noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
812
-
813
- # compute the previous noisy sample x_t -> x_t-1
814
- latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
815
-
816
- # scale and decode the image latents with vae
817
- latents = 1 / 0.18215 * latents
818
- image = self.vqvae.decode(latents).sample
819
-
820
- image = (image / 2 + 0.5).clamp(0, 1)
821
- if use_sr:
822
- image = self.sr(image)
823
- image = image.cpu().permute(0, 2, 3, 1).numpy()
824
- if output_type == "pil":
825
- image = self.numpy_to_pil(image)
826
-
827
- if not return_dict:
828
- return (image,)
829
-
830
- return ImagePipelineOutput(images=image)
831
-
832
-
833
- class QuickGELU(nn.Module):
834
- def forward(self, x: torch.Tensor):
835
- return x * torch.sigmoid(1.702 * x)
836
-
837
-
838
- class ResidualAttentionBlock(nn.Module):
839
- def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
840
- super().__init__()
841
- self.attn = nn.MultiheadAttention(d_model, n_head)
842
- self.ln_1 = nn.LayerNorm(d_model,eps=1e-07)
843
- self.mlp = nn.Sequential(OrderedDict([
844
- ("c_fc", nn.Linear(d_model, d_model * 4)),
845
- ("gelu", QuickGELU()),
846
- ("c_proj", nn.Linear(d_model * 4, d_model))
847
- ]))
848
- self.ln_2 = nn.LayerNorm(d_model,eps=1e-07)
849
- self.attn_mask = attn_mask
850
- def attention(self, x: torch.Tensor):
851
- self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
852
- return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
853
- def forward(self, x: torch.Tensor):
854
- x = x + self.attention(self.ln_1(x))
855
- x = x + self.mlp(self.ln_2(x))
856
- return x
857
-
858
-
859
- class Transformer(nn.Module):
860
- def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
861
- super().__init__()
862
- self.width = width
863
- self.layers = layers
864
- self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
865
-
866
- def forward(self, x: torch.Tensor):
867
- return self.resblocks(x)
868
-
869
-
870
- class TextTransformer(nn.Module):
871
- def __init__(self,
872
- context_length = 32,
873
- vocab_size = 21128,
874
- output_dim = 768,
875
- width = 768,
876
- layers = 12,
877
- heads = 12,
878
- return_full_embed = False):
879
- super(TextTransformer, self).__init__()
880
- self.width = width
881
- self.layers = layers
882
- self.vocab_size = vocab_size
883
- self.return_full_embed = return_full_embed
884
-
885
- self.transformer = Transformer(width, layers, heads, self.build_attntion_mask(context_length))
886
- self.text_projection = torch.nn.Parameter(
887
- torch.tensor(np.random.normal(0, self.width ** -0.5, size=(self.width, output_dim)).astype(np.float32)))
888
- self.ln_final = nn.LayerNorm(width,eps=1e-07)
889
-
890
- # https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/27
891
- # https://github.com/pytorch/pytorch/blob/a40812de534b42fcf0eb57a5cecbfdc7a70100cf/torch/nn/init.py#L22
892
- self.embedding_table = nn.Parameter(nn.init.trunc_normal_(torch.empty(vocab_size, width),std=0.02))
893
- # self.embedding_table = nn.Embedding.from_pretrained(nn.init.trunc_normal_(torch.empty(vocab_size, width),std=0.02))
894
- self.positional_embedding = nn.Parameter(nn.init.trunc_normal_(torch.empty(context_length, width),std=0.01))
895
- # self.positional_embedding = nn.Embedding.from_pretrained(nn.init.trunc_normal_(torch.empty(context_length, width),std=0.01))
896
-
897
- self.index_select=torch.index_select
898
- self.reshape=torch.reshape
899
-
900
- @staticmethod
901
- def build_attntion_mask(context_length):
902
- mask = np.triu(np.full((context_length, context_length), -np.inf).astype(np.float32), 1)
903
- mask = torch.tensor(mask)
904
- return mask
905
-
906
- def forward(self, x: torch.Tensor):
907
-
908
- tail_token=(x==102).nonzero(as_tuple=True)
909
- bsz, ctx_len = x.shape
910
- flatten_id = x.flatten()
911
- index_select_result = self.index_select(self.embedding_table,0, flatten_id)
912
- x = self.reshape(index_select_result, (bsz, ctx_len, -1))
913
- x = x + self.positional_embedding
914
- x = x.permute(1, 0, 2) # NLD -> LND
915
- x = self.transformer(x)
916
- x = x.permute(1, 0, 2) # LND -> NLD
917
- x = self.ln_final(x)
918
- x=x[tail_token]
919
- x = x @ self.text_projection
920
- return x
921
-
922
-
923
- class WukongClipTextEncoder(ModelMixin, ConfigMixin):
924
-
925
- @register_to_config
926
- def __init__(
927
- self,
928
- ):
929
- super().__init__()
930
- self.model = TextTransformer()
931
-
932
- def forward(
933
- self,
934
- tokens
935
- ):
936
- z = self.model(tokens)
937
- z = z / torch.linalg.norm(z, dim=-1, keepdim=True)
938
- if z.ndim==2:
939
- z = z.view((z.shape[0], 1, z.shape[1]))
940
- return z
941
-
942
-
943
- def make_layer(block, n_layers):
944
- layers = []
945
- for _ in range(n_layers):
946
- layers.append(block())
947
- return nn.Sequential(*layers)
948
-
949
-
950
- class ResidualDenseBlock_5C(nn.Module):
951
- def __init__(self, nf=64, gc=32, bias=True):
952
- super(ResidualDenseBlock_5C, self).__init__()
953
- # gc: growth channel, i.e. intermediate channels
954
- self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
955
- self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
956
- self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
957
- self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
958
- self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
959
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
960
-
961
- # initialization
962
- # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
963
-
964
- def forward(self, x):
965
- x1 = self.lrelu(self.conv1(x))
966
- x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
967
- x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
968
- x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
969
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
970
- return x5 * 0.2 + x
971
-
972
-
973
- class RRDB(nn.Module):
974
- '''Residual in Residual Dense Block'''
975
-
976
- def __init__(self, nf, gc=32):
977
- super(RRDB, self).__init__()
978
- self.RDB1 = ResidualDenseBlock_5C(nf, gc)
979
- self.RDB2 = ResidualDenseBlock_5C(nf, gc)
980
- self.RDB3 = ResidualDenseBlock_5C(nf, gc)
981
-
982
- def forward(self, x):
983
- out = self.RDB1(x)
984
- out = self.RDB2(out)
985
- out = self.RDB3(out)
986
- return out * 0.2 + x
987
-
988
-
989
- class RRDBNet(nn.Module):
990
- def __init__(self, in_nc, out_nc, nf, nb, gc=32):
991
- super(RRDBNet, self).__init__()
992
- RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
993
-
994
- self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
995
- self.RRDB_trunk = make_layer(RRDB_block_f, nb)
996
- self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
997
- #### upsampling
998
- self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
999
- self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
1000
- self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
1001
- self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
1002
-
1003
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
1004
-
1005
- def forward(self, x):
1006
- fea = self.conv_first(x)
1007
- trunk = self.trunk_conv(self.RRDB_trunk(fea))
1008
- fea = fea + trunk
1009
-
1010
- fea = self.lrelu(self.upconv1(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
1011
- fea = self.lrelu(self.upconv2(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
1012
- out = self.conv_last(self.lrelu(self.HRconv(fea)))
1013
-
1014
- return out
1015
-
1016
-
1017
- class ESRGAN(ModelMixin, ConfigMixin):
1018
-
1019
- @register_to_config
1020
- def __init__(
1021
- self,
1022
- ):
1023
- super().__init__()
1024
- self.model = RRDBNet(3, 3, 64, 23, gc=32)
1025
-
1026
- def forward(
1027
- self,
1028
- img_LR
1029
- ):
1030
- img_LR = img_LR[:,[2,1,0],:,:]
1031
- img_LR = img_LR.to(self.device)
1032
- with torch.no_grad():
1033
- output = self.model(img_LR)
1034
- output = output.data.float().clamp_(0, 1)
1035
- output = output[:,[2,1,0],:,:]
1036
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,31 +1,22 @@
1
  import gradio as gr
2
- from LdmZhPipeline import LDMZhTextToImagePipeline
3
  import torch
4
- import numpy as np
5
  from PIL import Image
6
 
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
- model_id = "alibaba-pai/pai-diffusion-poem-large-zh"
9
-
10
- pipe_text2img = LDMZhTextToImagePipeline.from_pretrained(model_id)
11
- pipe_text2img = pipe_text2img.to(device)
12
 
13
  def infer_text2img(prompt, guide, steps):
14
- output = pipe_text2img([prompt]*4, guidance_scale=guide, num_inference_steps=steps)
15
- images = output.images
16
- images = [np.array(images[i]) for i in range(4)]
17
- images = np.concatenate([
18
- np.concatenate(images[0:2], axis=0),
19
- np.concatenate(images[2:4], axis=0),
20
- ], axis=1)
21
- images = Image.fromarray(images)
22
- return images
23
 
24
  with gr.Blocks() as demo:
25
  examples = [
26
- ["远上寒山石径斜 白云深处有人家"],
27
- ["停车坐爱枫林晚 霜叶红于二月花"],
28
- ["接天莲叶无穷碧 映日荷花别样红"],
29
  ]
30
  with gr.Row():
31
  with gr.Column(scale=0.5, ):
 
1
  import gradio as gr
2
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
3
  import torch
 
4
  from PIL import Image
5
 
6
+ model_id = "alibaba-pai/pai-diffusion-general-large-zh"
7
+ pipe = StableDiffusionPipeline.from_pretrained(model_id)
8
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
9
+ pipe = pipe.to("cuda")
 
10
 
11
  def infer_text2img(prompt, guide, steps):
12
+ image = pipe([prompt], guidance_scale=guide, num_inference_steps=steps).images[0]
13
+ return image
 
 
 
 
 
 
 
14
 
15
  with gr.Blocks() as demo:
16
  examples = [
17
+ ["草地上的帐篷,背景是山脉"],
18
+ ["卧室里有一张床和一张桌子"],
19
+ ["雾蒙蒙的日出在湖面上"],
20
  ]
21
  with gr.Row():
22
  with gr.Column(scale=0.5, ):
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  --extra-index-url https://download.pytorch.org/whl/cu113
2
  torch
3
  torchvision
4
- diffusers==0.7.2
5
  transformers
6
  accelerate
 
1
  --extra-index-url https://download.pytorch.org/whl/cu113
2
  torch
3
  torchvision
4
+ diffusers==0.14.0
5
  transformers
6
  accelerate