|
import torch
|
|
from q_star import GPTWithMoE, GPTConfig, mcts_decode_single
|
|
|
|
|
|
def generate_text_with_mcts(
|
|
model: GPTWithMoE,
|
|
tokenizer,
|
|
prompt: str,
|
|
max_length: int = 50,
|
|
num_simulations: int = 100,
|
|
c_puct: float = 1.0,
|
|
top_k: int = 10,
|
|
device: str = "cuda"
|
|
):
|
|
"""
|
|
Generate text using the GPTWithMoE model and MCTS-based decoding.
|
|
|
|
Args:
|
|
model (GPTWithMoE): The trained model.
|
|
tokenizer: The tokenizer for text encoding and decoding.
|
|
prompt (str): The initial text prompt.
|
|
max_length (int): Maximum length of the generated text.
|
|
num_simulations (int): Number of MCTS simulations for each decoding step.
|
|
c_puct (float): Exploration parameter for MCTS.
|
|
top_k (int): Top-k tokens to consider during MCTS expansion.
|
|
device (str): Device to use for computation.
|
|
|
|
Returns:
|
|
str: The generated text.
|
|
"""
|
|
model.eval()
|
|
model.to(device)
|
|
|
|
|
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
|
|
|
|
|
generated_ids = mcts_decode_single(
|
|
model=model,
|
|
input_ids=input_ids,
|
|
max_length=max_length,
|
|
num_simulations=num_simulations,
|
|
c_puct=c_puct,
|
|
top_k=top_k,
|
|
)
|
|
|
|
|
|
generated_text = tokenizer.decode(generated_ids.tolist(), skip_special_tokens=True)
|
|
return generated_text
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from transformers import GPT2Tokenizer
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
config = GPTConfig(vocab_size=50304, block_size=512, n_layer=6, n_head=4, n_embd=256)
|
|
model = GPTWithMoE(config, num_experts=3, expert_layers=3, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device=device)
|
|
model.load_state_dict(torch.load("C:\\Users\\Admin\\MODELS\\moe_mcts_new.pt", map_location=device))
|
|
|
|
|
|
prompt = "Once upon a time in a distant galaxy,"
|
|
generated_text = generate_text_with_mcts(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
prompt=prompt,
|
|
max_length=100,
|
|
num_simulations=50,
|
|
c_puct=1.5,
|
|
top_k=5,
|
|
device=device,
|
|
)
|
|
|
|
print("Generated Text:")
|
|
print(generated_text)
|
|
|