BackTo2014 commited on
Commit
5135265
1 Parent(s): 41a5094

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +62 -4
README.md CHANGED
@@ -5,9 +5,67 @@ This is a simple attempt. I trained with CIFAR-10 dataset.
5
  ## Usage
6
 
7
  ```python
8
- from diffusers import DDPMPipeline
9
 
10
- pipeline = DDPMPipeline.from_pretrained('BackTo2014/DDPM-test')
11
- image = pipeline().images[0]
12
- image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  ```
 
5
  ## Usage
6
 
7
  ```python
8
+ # 生成图像有误...以下代码需修改!!!
9
 
10
+ import torch
11
+ from diffusers import DDPMPipeline, DDPMScheduler
12
+ from diffusers.models import UNet2DModel
13
+ from PIL import Image
14
+ import matplotlib.pyplot as plt
15
+
16
+ # 模型ID
17
+ model_id = "BackTo2014/DDPM-test"
18
+
19
+ # 检查设备
20
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
+
22
+ # 加载UNet模型和配置文件
23
+ try:
24
+ unet = UNet2DModel.from_pretrained(
25
+ model_id,
26
+ ignore_mismatched_sizes=True,
27
+ low_cpu_mem_usage=False,
28
+ ).to(device) # 将模型移动到GPU上
29
+ except ValueError as e:
30
+ print(f"Error loading model: {e}")
31
+
32
+ # 获取模型的state_dict
33
+ state_dict = unet.state_dict()
34
+
35
+ # 手动初始化缺失的权重
36
+ for key in e.args[0].split(': ')[1].split(', '):
37
+ name, size = key.split('.')
38
+ size = tuple(map(int, size.replace(')', '').replace('(', '').split(',')))
39
+
40
+ # 创建随机权重
41
+ new_weight = torch.randn(size).to(device) # 将权重移动到GPU上
42
+
43
+ # 更新state_dict
44
+ state_dict[name] = new_weight
45
+
46
+ # 加载更新后的state_dict
47
+ unet.load_state_dict(state_dict).to(device) # 将模型移动到GPU上
48
+
49
+ # 如果sample_size未定义,则手动设置
50
+ if unet.config.sample_size is None:
51
+ # 假设样本尺寸为 32x32
52
+ unet.config.sample_size = (32, 32)
53
+
54
+ # 初始化Scheduler
55
+ scheduler = DDPMScheduler.from_config(model_id)
56
+
57
+ # 创建DDPMPipeline
58
+ pipeline = DDPMPipeline(unet=unet, scheduler=scheduler)
59
+
60
+ # 生成图像
61
+ generator = torch.manual_seed(0)
62
+ image = pipeline(num_inference_steps=1000, generator=generator).images[0]
63
+
64
+ # 使用matplotlib显示图像
65
+ plt.imshow(image)
66
+ plt.axis('off') # 不显示坐标轴
67
+ plt.show()
68
+
69
+ # 保存图像
70
+ image.save("generated_image.png")
71
  ```