Spaces:
Sleeping
Sleeping
from typing import Dict | |
import torch | |
from .config_manager import ConfigManager | |
class Diacritizer: | |
def __init__( | |
self, config_path: str, model_kind: str, load_model: bool = False | |
) -> None: | |
self.config_path = config_path | |
self.model_kind = model_kind | |
self.config_manager = ConfigManager( | |
config_path=config_path, model_kind=model_kind | |
) | |
self.config = self.config_manager.config | |
self.text_encoder = self.config_manager.text_encoder | |
if self.config.get("device"): | |
self.device = self.config["device"] | |
else: | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
if load_model: | |
self.model, self.global_step = self.config_manager.load_model() | |
self.model = self.model.to(self.device) | |
self.start_symbol_id = self.text_encoder.start_symbol_id | |
def set_model(self, model: torch.nn.Module): | |
self.model = model | |
def diacritize_text(self, text: str): | |
seq = self.text_encoder.input_to_sequence(text) | |
output = self.diacritize_batch(torch.LongTensor([seq]).to(self.device)) | |
def diacritize_batch(self, batch): | |
raise NotImplementedError() | |
def diacritize_iterators(self, iterator): | |
pass | |
class CBHGDiacritizer(Diacritizer): | |
def diacritize_batch(self, batch): | |
self.model.eval() | |
inputs = batch["src"] | |
lengths = batch["lengths"] | |
outputs = self.model(inputs.to(self.device), lengths.to("cpu")) | |
diacritics = outputs["diacritics"] | |
predictions = torch.max(diacritics, 2).indices | |
sentences = [] | |
for src, prediction in zip(inputs, predictions): | |
sentence = self.text_encoder.combine_text_and_haraqat( | |
list(src.detach().cpu().numpy()), | |
list(prediction.detach().cpu().numpy()), | |
) | |
sentences.append(sentence) | |
return sentences | |
class Seq2SeqDiacritizer(Diacritizer): | |
def diacritize_batch(self, batch): | |
self.model.eval() | |
inputs = batch["src"] | |
lengths = batch["lengths"] | |
outputs = self.model(inputs.to(self.device), lengths.to("cpu")) | |
diacritics = outputs["diacritics"] | |
predictions = torch.max(diacritics, 2).indices | |
sentences = [] | |
for src, prediction in zip(inputs, predictions): | |
sentence = self.text_encoder.combine_text_and_haraqat( | |
list(src.detach().cpu().numpy()), | |
list(prediction.detach().cpu().numpy()), | |
) | |
sentences.append(sentence) | |
return sentences | |
class GPTDiacritizer(Diacritizer): | |
def diacritize_batch(self, batch): | |
self.model.eval() | |
inputs = batch["src"] | |
lengths = batch["lengths"] | |
outputs = self.model(inputs.to(self.device), lengths.to("cpu")) | |
diacritics = outputs["diacritics"] | |
predictions = torch.max(diacritics, 2).indices | |
sentences = [] | |
for src, prediction in zip(inputs, predictions): | |
sentence = self.text_encoder.combine_text_and_haraqat( | |
list(src.detach().cpu().numpy()), | |
list(prediction.detach().cpu().numpy()), | |
) | |
sentences.append(sentence) | |
return sentences | |