File size: 7,484 Bytes
b818573
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import data
import torch
from models import imagebind_model
from models.imagebind_model import ModalityType
import torch.nn as nn
from imagen_pytorch import ImagenTrainer
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer
from extract.getim import load_image
import torch.optim as optim
import os
from torchvision import transforms
from image2vidimg import cobtwoten, cobtwoten256
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
device = torch.device("cuda")
#import matplotlib.pyplot as plt
#import torch.nn.functional as F
#import cv2

torch.cuda.empty_cache()
transform = transforms.Compose([
    transforms.ToTensor(),  # 将numpy数组或PIL.Image读的图片转换成(C,H, W)的Tensor格式且/255归一化到[0,1.0]之间
])  # 来自ImageNet的mean和variance
unloader = transforms.ToPILImage()

# def imshow(tensor, title=None):
#
#     tensor=tensor.permute(1,2,0)
#     print(tensor.shape)
#     cv2.imshow('image:', tensor.cpu().numpy())
#     # 防止图片关闭
#     cv2.waitKey(0)
#     # plt.imshow(img_pil)
#     # if title is not None:
#     #     plt.title(title)
#     # plt.pause(0.001) # pause a bit so that plots are updated

def imagebind_out(audio_paths,model):
    # Load data
    inputs = {
        ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
    }

    with torch.no_grad():
        embeddings = model(inputs)

    return embeddings

class encode_audio(nn.Module):
    def __init__(self):
        super().__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # self.link1=nn.Linear(1024,768)
        self.link2=nn.Linear(1024,343)
        # self.link3=nn.Linear(1024,768)

    def forward(self,embeddings):
        l1=embeddings
        l2=self.link2(embeddings)
        # l3=self.link3(embeddings)
        l3=torch.matmul(l2.transpose(1,2),l1)

        return torch.cat([l1,l3],dim=1)


# os.listdir()方法获取文件夹名字,返回数组
def getAllFiles(targetDir):
    listFiles = os.listdir(targetDir)
    return listFiles

# unet for imagen
unet1 = Unet3D(max_text_len=344,text_embed_dim=1024,dim = 64, dim_mults = (1, 2, 4, 8)).to(device)
unet2 = Unet3D(max_text_len=344,text_embed_dim=1024,dim = 128, dim_mults = (1, 2, 4, 8)).to(device)
# unet3 = Unet3D(dim = 256, dim_mults = (1, 2, 4, 8)).cuda()
#unet1 = NullUnet()  # add a placeholder "null" unet for the base unet

imagen = ElucidatedImagen(
    text_embed_dim=1024,
    unets = (unet1, unet2),
    image_sizes = (64, 128),
    random_crop_sizes = (None, 64),
    temporal_downsample_factor = (2, 1),        # in this example, the first unet would receive the video temporally downsampled by 2x
    num_sample_steps = 10,
    cond_drop_prob = 0.1,
    sigma_min = 0.002,                          # min noise level
    sigma_max = (80, 160),                      # max noise level, double the max noise level for upsampler
    sigma_data = 0.5,                           # standard deviation of data distribution
    rho = 7,                                    # controls the sampling schedule
    P_mean = -1.2,                              # mean of log-normal distribution from which noise is drawn for training
    P_std = 1.2,                                # standard deviation of log-normal distribution from which noise is drawn for training
    S_churn = 80,                               # parameters for stochastic sampling - depends on dataset, Table 5 in apper
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
).to(device)

trainer = ImagenTrainer(imagen)
# trainer.to(device)
# trainer.load("./checkpoint.pt")
trainer = trainer.to(device)
# Instantiate model

# device_ids = [0, 1]
model_imageb = imagebind_model.imagebind_huge(pretrained=True)
model_imageb=model_imageb.to(device)
model_imageb.eval()
# model_imageb=model_imageb.cuda(device=device_ids)
# model_imageb.to(device)

epo=31
p=1
files = getAllFiles("./extract/audio")

outloss=0
model1=(encode_audio()).to(device)
# model1.load_state_dict(torch.load("wlc.pt").state_dict())
optimizer = optim.Adam(model1.parameters(), lr=1e-5,
                       betas=(0.9, 0.999), eps=1e-08, weight_decay=0., amsgrad=True)
# model1.eval()
model1.train()
torch.cuda.empty_cache()

for k in range(epo):
    for nm in range(0, len(files) + 1 - p, p):
        #for i in (1, 2):
        file_ext0 = os.path.splitext(files[nm])
        front0, ext0 = file_ext0
        audio_pat=[]
        audio_pat.append("./extract/audio/" + str(front0) + ".wav")
        # fcontents = load_image("./extract/image/0.jpg", transform=None, shape=[256, 128])
        fcontent = cobtwoten("./extract/image/" + str(front0) + ".jpg")
        # print(fcontent.shape)
        #fcontent = load_image("./extract/image/" + str(front0) + ".jpg", transform, shape=[256, 256])
        for ni in range(1,p):
            file_ext = os.path.splitext(files[nm+ni])
            front, ext = file_ext
            # content = load_image("./extract/image/" + str(front) + ".jpg", transform, shape=[256, 256])
            content = cobtwoten("./extract/image/" + str(front) + ".jpg")
            fcontent = torch.cat((fcontent, content), -5)
            audio_pat.append("./extract/audio/" + str(front) + ".wav")
            # imageb=torch.LongTensor(imageb_out["audio"])

        imageb_out = imagebind_out(audio_pat,model_imageb)
        fmusic = model1(imageb_out["audio"].unsqueeze(1))#(5,1,1024)->(5,344,1024)
        # fmusic = model1(imageb_out["audio"].unsqueeze(1).cuda())#(5,1,1024)->(5,344,1024)
        # print(fmusic)
        # print(fmusic.shape)
        fmusic=fmusic.to(device)
        fcontent=fcontent.to(device)
        loss = trainer(fcontent, text_embeds=fmusic, unet_number = 2,ignore_time = False, max_batch_size = p)
        trainer.update(unet_number = 2)
        optimizer.step()
        # print(optimizer.state)
        optimizer.zero_grad()
        print(loss)
        outloss=outloss+loss
        #print("unet"+str(i)+" "+str(loss))

    outloss=outloss

    print("epoch"+str(k)+" "+" loss: "+str(outloss))

    outloss=0
    if k % 3 == 2:
        torch.save(model1, "wlc.pt")
        trainer.save('./checkpoint.pt')





# text_list=["A dog.", "A car", "A bird"]
# image_paths=[".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"]


#
#
#
#
# print(
#     "Vision x Text: ",
#     torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1),
# )
# print(
#     "Audio x Text: ",
#     torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1),
# )
# print(
#     "Vision x Audio: ",
#     torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1),
# )
#
#
#
# print(embeddings['audio'].shape)
# print(embeddings[ModalityType.AUDIO].shape)
# print(embeddings[ModalityType.VISION].shape)



# Expected output:
#
# Vision x Text:
# tensor([[9.9761e-01, 2.3694e-03, 1.8612e-05],
#         [3.3836e-05, 9.9994e-01, 2.4118e-05],
#         [4.7997e-05, 1.3496e-02, 9.8646e-01]])
#
# Audio x Text:
# tensor([[1., 0., 0.],
#         [0., 1., 0.],
#         [0., 0., 1.]])
#
# Vision x Audio:
# tensor([[0.8070, 0.1088, 0.0842],
#         [0.1036, 0.7884, 0.1079],
#         [0.0018, 0.0022, 0.9960]])