Spaces:
Runtime error
Runtime error
| # Install necessary libraries | |
| #!pip install transformers accelerate datasets gradio sympy | |
| # Import libraries | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import gradio as gr | |
| import sympy | |
| # Load Model and Tokenizer | |
| MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" | |
| PRM_NAME = "RLHFlow/Llama3.1-8B-PRM" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load LLaMA model | |
| def load_model(model_name): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | |
| return model.to(device), tokenizer | |
| llama_model, llama_tokenizer = load_model(MODEL_NAME) | |
| # Load Process Reward Model (PRM) | |
| prm_model, prm_tokenizer = load_model(PRM_NAME) | |
| # Strategies | |
| def majority_voting(prompt, num_samples=5): | |
| outputs = [] | |
| for _ in range(num_samples): | |
| input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
| output = llama_model.generate(input_ids, max_new_tokens=50) | |
| outputs.append(llama_tokenizer.decode(output[0], skip_special_tokens=True)) | |
| # Return the most common result | |
| return max(set(outputs), key=outputs.count) | |
| def best_of_n(prompt, num_samples=5): | |
| scored_outputs = [] | |
| for _ in range(num_samples): | |
| input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
| output = llama_model.generate(input_ids, max_new_tokens=50) | |
| response = llama_tokenizer.decode(output[0], skip_special_tokens=True) | |
| score = prm_model(**prm_tokenizer(response, return_tensors="pt").to(device)).logits.mean().item() | |
| scored_outputs.append((response, score)) | |
| # Return the highest scored response | |
| return max(scored_outputs, key=lambda x: x[1])[0] | |
| def beam_search(prompt, num_beams=5): | |
| input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
| outputs = llama_model.generate(input_ids, max_new_tokens=50, num_beams=num_beams, num_return_sequences=num_beams) | |
| return [llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs] | |
| def dvts(prompt, depth=3, breadth=2): | |
| """ | |
| Simplified implementation of DVTS: generates a tree of solutions and evaluates branches using PRM. | |
| """ | |
| results = [] | |
| for _ in range(breadth): | |
| input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
| output = llama_model.generate(input_ids, max_new_tokens=50) | |
| response = llama_tokenizer.decode(output[0], skip_special_tokens=True) | |
| score = prm_model(**prm_tokenizer(response, return_tensors="pt").to(device)).logits.mean().item() | |
| results.append((response, score)) | |
| # Select the top responses and expand them recursively | |
| for _ in range(depth - 1): | |
| best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth] | |
| for response, _ in best_responses: | |
| input_ids = llama_tokenizer(response, return_tensors="pt").input_ids.to(device) | |
| output = llama_model.generate(input_ids, max_new_tokens=50) | |
| extended_response = llama_tokenizer.decode(output[0], skip_special_tokens=True) | |
| score = prm_model(**prm_tokenizer(extended_response, return_tensors="pt").to(device)).logits.mean().item() | |
| results.append((extended_response, score)) | |
| # Return the best overall response | |
| return max(results, key=lambda x: x[1])[0] | |
| # Gradio Interface | |
| def inference(prompt, strategy, num_samples, depth, breadth): | |
| if strategy == "Majority Voting": | |
| return majority_voting(prompt, num_samples) | |
| elif strategy == "Best-of-N": | |
| return best_of_n(prompt, num_samples) | |
| elif strategy == "Beam Search": | |
| return beam_search(prompt, num_samples) | |
| elif strategy == "DVTS": | |
| return dvts(prompt, depth, breadth) | |
| else: | |
| return "Invalid Strategy" | |
| gr.Interface( | |
| fn=inference, | |
| inputs=[ | |
| gr.Textbox(label="Problem Statement", placeholder="Enter your problem here"), | |
| gr.Radio( | |
| ["Majority Voting", "Best-of-N", "Beam Search", "DVTS"], | |
| label="Inference Strategy", | |
| ), | |
| gr.Slider(1, 10, step=1, value=5, label="Number of Samples"), | |
| gr.Slider(1, 5, step=1, value=3, label="Depth (DVTS Only)"), | |
| gr.Slider(1, 5, step=1, value=2, label="Breadth (DVTS Only)"), | |
| ], | |
| outputs="text", | |
| title="Dynamic Inference Toolkit", | |
| description="Explore test-time compute scaling strategies with Meta's LLaMA model.", | |
| ).launch() | |