Spaces:
Runtime error
Runtime error
File size: 4,384 Bytes
12f2e48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from dataclasses import dataclass, field
import logging
from flask import Flask, request, jsonify
import transformers
import torch
from datasets import load_from_disk
from multi_token.model_utils import MultiTaskType
from multi_token.training import ModelArguments
from multi_token.inference import load_trained_lora_model
from multi_token.data_tools import encode_chat
import evaluate
import random
import bert_score
PRETRAIN_PHRASES = [
"What is happening in the given music <sound>?",
"Describe the sound. <sound>",
"Describe the music. <sound>",
"<sound> Provide a description of the music.",
"<sound> Provide a description of the sound.",
"Can you interpret <sound>?",
"Please explain what's happening in <sound>",
"What does <sound> represent?",
"Could you describe <sound> for me?",
"What's the content of <sound>?",
"Can you depict <sound>?",
"What is <sound>?",
"In the music clip, <sound>, what is happening?",
"Provide a description of the music. <sound>",
"Provide a description of the sound. <sound>",
"Provide a caption for the sound. <sound>",
"Provide a caption for the music. <sound>",
]
@dataclass
class ServeArguments(ModelArguments):
port: int = field(default=8080)
host: str = field(default="0.0.0.0")
load_bits: int = field(default=16)
max_new_tokens: int = field(default=128)
temperature: float = field(default=0.01)
def generate(input_json):
encoded_dict = encode_chat(input_json, tokenizer, model.modalities)
with torch.inference_mode():
output_ids = model.generate(
input_ids=encoded_dict["input_ids"].unsqueeze(0).to(model.device),
max_new_tokens=serve_args.max_new_tokens,
use_cache=True,
do_sample=True,
temperature=serve_args.temperature,
modality_inputs={
m.name: [encoded_dict[m.name]] for m in model.modalities
},
)
outputs = tokenizer.decode(
output_ids[0, encoded_dict["input_ids"].shape[0]:],
skip_special_tokens=True,
).strip()
return {"output": outputs}
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
parser = transformers.HfArgumentParser((ServeArguments,))
serve_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
dataset_path = "/data/musicbench_multitoken_official_split/val"
ds = load_from_disk(dataset_path)
# Load MU-LLaMA model and tokenizer
model_name_or_path = "mu-llama/MU-LLaMA"
model = transformers.LlamaForCausalLM.from_pretrained(model_name_or_path)
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name_or_path)
predictions = []
references = []
content_phrase = random.choice(PRETRAIN_PHRASES)
for data_point_id in range(100):
data_point = ds[data_point_id]
input_json = {"messages": [{"role": "user", "content": content_phrase}], "sounds": data_point["sounds"]}
output_json = generate(input_json)
print("Prediction ", output_json["output"])
print("Reference ", data_point["messages"][1]["content"])
print()
print()
predictions.append(output_json["output"])
references.append(data_point["messages"][1]["content"])
# Load evaluation metrics
bleu = evaluate.load("bleu")
meteor = evaluate.load("meteor")
rouge = evaluate.load("rouge")
# Compute BLEU scores
bleu_results = bleu.compute(predictions=predictions, references=references, max_order=4)
# bleu_score = sum(bleu_results[f"bleu{i}"] for i in range(1, 5)) / 4
print(bleu_results)
# Compute METEOR score
meteor_results = meteor.compute(predictions=predictions, references=references)
meteor_score = meteor_results["meteor"]
# Compute ROUGE-L score
rouge_results = rouge.compute(predictions=predictions, references=references, rouge_types=["rougeL"])
#rouge_l_score = rouge_results["rougeL"].mid.fmeasure
print(rouge_results)
# Compute BERT-Score
P, R, F1 = bert_score.score(predictions, references, lang="en", rescale_with_baseline=True)
bert_score_f1 = F1.mean().item()
# Print results
# print(f"BLEU Score: {bleu_score}")
print(f"METEOR Score: {meteor_score}")
# print(f"ROUGE-L Score: {rouge_l_score}")
print(f"BERT-Score F1: {bert_score_f1}")
|