megalaa's picture
Upload 11 files
d33d554
from dataclasses import dataclass
import logging
import os
from abc import ABC
from typing import Optional
import torch
import json
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
)
from ts.torch_handler.base_handler import BaseHandler
logger = logging.getLogger(__name__)
MAX_TOKEN_LENGTH_ERR = {
"code": 422,
"type" : "MaxTokenLengthError",
"message": "Max token length exceeded",
}
class CopEngHandler(BaseHandler, ABC):
@dataclass
class GenerationConfig:
max_length: int = 20
max_new_tokens: Optional[int] = None
min_length: int = 0
min_new_tokens: Optional[int] = None
early_stopping: bool = True
do_sample: bool = False
num_beams: int = 1
num_beam_groups: int = 1
top_k: int = 50
top_p: float = 0.95
temperature: float = 1.0
diversity_penalty: float = 0.0
def __init__(self):
super(CopEngHandler, self).__init__()
self.initialized = False
def initialize(self, ctx):
"""In this initialize function, the HF large model is loaded and
partitioned using DeepSpeed.
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
"""
logger.info("Start initialize")
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")
serialized_file = self.manifest["model"]["serializedFile"]
model_pt_path = os.path.join(model_dir, serialized_file)
setup_config_path = os.path.join(model_dir, "setup_self.config.json")
if os.path.isfile(setup_config_path):
with open(setup_config_path) as setup_config_path:
self.setup_config = json.load(setup_config_path)
seed = int(42)
torch.manual_seed(seed)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info("Device: %s", self.device)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
self.model.to(self.device)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.config = CopEngHandler.GenerationConfig(
max_new_tokens=128,
min_new_tokens=1,
num_beams=5,
)
self.initialized = True
logger.info("Init done")
def preprocess(self, requests):
preprocessed_data = []
for data in requests:
data_item = data.get("data")
if data_item is None:
data_item = data.get("body")
if isinstance(data_item, (bytes, bytearray)):
data_item = data_item.decode("utf-8")
preprocessed_data.append(greekify(data_item))
logger.info("preprocessed_data %s: ", preprocessed_data)
return preprocessed_data
def inference(self, data):
indices = {}
batch = []
for i, item in enumerate(data):
tokens = self.tokenizer(item, return_tensors="pt", padding=True)
if len(tokens.input_ids.squeeze()) > self.tokenizer.model_max_length:
logger.info("Skipping token %s for index %s", tokens, i)
continue
indices[i] = len(batch)
batch.append(data[i])
logger.info("inference batch: %s", batch)
result = self.batch_translate(batch)
return [result[indices[i]] if i in indices else None for i in range(len(data))]
def postprocess(self, output):
return output
def handle(self, requests, context):
preprocessed = self.preprocess(requests)
inference_data = self.inference(preprocessed)
postprocessed = self.postprocess(inference_data)
logger.info("inference result: %s", postprocessed)
responses = [
{"code": 200, "translation": translation}
if translation
else MAX_TOKEN_LENGTH_ERR
for translation in postprocessed
]
return responses
def batch_translate(self, input_sentences, output_confidence=False):
if len(input_sentences) == 0:
return []
inputs = self.tokenizer(input_sentences, return_tensors="pt", padding=True).to(
self.device
)
output_scores, return_dict_in_generate = output_confidence, output_confidence
outputs = self.model.generate(
**inputs,
max_length=self.config.max_length,
max_new_tokens=self.config.max_new_tokens,
min_length=self.config.min_length,
min_new_tokens=self.config.min_new_tokens,
early_stopping=self.config.early_stopping,
do_sample=self.config.do_sample,
num_beams=self.config.num_beams,
num_beam_groups=self.config.num_beam_groups,
top_k=self.config.top_k,
top_p=self.config.top_p,
temperature=self.config.temperature,
diversity_penalty=self.config.diversity_penalty,
output_scores=output_scores,
return_dict_in_generate=True,
)
translated_text = self.tokenizer.batch_decode(
outputs.sequences, skip_special_tokens=True
)
return translated_text
COPTIC_TO_GREEK = {
"ⲁ": "α",
"ⲃ": "β",
"ⲅ": "γ",
"ⲇ": "δ",
"ⲉ": "ε",
"ⲋ": "ϛ",
"ⲍ": "ζ",
"ⲏ": "η",
"ⲑ": "θ",
"ⲓ": "ι",
"ⲕ": "κ",
"ⲗ": "λ",
"ⲙ": "μ",
"ⲛ": "ν",
"ⲝ": "ξ",
"ⲟ": "ο",
"ⲡ": "π",
"ⲣ": "ρ",
"ⲥ": "σ",
"ⲧ": "τ",
"ⲩ": "υ",
"ⲫ": "φ",
"ⲭ": "χ",
"ⲯ": "ψ",
"ⲱ": "ω",
"ϣ": "s",
"ϥ": "f",
"ϧ": "k",
"ϩ": "h",
"ϫ": "j",
"ϭ": "c",
"ϯ": "t",
}
def greekify(coptic_text):
chars = []
for c in coptic_text:
l_c = c.lower()
chars.append(COPTIC_TO_GREEK.get(l_c, l_c))
return "".join(chars)