PHBJT commited on
Commit
c36a256
1 Parent(s): c2a0e83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -15,7 +15,8 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
15
  repo_id = "ylacombe/p-m-e"
16
 
17
  model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
18
- tokenizer = AutoTokenizer.from_pretrained(repo_id)
 
19
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
20
 
21
 
@@ -97,8 +98,8 @@ def preprocess(text):
97
 
98
  @spaces.GPU
99
  def gen_tts(text, description):
100
- inputs = tokenizer(description.strip(), return_tensors="pt").to(device)
101
- prompt = tokenizer(preprocess(text), return_tensors="pt").to(device)
102
 
103
  set_seed(SEED)
104
  generation = model.generate(
 
15
  repo_id = "ylacombe/p-m-e"
16
 
17
  model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
18
+ text_tokenizer = AutoTokenizer.from_pretrained(repo_id)
19
+ description_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
20
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
21
 
22
 
 
98
 
99
  @spaces.GPU
100
  def gen_tts(text, description):
101
+ inputs = description_tokenizer(description.strip(), return_tensors="pt").to(device)
102
+ prompt = text_tokenizer(preprocess(text), return_tensors="pt").to(device)
103
 
104
  set_seed(SEED)
105
  generation = model.generate(