|
|
import re |
|
|
import os |
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from .config import RANDOM_STATE, ASSETS_DIR |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_critic_prompt(genre, artist, title, lyrics, max_lyric_length=300): |
|
|
"""Constructs the instruction-tuning prompt for the Critic persona.""" |
|
|
lyrics_snippet = lyrics[:max_lyric_length] |
|
|
|
|
|
instruction = ( |
|
|
"You are a professional music critic. Provide specific feedback on how to improve " |
|
|
"the lyrics based on the genre and artist style. \n" |
|
|
"Formatting Rules: \n" |
|
|
"1. Use plain text with clear line breaks.\n" |
|
|
"2. Ensure all song titles and words have proper spacing." |
|
|
) |
|
|
|
|
|
context = ( |
|
|
f"Target Genre: {genre}\n" |
|
|
f"Target Artist: {artist}\n" |
|
|
f"Target Title: {title}\n\n" |
|
|
f"Lyrics to Evaluate:\n{lyrics_snippet}" |
|
|
) |
|
|
|
|
|
return f"<start_of_turn>user\n{instruction}\n\n{context}<end_of_turn>\n<start_of_turn>model\n" |
|
|
|
|
|
def build_revision_prompt(genre, artist, title, draft, critiques): |
|
|
"""Constructs the prompt for the 'Revise' step of the refinement loop.""" |
|
|
instruction = ( |
|
|
"You are an expert songwriter. Revise the provided lyrics by incorporating " |
|
|
"the specific feedback from the critic while maintaining the genre and artist style." |
|
|
) |
|
|
|
|
|
context = ( |
|
|
f"Genre: {genre}\n" |
|
|
f"Artist Style: {artist}\n" |
|
|
f"Title: {title}\n\n" |
|
|
f"Current Draft:\n{draft}\n\n" |
|
|
f"Critic Feedback:\n{critiques}" |
|
|
) |
|
|
|
|
|
return f"<start_of_turn>user\n{instruction}\n\n{context}<end_of_turn>\n<start_of_turn>model\n" |
|
|
|
|
|
def build_inference_prompt(genre, artist, title): |
|
|
"""Reconstructs the prompt format used during v1.0 training.""" |
|
|
instruction = "Generate lyrics for a song based on these details." |
|
|
input_context = f"Genre: {genre}\nArtist: {artist}\nTitle: {title}" |
|
|
|
|
|
return ( |
|
|
f"<start_of_turn>user\n{instruction}\n\n{input_context}<end_of_turn>\n" |
|
|
f"<start_of_turn>model\n" |
|
|
) |
|
|
|
|
|
def format_prompt(row): |
|
|
"""Converts a dataframe row into a structured Gemma control-token prompt.""" |
|
|
instruction = "Generate lyrics for a song based on these details." |
|
|
input_context = f"Genre: {row['tag']}\nArtist: {row['artist']}\nTitle: {row['title']}" |
|
|
response = row['lyrics'] |
|
|
|
|
|
return ( |
|
|
f"<start_of_turn>user\n{instruction}\n\n{input_context}<end_of_turn>\n" |
|
|
f"<start_of_turn>model\n{response}<end_of_turn>" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_lyrics(text): |
|
|
"""Cleans up raw model output by enforcing structural newlines and spacing.""" |
|
|
|
|
|
text = re.sub(r'(\[.*?\])', r'\n\n\1\n', text) |
|
|
|
|
|
|
|
|
text = re.sub(r'([a-z])([A-Z])', r'\1\n\2', text) |
|
|
return text.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_critic_training_row(row): |
|
|
"""Standardizes raw rows into the Critic instruction-tuning format.""" |
|
|
prompt = build_critic_prompt(row.tag, row.artist, row.title, row.lyrics) |
|
|
|
|
|
target_output = ( |
|
|
f"Genre Fit: The {row.tag} style is well-maintained.\n" |
|
|
f"Artist Style: Matches the {row.artist} aesthetic.\n" |
|
|
f"Improvements: Consider refining the rhythmic flow in the second verse." |
|
|
) |
|
|
|
|
|
return f"{prompt}{target_output}<eos>" |
|
|
|
|
|
def prepare_lyric_dataset(lyrics_filename, reviews_filename, songs_per_genre=200): |
|
|
"""Loads, cleans, and balances the dataset while exporting EDA plots.""" |
|
|
from .viz import save_figure |
|
|
|
|
|
lyrics_path = os.path.join("data", lyrics_filename) |
|
|
reviews_path = os.path.join("data", reviews_filename) |
|
|
|
|
|
print(f"Loading & Cleaning Raw Data...") |
|
|
|
|
|
lyrics_df = pd.read_csv(lyrics_path, on_bad_lines='skip') |
|
|
reviews_df = pd.read_csv(reviews_path) |
|
|
|
|
|
lyrics_df = lyrics_df.dropna(subset=['lyrics', 'artist', 'tag']) |
|
|
reviews_df = reviews_df.dropna(subset=['genre', 'artist']) |
|
|
|
|
|
lyrics_clean = lyrics_df.drop_duplicates(subset="artist")[["artist", "lyrics", "title", "tag"]] |
|
|
merged_df = reviews_df.merge(lyrics_clean, on="artist", how="left").dropna(subset=["lyrics", "tag"]) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 5)) |
|
|
top_raw = merged_df['tag'].value_counts().nlargest(10) |
|
|
sns.barplot(x=top_raw.values, y=top_raw.index, hue=top_raw.index, palette='viridis', legend=False) |
|
|
plt.title(f"Raw Genre Distribution (n={len(merged_df):,})") |
|
|
save_figure("eda_1_raw_distribution.png") |
|
|
|
|
|
|
|
|
balanced_df = merged_df.groupby("tag", group_keys=False).apply( |
|
|
lambda x: x.sample(min(len(x), songs_per_genre), random_state=RANDOM_STATE) |
|
|
) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 5)) |
|
|
sns.countplot(data=balanced_df, y='tag', hue='tag', palette='magma', legend=False) |
|
|
plt.title(f"Balanced Genre Distribution (n={len(balanced_df):,})") |
|
|
save_figure("eda_2_balanced_distribution.png") |
|
|
|
|
|
return balanced_df |