Spaces:
Running
on
Zero
Running
on
Zero
v1.2 edit
Browse files- app.py +5 -1
- diffrhythm/model/cfm.py +10 -16
- diffrhythm/model/dit.py +7 -3
app.py
CHANGED
@@ -232,8 +232,12 @@ with gr.Blocks(css=css) as demo:
|
|
232 |
3. **Supported Languages**
|
233 |
- **Chinese and English**
|
234 |
- More languages comming soon
|
|
|
|
|
|
|
|
|
235 |
|
236 |
-
|
237 |
- If loading audio result is slow, you can select Output Format as mp3 in Advanced Settings.
|
238 |
|
239 |
""")
|
|
|
232 |
3. **Supported Languages**
|
233 |
- **Chinese and English**
|
234 |
- More languages comming soon
|
235 |
+
|
236 |
+
4. **Editing Function in Advanced Settings**
|
237 |
+
- Using full-length audio as reference is recommended for best results.
|
238 |
+
- Use -1 to represent the start/end of audio (e.g. [[-1,25], [50,-1]] means "from start to 25s" and "from 50s to end").
|
239 |
|
240 |
+
5. **Others**
|
241 |
- If loading audio result is slow, you can select Output Format as mp3 in Advanced Settings.
|
242 |
|
243 |
""")
|
diffrhythm/model/cfm.py
CHANGED
@@ -208,27 +208,21 @@ class CFM(nn.Module):
|
|
208 |
negative_style_prompt = negative_style_prompt.repeat(batch_infer_num, 1)
|
209 |
start_time = start_time.repeat(batch_infer_num)
|
210 |
fixed_span_mask = fixed_span_mask.repeat(batch_infer_num, 1, 1)
|
211 |
-
|
212 |
-
start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
|
213 |
-
_, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
|
214 |
-
|
215 |
-
text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
|
216 |
-
text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
|
217 |
-
step_cond = torch.cat([step_cond, step_cond], 0)
|
218 |
-
style_prompt = torch.cat([style_prompt, negative_style_prompt], 0)
|
219 |
-
start_time_embed = torch.cat([start_time_embed, start_time_embed], 0)
|
220 |
|
221 |
def fn(t, x):
|
222 |
-
|
223 |
pred = self.transformer(
|
224 |
-
x=x,
|
225 |
-
|
226 |
)
|
|
|
|
|
227 |
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
|
|
232 |
|
233 |
# noise input
|
234 |
# to make sure batch inference result is same with different batch size, and for sure single inference
|
|
|
208 |
negative_style_prompt = negative_style_prompt.repeat(batch_infer_num, 1)
|
209 |
start_time = start_time.repeat(batch_infer_num)
|
210 |
fixed_span_mask = fixed_span_mask.repeat(batch_infer_num, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
def fn(t, x):
|
213 |
+
# predict flow
|
214 |
pred = self.transformer(
|
215 |
+
x=x, cond=step_cond, text=text, time=t, drop_audio_cond=False, drop_text=False, drop_prompt=False,
|
216 |
+
style_prompt=style_prompt, start_time=start_time
|
217 |
)
|
218 |
+
if cfg_strength < 1e-5:
|
219 |
+
return pred
|
220 |
|
221 |
+
null_pred = self.transformer(
|
222 |
+
x=x, cond=step_cond, text=text, time=t, drop_audio_cond=True, drop_text=True, drop_prompt=False,
|
223 |
+
style_prompt=negative_style_prompt, start_time=start_time
|
224 |
+
)
|
225 |
+
return pred + (pred - null_pred) * cfg_strength
|
226 |
|
227 |
# noise input
|
228 |
# to make sure batch inference result is same with different batch size, and for sure single inference
|
diffrhythm/model/dit.py
CHANGED
@@ -162,21 +162,25 @@ class DiT(nn.Module):
|
|
162 |
def forward(
|
163 |
self,
|
164 |
x: float["b n d"], # nosied input audio # noqa: F722
|
165 |
-
text_embed: int["b nt"], # text # noqa: F722
|
166 |
-
text_residuals,
|
167 |
cond: float["b n d"], # masked cond audio # noqa: F722
|
|
|
168 |
time: float["b"] | float[""], # time step # noqa: F821 F722
|
169 |
drop_audio_cond, # cfg for cond audio
|
|
|
170 |
drop_prompt=False,
|
171 |
style_prompt=None, # [b d t]
|
172 |
start_time=None,
|
173 |
):
|
|
|
174 |
batch, seq_len = x.shape[0], x.shape[1]
|
175 |
if time.ndim == 0:
|
176 |
time = time.repeat(batch)
|
177 |
|
|
|
178 |
t = self.time_embed(time)
|
179 |
-
|
|
|
|
|
180 |
|
181 |
if drop_prompt:
|
182 |
style_prompt = torch.zeros_like(style_prompt)
|
|
|
162 |
def forward(
|
163 |
self,
|
164 |
x: float["b n d"], # nosied input audio # noqa: F722
|
|
|
|
|
165 |
cond: float["b n d"], # masked cond audio # noqa: F722
|
166 |
+
text: int["b nt"], # text # noqa: F722
|
167 |
time: float["b"] | float[""], # time step # noqa: F821 F722
|
168 |
drop_audio_cond, # cfg for cond audio
|
169 |
+
drop_text, # cfg for text
|
170 |
drop_prompt=False,
|
171 |
style_prompt=None, # [b d t]
|
172 |
start_time=None,
|
173 |
):
|
174 |
+
|
175 |
batch, seq_len = x.shape[0], x.shape[1]
|
176 |
if time.ndim == 0:
|
177 |
time = time.repeat(batch)
|
178 |
|
179 |
+
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
180 |
t = self.time_embed(time)
|
181 |
+
s_t = self.start_time_embed(start_time)
|
182 |
+
c = t + s_t
|
183 |
+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
184 |
|
185 |
if drop_prompt:
|
186 |
style_prompt = torch.zeros_like(style_prompt)
|