from typing import Optional, Dict, List from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os import logging from functools import lru_cache import concurrent.futures from torch.cuda import empty_cache logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelManager: _instance = None _initialized = False _model_name = "meta-llama/Llama-3.2-1B-Instruct" def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): # Initialize tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto" ) # Set model to evaluation mode and move to GPU self.model = self.model.to(self.model.device) self.model.eval() ModelManager._initialized = True def __del__(self): try: del self.model del self.tokenizer torch.cuda.empty_cache() except: pass @lru_cache(maxsize=1) def get_hf_token() -> str: """Get Hugging Face token from environment variables.""" token = os.getenv("HF_TOKEN") if not token: raise EnvironmentError( "HF_TOKEN environment variable not found. " "Please set your Hugging Face access token." ) return token model_name = "meta-llama/Llama-3.2-1B-Instruct" class PoetryGenerationService: def __init__(self): # Get model manager instance model_manager = ModelManager() self.model = model_manager.model self.tokenizer = model_manager.tokenizer self.cache = {} def preload_models(self): """Preload the models during application startup""" try: # Initialize tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto" ) # Set model to evaluation mode and move to GPU self.model = self.model.to(self.model.device) self.model.eval() logger.info("Models preloaded successfully") except Exception as e: logger.error(f"Error preloading models: {str(e)}") raise def generate_poem( self, prompt: str, temperature: Optional[float] = 0.7, top_p: Optional[float] = 0.9, top_k: Optional[int] = 50, max_length: Optional[int] = 100, repetition_penalty: Optional[float] = 1.1 ) -> str: try: inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128) inputs = {k: v.to(self.model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, max_length=max_length, repetition_penalty=repetition_penalty, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, ) return self.tokenizer.decode( outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True ) except Exception as e: raise Exception(f"Error generating poem: {str(e)}") def generate_poems(self, prompts: list[str]) -> list[str]: with concurrent.futures.ThreadPoolExecutor() as executor: poems = list(executor.map(self.generate_poem, prompts)) return poems