ajayarora1235
commited on
Commit
·
f9f05d9
1
Parent(s):
1cc40d5
new hubert method
Browse files- .gitattributes +1 -0
- app.py +1 -1
- hubert.pth +3 -0
- vc_infer_pipeline.py +3 -2
.gitattributes
CHANGED
@@ -2,3 +2,4 @@ ilariasuitewallpaper.jpg filter=lfs diff=lfs merge=lfs -text
|
|
2 |
ilariaaisuite.png filter=lfs diff=lfs merge=lfs -text
|
3 |
pretrained_models/giga330M.pth filter=lfs diff=lfs merge=lfs -text
|
4 |
pretrained_models/encodec_4cb2048_giga.th filter=lfs diff=lfs merge=lfs -text
|
|
|
|
2 |
ilariaaisuite.png filter=lfs diff=lfs merge=lfs -text
|
3 |
pretrained_models/giga330M.pth filter=lfs diff=lfs merge=lfs -text
|
4 |
pretrained_models/encodec_4cb2048_giga.th filter=lfs diff=lfs merge=lfs -text
|
5 |
+
hubert.pth filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -246,7 +246,7 @@ associated_links = {}
|
|
246 |
def load_hubert():
|
247 |
global hubert_model
|
248 |
# Load the model
|
249 |
-
hubert_model = torch.load("hubert_base.
|
250 |
|
251 |
# Prepare the model
|
252 |
hubert_model = hubert_model.to(config.device)
|
|
|
246 |
def load_hubert():
|
247 |
global hubert_model
|
248 |
# Load the model
|
249 |
+
hubert_model = torch.load("hubert_base.pth", map_location=config.device)
|
250 |
|
251 |
# Prepare the model
|
252 |
hubert_model = hubert_model.to(config.device)
|
hubert.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6e579cfcfb99bfca12e392d89854f7ed722ebb08c74daa8d54b4b4165436e8f7
|
3 |
+
size 377560373
|
vc_infer_pipeline.py
CHANGED
@@ -396,8 +396,9 @@ class VC(object):
|
|
396 |
}
|
397 |
t0 = ttime()
|
398 |
with torch.no_grad():
|
399 |
-
|
400 |
-
|
|
|
401 |
if protect < 0.5 and pitch != None and pitchf != None:
|
402 |
feats0 = feats.clone()
|
403 |
if (
|
|
|
396 |
}
|
397 |
t0 = ttime()
|
398 |
with torch.no_grad():
|
399 |
+
feats = model(inputs["source"])["last_hidden_state"]
|
400 |
+
# logits = model.extract_features(**inputs)
|
401 |
+
# feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
|
402 |
if protect < 0.5 and pitch != None and pitchf != None:
|
403 |
feats0 = feats.clone()
|
404 |
if (
|