Chris4K's picture
Update app.py
8f1b448 verified
raw
history blame
4.51 kB
# Install necessary libraries
#!pip install transformers accelerate datasets gradio sympy
# Import libraries
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import sympy
# Load Model and Tokenizer
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
PRM_NAME = "RLHFlow/Llama3.1-8B-PRM"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load LLaMA model
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
return model.to(device), tokenizer
llama_model, llama_tokenizer = load_model(MODEL_NAME)
# Load Process Reward Model (PRM)
prm_model, prm_tokenizer = load_model(PRM_NAME)
# Strategies
def majority_voting(prompt, num_samples=5):
outputs = []
for _ in range(num_samples):
input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
output = llama_model.generate(input_ids, max_new_tokens=50)
outputs.append(llama_tokenizer.decode(output[0], skip_special_tokens=True))
# Return the most common result
return max(set(outputs), key=outputs.count)
def best_of_n(prompt, num_samples=5):
scored_outputs = []
for _ in range(num_samples):
input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
output = llama_model.generate(input_ids, max_new_tokens=50)
response = llama_tokenizer.decode(output[0], skip_special_tokens=True)
score = prm_model(**prm_tokenizer(response, return_tensors="pt").to(device)).logits.mean().item()
scored_outputs.append((response, score))
# Return the highest scored response
return max(scored_outputs, key=lambda x: x[1])[0]
def beam_search(prompt, num_beams=5):
input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
outputs = llama_model.generate(input_ids, max_new_tokens=50, num_beams=num_beams, num_return_sequences=num_beams)
return [llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
def dvts(prompt, depth=3, breadth=2):
"""
Simplified implementation of DVTS: generates a tree of solutions and evaluates branches using PRM.
"""
results = []
for _ in range(breadth):
input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
output = llama_model.generate(input_ids, max_new_tokens=50)
response = llama_tokenizer.decode(output[0], skip_special_tokens=True)
score = prm_model(**prm_tokenizer(response, return_tensors="pt").to(device)).logits.mean().item()
results.append((response, score))
# Select the top responses and expand them recursively
for _ in range(depth - 1):
best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth]
for response, _ in best_responses:
input_ids = llama_tokenizer(response, return_tensors="pt").input_ids.to(device)
output = llama_model.generate(input_ids, max_new_tokens=50)
extended_response = llama_tokenizer.decode(output[0], skip_special_tokens=True)
score = prm_model(**prm_tokenizer(extended_response, return_tensors="pt").to(device)).logits.mean().item()
results.append((extended_response, score))
# Return the best overall response
return max(results, key=lambda x: x[1])[0]
# Gradio Interface
def inference(prompt, strategy, num_samples, depth, breadth):
if strategy == "Majority Voting":
return majority_voting(prompt, num_samples)
elif strategy == "Best-of-N":
return best_of_n(prompt, num_samples)
elif strategy == "Beam Search":
return beam_search(prompt, num_samples)
elif strategy == "DVTS":
return dvts(prompt, depth, breadth)
else:
return "Invalid Strategy"
gr.Interface(
fn=inference,
inputs=[
gr.Textbox(label="Problem Statement", placeholder="Enter your problem here"),
gr.Radio(
["Majority Voting", "Best-of-N", "Beam Search", "DVTS"],
label="Inference Strategy",
),
gr.Slider(1, 10, step=1, value=5, label="Number of Samples"),
gr.Slider(1, 5, step=1, value=3, label="Depth (DVTS Only)"),
gr.Slider(1, 5, step=1, value=2, label="Breadth (DVTS Only)"),
],
outputs="text",
title="Dynamic Inference Toolkit",
description="Explore test-time compute scaling strategies with Meta's LLaMA model.",
).launch()