PHBJT commited on
Commit
03d612a
·
verified ·
1 Parent(s): d7d8798

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -12
app.py CHANGED
@@ -11,8 +11,8 @@ from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
11
  # Device setup
12
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
 
14
- # SmolLM setup
15
- checkpoint = "HuggingFaceTB/SmolLM-360M"
16
  smol_tokenizer = AutoTokenizer.from_pretrained(checkpoint)
17
  smol_model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
18
 
@@ -49,7 +49,9 @@ def format_description(raw_description, do_format=True):
49
  if not do_format:
50
  return raw_description
51
 
52
- prompt = f"""Format this voice description to match exactly:
 
 
53
  "a [gender] with a [pitch] voice speaks [speed] in a [environment], [delivery style]"
54
  Where:
55
  - gender: man/woman
@@ -57,21 +59,25 @@ Where:
57
  - speed: slowly/moderately/quickly
58
  - environment: close-sounding and clear/distant-sounding and noisy
59
  - delivery style: with monotone delivery/with animated delivery
60
-
61
- Description to format: {raw_description}
62
- Formatted description:"""
63
 
64
- inputs = smol_tokenizer.encode(prompt, return_tensors="pt").to(device)
 
65
  outputs = smol_model.generate(
66
  inputs,
67
- max_length=200,
68
- num_return_sequences=1,
69
  temperature=0.7,
70
- do_sample=True,
71
- pad_token_id=smol_tokenizer.eos_token_id
72
  )
73
  formatted = smol_tokenizer.decode(outputs[0], skip_special_tokens=True)
74
- return formatted.split("Formatted description:")[-1].strip()
 
 
 
 
 
75
 
76
  def preprocess(text):
77
  text = number_normalizer(text).strip()
@@ -109,6 +115,7 @@ def gen_tts(text, description, do_format=True):
109
  audio_arr = generation.cpu().numpy().squeeze()
110
  return formatted_desc, (SAMPLE_RATE, audio_arr)
111
 
 
112
  css = """
113
  #share-btn-container {
114
  display: flex;
 
11
  # Device setup
12
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
 
14
+ # SmolLM Instruct setup
15
+ checkpoint = "HuggingFaceTB/SmolLM-360M-Instruct"
16
  smol_tokenizer = AutoTokenizer.from_pretrained(checkpoint)
17
  smol_model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
18
 
 
49
  if not do_format:
50
  return raw_description
51
 
52
+ messages = [{
53
+ "role": "user",
54
+ "content": f"""Format this voice description to match exactly:
55
  "a [gender] with a [pitch] voice speaks [speed] in a [environment], [delivery style]"
56
  Where:
57
  - gender: man/woman
 
59
  - speed: slowly/moderately/quickly
60
  - environment: close-sounding and clear/distant-sounding and noisy
61
  - delivery style: with monotone delivery/with animated delivery
62
+ Description to format: {raw_description}"""
63
+ }]
 
64
 
65
+ input_text = smol_tokenizer.apply_chat_template(messages, tokenize=False)
66
+ inputs = smol_tokenizer.encode(input_text, return_tensors="pt").to(device)
67
  outputs = smol_model.generate(
68
  inputs,
69
+ max_new_tokens=200,
 
70
  temperature=0.7,
71
+ top_p=0.9,
72
+ do_sample=True
73
  )
74
  formatted = smol_tokenizer.decode(outputs[0], skip_special_tokens=True)
75
+
76
+ # Extract the formatted description from the response
77
+ try:
78
+ return formatted.split("a ")[-1].strip()
79
+ except:
80
+ return raw_description
81
 
82
  def preprocess(text):
83
  text = number_normalizer(text).strip()
 
115
  audio_arr = generation.cpu().numpy().squeeze()
116
  return formatted_desc, (SAMPLE_RATE, audio_arr)
117
 
118
+ # Rest of the code remains unchanged
119
  css = """
120
  #share-btn-container {
121
  display: flex;