import argparse import torch from transformers import AutoTokenizer, AutoModelForCausalLM # from sentence_transformers import SentenceTransformer import os class LlamaModel: def __init__(self, model_name_or_path = "meta-llama/Llama-3.2-3B-instruct", max_new_tokens=1000, do_sample=True): token = os.getenv("HF_AUTH_TOKEN") if token is None: raise ValueError("HF_AUTH_TOKEN environment variable is not set.") self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, user_auth_token = token) self.model = AutoModelForCausalLM.from_pretrained( model_name_or_path, torch_dtype=torch.bfloat16, device_map= "auto", user_auth_token = token ) self.max_new_tokens = max_new_tokens self.do_sample = do_sample def generate(self, prompt: str) -> str: # Tokenize the input prompt tokenized_input = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device) # Generate output using the model with torch.no_grad(): output = self.model.generate( **tokenized_input, max_new_tokens=self.max_new_tokens, do_sample=self.do_sample, pad_token_id=self.tokenizer.eos_token_id, ) # Remove the input prompt from the generated output output = output[:, tokenized_input["input_ids"].shape[-1]:].cpu() # Decode the generated output decoded_output = self.tokenizer.batch_decode(output, skip_special_tokens=True) return decoded_output[0].strip() class GPT2Model: """ A class for GPT-2 model handling. """ def __init__(self, model_name="gpt2"): """ Initialize the GPT-2 model and tokenizer. """ self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained(model_name) def generate(self, input_text, max_length=200, temperature=0.7, top_p=0.9, top_k=50): """ Generate a response using the GPT-2 model. """ inputs = self.tokenizer.encode(input_text, return_tensors="pt") outputs = self.model.generate( inputs, max_length=max_length, num_return_sequences=1, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True) class GPTNeoXModel: """ A class for GPT-NeoX model handling. """ def __init__(self, model_name="EleutherAI/gpt-neox-20b"): """ Initialize the GPT-NeoX model and tokenizer. """ self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained(model_name) def generate(self, input_text, max_length=50, temperature=0.7, top_p=0.9, top_k=50): """ Generate a response using the GPT-NeoX model. """ inputs = self.tokenizer.encode(input_text, return_tensors="pt") outputs = self.model.generate( inputs, max_length=max_length, num_return_sequences=1, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True) class DistilGPT2Model: """ A class for DistilGPT-2 model handling. """ def __init__(self, model_name="distilgpt2"): """ Initialize the DistilGPT-2 model and tokenizer. """ self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained(model_name) def generate(self, input_text, max_length=200, temperature=0.7, top_p=0.9, top_k=50): """ Generate a response using the DistilGPT-2 model. """ inputs = self.tokenizer.encode(input_text, return_tensors="pt") outputs = self.model.generate( inputs, max_length=max_length, num_return_sequences=1, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True) class LLaMA2Model: """ A class for LLaMA-2 7B Chat model handling. """ def __init__(self, model_name="meta-llama/Llama-2-7b-chat-hf"): """ Initialize the LLaMA-2 model and tokenizer. """ self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained(model_name) def generate(self, input_text, max_length=50, temperature=0.7, top_p=0.9, top_k=50): """ Generate a response using the LLaMA-2 model. """ inputs = self.tokenizer.encode(input_text, return_tensors="pt") outputs = self.model.generate( inputs, max_length=max_length, num_return_sequences=1, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True)