import torch from q_star import GPTWithMoE, GPTConfig, mcts_decode_single def generate_text_with_mcts( model: GPTWithMoE, tokenizer, # Tokenizer to encode and decode text prompt: str, max_length: int = 50, num_simulations: int = 100, c_puct: float = 1.0, top_k: int = 10, device: str = "cuda" ): """ Generate text using the GPTWithMoE model and MCTS-based decoding. Args: model (GPTWithMoE): The trained model. tokenizer: The tokenizer for text encoding and decoding. prompt (str): The initial text prompt. max_length (int): Maximum length of the generated text. num_simulations (int): Number of MCTS simulations for each decoding step. c_puct (float): Exploration parameter for MCTS. top_k (int): Top-k tokens to consider during MCTS expansion. device (str): Device to use for computation. Returns: str: The generated text. """ model.eval() model.to(device) # Encode the prompt into input_ids input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) # Use MCTS to decode the sequence generated_ids = mcts_decode_single( model=model, input_ids=input_ids, max_length=max_length, num_simulations=num_simulations, c_puct=c_puct, top_k=top_k, ) # Decode the generated IDs back to text generated_text = tokenizer.decode(generated_ids.tolist(), skip_special_tokens=True) return generated_text if __name__ == "__main__": from transformers import GPT2Tokenizer # Define the device device = "cuda" if torch.cuda.is_available() else "cpu" # Initialize the tokenizer (adapt as per your model's tokenizer) tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token # Load the trained model config = GPTConfig(vocab_size=50304, block_size=512, n_layer=6, n_head=4, n_embd=256) model = GPTWithMoE(config, num_experts=3, expert_layers=3, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device=device) model.load_state_dict(torch.load("C:\\Users\\Admin\\MODELS\\moe_mcts_new.pt", map_location=device)) # Generate text using a prompt prompt = "Once upon a time in a distant galaxy," generated_text = generate_text_with_mcts( model=model, tokenizer=tokenizer, prompt=prompt, max_length=100, num_simulations=50, c_puct=1.5, top_k=5, device=device, ) print("Generated Text:") print(generated_text)