gemma-med3-9b / handler.py
krplatz's picture
Update handler.py
c8186c8 verified
raw
history blame
1.02 kB
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import torch
class EndpointHandler:
def __init__(self, model_dir):
# Load the tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
# Load the model with the `ignore_mismatched_sizes` flag
self.model = AutoModelForSequenceClassification.from_pretrained(
model_dir,
ignore_mismatched_sizes=True
)
# Initialize the pipeline
self.pipeline = pipeline(
"text-classification",
model=self.model,
tokenizer=self.tokenizer,
device=0 if torch.cuda.is_available() else -1 # Use GPU if available
)
def __call__(self, inputs):
# Perform inference using the pipeline
predictions = self.pipeline(inputs)
return predictions
# Function to be called by Hugging Face Inference Toolkit
def get_pipeline(model_dir):
return EndpointHandler(model_dir)