import torch from transformers import AutoTokenizer, AutoModelForCausalLM from typing import Dict, List, Any class EndpointHandler: # def __init__(self, path="decapoda-research/llama-65b-hf"): def __init__(self, path="TangrisJones/vicuna-13b-GPTQ-4bit-128g"): self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained(path) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: input_text = data["inputs"] kwargs = data.get("kwargs", {}) # Tokenize input text input_tokens = self.tokenizer.encode(input_text, return_tensors="pt") # Generate output tokens with torch.no_grad(): output_tokens = self.model.generate(input_tokens, **kwargs) # Decode output tokens output_text = self.tokenizer.decode(output_tokens[0]) return [{"output": output_text}] # Example usage if __name__ == "__main__": handler = EndpointHandler() input_data = {"inputs": "Once upon a time in a small village, "} output_data = handler(input_data) print(output_data)