andrewrreed's picture
andrewrreed HF staff
model to gpu
92f0779
raw
history blame contribute delete
No virus
2.13 kB
import os
import torch
from typing import Dict, List, Any
from transformers import AutoTokenizer
from gector import GECToR, predict, load_verb_dict
class EndpointHandler:
def __init__(self, path=""):
self.model = GECToR.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.encode, self.decode = load_verb_dict(
os.path.join(path, "data/verb-form-vocab.txt")
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the input data and return the predicted results.
Args:
data (Dict[str, Any]): The input data dictionary containing the following keys:
- "inputs" (List[str]): A list of input strings to be processed.
- "n_iterations" (int, optional): The number of iterations for prediction. Defaults to 5.
- "batch_size" (int, optional): The batch size for prediction. Defaults to 2.
- "keep_confidence" (float, optional): The confidence threshold for keeping predictions. Defaults to 0.0.
- "min_error_prob" (float, optional): The minimum error probability for keeping predictions. Defaults to 0.0.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the predicted results for each input string.
"""
srcs = data["inputs"]
# Extract optional parameters from data, with defaults
n_iterations = data.get("n_iterations", 5)
batch_size = data.get("batch_size", 2)
keep_confidence = data.get("keep_confidence", 0.0)
min_error_prob = data.get("min_error_prob", 0.0)
return predict(
model=self.model,
tokenizer=self.tokenizer,
srcs=srcs,
encode=self.encode,
decode=self.decode,
keep_confidence=keep_confidence,
min_error_prob=min_error_prob,
n_iteration=n_iterations,
batch_size=batch_size,
)