Spaces:
Sleeping
Sleeping
| from miditok import REMI | |
| from transformers import AutoModelForCausalLM, GenerationConfig, AutoConfig | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| from pathlib import Path | |
| from symusic import Score | |
| class Processor: | |
| def __init__(self, model_location_repo, model_tokenizer_file) -> None: | |
| self.config = AutoConfig.from_pretrained(model_location_repo) | |
| self.model = AutoModelForCausalLM.from_pretrained(model_location_repo, config=self.config) | |
| tokenizer_file_location = hf_hub_download(repo_id=model_location_repo, filename=model_tokenizer_file) | |
| self.tokenizer = REMI(params=Path(tokenizer_file_location)) | |
| self.generation_config = GenerationConfig( | |
| max_new_tokens=2000, | |
| num_beams=1, | |
| do_sample=True, | |
| temperature=0.9, | |
| top_k=15, | |
| top_p=0.95, | |
| epsilon_cutoff=3e-4, | |
| eta_cutoff=1e-3, | |
| pad_token_id=self.tokenizer['PAD_None'], | |
| bos_token_id=self.tokenizer['BOS_None'], | |
| eos_token_id=self.tokenizer['EOS_None'], | |
| ) | |
| def transpose_midi(self, midi_bytes: bytes | None, max_new_tokens: int = 2000, temperature: float = 0.9, top_p: float = 0.95, do_sample: bool = True) -> bytes | None: | |
| """"" | |
| Process the MIDI file using a transformer model to generate new MIDI content based on the input. | |
| Args: | |
| midi_bytes: Raw MIDI file bytes from the frontend | |
| Returns: | |
| Generated MIDI file bytes | |
| """"" | |
| if midi_bytes is None: | |
| return None | |
| try: | |
| score = Score.from_midi(midi_bytes) | |
| tokenized_input = self.tokenizer(score) | |
| max_len = self.model.config.max_position_embeddings | |
| print(f"Max position embeddings: {self.model.config.max_position_embeddings}") | |
| max_len = 1024 #TODO for now as we are using a smaller model | |
| # Truncate input if it exceeds the model's maximum context length | |
| input_ids = tokenized_input[0].ids | |
| if len(input_ids) >= max_len: | |
| print(f"Warning: Input sequence ({len(input_ids)}) longer than max_position_embeddings ({max_len}). Truncating.") | |
| input_ids = input_ids[-max_len:] | |
| tensor_sequence = torch.tensor([input_ids], dtype=torch.long) | |
| print(f"input tensor shape: {tensor_sequence.shape}") | |
| input_token_length = tensor_sequence.shape[1] | |
| # Generate the new token sequence | |
| gen_config = GenerationConfig( | |
| max_new_tokens=int(max_new_tokens), | |
| num_beams=self.generation_config.num_beams, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| top_k=self.generation_config.top_k, | |
| top_p=top_p, | |
| epsilon_cutoff=self.generation_config.epsilon_cutoff, | |
| eta_cutoff=self.generation_config.eta_cutoff, | |
| pad_token_id=self.generation_config.pad_token_id, | |
| bos_token_id=self.generation_config.bos_token_id, | |
| eos_token_id=self.generation_config.eos_token_id, | |
| ) | |
| res = self.model.generate( | |
| inputs=tensor_sequence, | |
| generation_config=gen_config) | |
| print("Generated Output Shape", res.shape) | |
| print(f"New tokens length: {res.shape[1] - input_token_length}") | |
| # Decode the generated tokens (excluding the input part) | |
| decoded = self.tokenizer.decode([res[0][input_token_length:]]) | |
| return decoded.dumps_midi() | |
| except Exception as e: | |
| print(f"Error processing MIDI: {e}") | |
| return midi_bytes # Return original on error | |