|
""" |
|
Example script for using the T5 Spotify Features model |
|
""" |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
import json |
|
|
|
def predict_spotify_features(prompt_text, model_name="afsagag/t5-spotify-features"): |
|
""" |
|
Generate Spotify audio features from a text prompt |
|
|
|
Args: |
|
prompt_text (str): Natural language description of music preferences |
|
model_name (str): Hugging Face model name |
|
|
|
Returns: |
|
dict: Spotify audio features or None if JSON parsing fails |
|
""" |
|
|
|
model = T5ForConditionalGeneration.from_pretrained(model_name) |
|
tokenizer = T5Tokenizer.from_pretrained(model_name) |
|
|
|
|
|
input_text = f"prompt: {prompt_text}" |
|
|
|
|
|
input_ids = tokenizer(input_text, return_tensors="pt", max_length=256, truncation=True).input_ids |
|
outputs = model.generate( |
|
input_ids, |
|
max_length=256, |
|
num_beams=4, |
|
early_stopping=True, |
|
do_sample=False |
|
) |
|
|
|
|
|
result = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
cleaned_result = result.replace("ll", "null").replace("nu", "null") |
|
|
|
try: |
|
return json.loads(cleaned_result) |
|
except json.JSONDecodeError: |
|
print(f"Failed to parse JSON: {cleaned_result}") |
|
return None |
|
|
|
if __name__ == "__main__": |
|
|
|
test_prompts = [ |
|
"I want energetic dance music", |
|
"Play some calm acoustic songs", |
|
"Upbeat pop music for working out", |
|
"Sad slow songs for rainy days" |
|
] |
|
|
|
for prompt in test_prompts: |
|
print(f"\nPrompt: {prompt}") |
|
features = predict_spotify_features(prompt) |
|
if features: |
|
print(f"Features: {json.dumps(features, indent=2)}") |
|
|