import torch from transformers import AutoModelForCausalLM, AutoTokenizer import argparse import logging from typing import List, Optional # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # Load model and tokenizer def load_model_and_tokenizer(model_name: str) -> tuple: """ Load the pre-trained model and tokenizer. Args: model_name (str): Name or path of the pre-trained model. Returns: tuple: (model, tokenizer) """ logger.info(f"Loading model: {model_name}...") try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ) logger.info("Model and tokenizer loaded successfully.") return model, tokenizer except Exception as e: logger.error(f"Error loading model: {e}") raise # Generate text def generate_text( model, tokenizer, prompt: str, max_length: int = 100, temperature: float = 1.0, top_k: int = 50, top_p: float = 0.95, ) -> str: """ Generate text based on the given prompt. Args: model: Pre-trained language model. tokenizer: Tokenizer for the model. prompt (str): Input prompt for text generation. max_length (int): Maximum length of the generated text. temperature (float): Sampling temperature (higher = more random). top_k (int): Top-k sampling (0 = no sampling). top_p (float): Top-p (nucleus) sampling (1.0 = no sampling). Returns: str: Generated text. """ try: inputs = tokenizer(prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = {key: value.to("cuda") for key, value in inputs.items()} model.to("cuda") with torch.no_grad(): outputs = model.generate( inputs.input_ids, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, do_sample=True, ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) logger.info("Text generation completed successfully.") return generated_text except Exception as e: logger.error(f"Error generating text: {e}") raise # Save generated text to a file def save_to_file(text: str, filename: str) -> None: """ Save the generated text to a file. Args: text (str): Generated text. filename (str): Name of the output file. """ try: with open(filename, "w") as file: file.write(text) logger.info(f"Generated text saved to {filename}.") except Exception as e: logger.error(f"Error saving to file: {e}") raise # Main function def main(): # Parse command-line arguments parser = argparse.ArgumentParser( description="Generate text using a pre-trained language model.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--model", type=str, default="mistralai/Mistral-8x7B", help="Name or path of the pre-trained model.", ) parser.add_argument( "--prompt", type=str, required=True, help="Input prompt for text generation.", ) parser.add_argument( "--max_length", type=int, default=100, help="Maximum length of the generated text.", ) parser.add_argument( "--temperature", type=float, default=1.0, help="Sampling temperature (higher = more random).", ) parser.add_argument( "--top_k", type=int, default=50, help="Top-k sampling (0 = no sampling).", ) parser.add_argument( "--top_p", type=float, default=0.95, help="Top-p (nucleus) sampling (1.0 = no sampling).", ) parser.add_argument( "--output_file", type=str, help="File to save the generated text.", ) args = parser.parse_args() # Load model and tokenizer try: model, tokenizer = load_model_and_tokenizer(args.model) except Exception as e: logger.error(f"Failed to load model: {e}") return # Generate text try: logger.info("Generating text...") generated_text = generate_text( model, tokenizer, args.prompt, max_length=args.max_length, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, ) # Print the generated text print("\nGenerated Text:") print(generated_text) # Save to file if specified if args.output_file: save_to_file(generated_text, args.output_file) except Exception as e: logger.error(f"Failed to generate text: {e}") if __name__ == "__main__": main()