megalaa's picture
Upload 11 files
10597c2
raw
history blame
6.16 kB
from dataclasses import dataclass
import logging
import os
from abc import ABC
from typing import Optional
import torch
import json
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
)
from ts.torch_handler.base_handler import BaseHandler
logger = logging.getLogger(__name__)
MAX_TOKEN_LENGTH_ERR = {
"code": 422,
"type" : "MaxTokenLengthError",
"message": "Max token length exceeded",
}
class EngCopHandler(BaseHandler, ABC):
@dataclass
class GenerationConfig:
max_length: int = 20
max_new_tokens: Optional[int] = None
min_length: int = 0
min_new_tokens: Optional[int] = None
early_stopping: bool = True
do_sample: bool = False
num_beams: int = 1
num_beam_groups: int = 1
top_k: int = 50
top_p: float = 0.95
temperature: float = 1.0
diversity_penalty: float = 0.0
def __init__(self):
super(EngCopHandler, self).__init__()
self.initialized = False
def initialize(self, ctx):
"""In this initialize function, the HF large model is loaded and
partitioned using DeepSpeed.
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
"""
logger.info("Start initialize")
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")
serialized_file = self.manifest["model"]["serializedFile"]
model_pt_path = os.path.join(model_dir, serialized_file)
setup_config_path = os.path.join(model_dir, "setup_self.config.json")
if os.path.isfile(setup_config_path):
with open(setup_config_path) as setup_config_path:
self.setup_config = json.load(setup_config_path)
seed = int(42)
torch.manual_seed(seed)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info("Device: %s", self.device)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
self.model.to(self.device)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.config = EngCopHandler.GenerationConfig(
max_new_tokens=128,
min_new_tokens=1,
num_beams=5,
)
self.initialized = True
logger.info("Init done")
def preprocess(self, requests):
preprocessed_data = []
for data in requests:
data_item = data.get("data")
if data_item is None:
data_item = data.get("body")
if isinstance(data_item, (bytes, bytearray)):
data_item = data_item.decode("utf-8")
preprocessed_data.append(data_item)
logger.info("preprocessed_data %s: ", preprocessed_data)
return preprocessed_data
def inference(self, data):
indices = {}
batch = []
for i, item in enumerate(data):
tokens = self.tokenizer(item, return_tensors="pt", padding=True)
if len(tokens.input_ids.squeeze()) > self.tokenizer.model_max_length:
logger.info("Skipping token %s for index %s", tokens, i)
continue
indices[i] = len(batch)
batch.append(data[i])
logger.info("inference batch: %s", batch)
result = self.batch_translate(batch)
return [
degreekify(result[indices[i]]) if i in indices else None
for i in range(len(data))
]
def postprocess(self, output):
return output
def handle(self, requests, context):
logger.info("requests %s: ", requests)
preprocessed = self.preprocess(requests)
inference_data = self.inference(preprocessed)
postprocessed = self.postprocess(inference_data)
logger.info("inference result: %s", postprocessed)
responses = [
{"code": 200, "translation": translation}
if translation
else MAX_TOKEN_LENGTH_ERR
for translation in postprocessed
]
return responses
def batch_translate(self, input_sentences, output_confidence=False):
if len(input_sentences) == 0:
return []
inputs = self.tokenizer(input_sentences, return_tensors="pt", padding=True).to(
self.device
)
output_scores, return_dict_in_generate = output_confidence, output_confidence
outputs = self.model.generate(
**inputs,
max_length=self.config.max_length,
max_new_tokens=self.config.max_new_tokens,
min_length=self.config.min_length,
min_new_tokens=self.config.min_new_tokens,
early_stopping=self.config.early_stopping,
do_sample=self.config.do_sample,
num_beams=self.config.num_beams,
num_beam_groups=self.config.num_beam_groups,
top_k=self.config.top_k,
top_p=self.config.top_p,
temperature=self.config.temperature,
diversity_penalty=self.config.diversity_penalty,
output_scores=output_scores,
return_dict_in_generate=True,
)
translated_text = self.tokenizer.batch_decode(
outputs.sequences, skip_special_tokens=True
)
return translated_text
GREEK_TO_COPTIC = {
"α": "ⲁ",
"β": "ⲃ",
"γ": "ⲅ",
"δ": "ⲇ",
"ε": "ⲉ",
"ϛ": "ⲋ",
"ζ": "ⲍ",
"η": "ⲏ",
"θ": "ⲑ",
"ι": "ⲓ",
"κ": "ⲕ",
"λ": "ⲗ",
"μ": "ⲙ",
"ν": "ⲛ",
"ξ": "ⲝ",
"ο": "ⲟ",
"π": "ⲡ",
"ρ": "ⲣ",
"σ": "ⲥ",
"τ": "ⲧ",
"υ": "ⲩ",
"φ": "ⲫ",
"χ": "ⲭ",
"ψ": "ⲯ",
"ω": "ⲱ",
"s": "ϣ",
"f": "ϥ",
"k": "ϧ",
"h": "ϩ",
"j": "ϫ",
"c": "ϭ",
"t": "ϯ",
}
def degreekify(greek_text):
chars = []
for c in greek_text:
l_c = c.lower()
chars.append(GREEK_TO_COPTIC.get(l_c, l_c))
return "".join(chars)