gpt-moe-mcts / mcts_text_gen.py
RobbiePasquale's picture
Upload 3 files
8e083dc verified
import torch
from q_star import GPTWithMoE, GPTConfig, mcts_decode_single
def generate_text_with_mcts(
model: GPTWithMoE,
tokenizer, # Tokenizer to encode and decode text
prompt: str,
max_length: int = 50,
num_simulations: int = 100,
c_puct: float = 1.0,
top_k: int = 10,
device: str = "cuda"
):
"""
Generate text using the GPTWithMoE model and MCTS-based decoding.
Args:
model (GPTWithMoE): The trained model.
tokenizer: The tokenizer for text encoding and decoding.
prompt (str): The initial text prompt.
max_length (int): Maximum length of the generated text.
num_simulations (int): Number of MCTS simulations for each decoding step.
c_puct (float): Exploration parameter for MCTS.
top_k (int): Top-k tokens to consider during MCTS expansion.
device (str): Device to use for computation.
Returns:
str: The generated text.
"""
model.eval()
model.to(device)
# Encode the prompt into input_ids
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Use MCTS to decode the sequence
generated_ids = mcts_decode_single(
model=model,
input_ids=input_ids,
max_length=max_length,
num_simulations=num_simulations,
c_puct=c_puct,
top_k=top_k,
)
# Decode the generated IDs back to text
generated_text = tokenizer.decode(generated_ids.tolist(), skip_special_tokens=True)
return generated_text
if __name__ == "__main__":
from transformers import GPT2Tokenizer
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize the tokenizer (adapt as per your model's tokenizer)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# Load the trained model
config = GPTConfig(vocab_size=50304, block_size=512, n_layer=6, n_head=4, n_embd=256)
model = GPTWithMoE(config, num_experts=3, expert_layers=3, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device=device)
model.load_state_dict(torch.load("C:\\Users\\Admin\\MODELS\\moe_mcts_new.pt", map_location=device))
# Generate text using a prompt
prompt = "Once upon a time in a distant galaxy,"
generated_text = generate_text_with_mcts(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_length=100,
num_simulations=50,
c_puct=1.5,
top_k=5,
device=device,
)
print("Generated Text:")
print(generated_text)