rl2 / app.py
klasser's picture
Update app.py
ed83615 verified
import streamlit as st
import torch
import transformers
from trl import AutoModelForCausalLMWithValueHead
import math
import time
st.set_page_config(page_title="RLHF Magic | Movie Reviews", page_icon="🍿", layout="wide")
st.markdown("""
<style>
.big-font { font-size:22px !important; font-weight: 500; }
.stProgress .st-bo { transition: background-color 0.5s ease; }
</style>
""", unsafe_allow_html=True)
st.title("🍿 Нейросеть-Кинокритик: До и После RLHF")
st.markdown("""
<div class="big-font">
Посмотрите, как работает магия обучения с подкреплением (RLHF). <br>
Слева — базовая модель GPT-2, которая пишет что вздумается. Справа — та же модель, но <b>натренированная всегда писать позитивные отзывы</b>, даже если вы начинаете текст с ужасных слов!
</div>
<br>
""", unsafe_allow_html=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@st.cache_resource
def load_models():
reward_path = "reward_model_trained"
ppo_path = "ppo_model_trained"
orig_model_name = "lvwerra/gpt2-imdb"
# 1. Reward Model
reward_tokenizer = transformers.AutoTokenizer.from_pretrained(reward_path)
reward_model = transformers.AutoModelForSequenceClassification.from_pretrained(reward_path).to(DEVICE).eval()
# 2. Original GPT-2
orig_tokenizer = transformers.AutoTokenizer.from_pretrained(orig_model_name)
if orig_tokenizer.pad_token is None:
orig_tokenizer.pad_token = orig_tokenizer.eos_token
orig_model = transformers.AutoModelForCausalLM.from_pretrained(orig_model_name).to(DEVICE).eval()
# 3. RLHF Model
rlhf_model_full = AutoModelForCausalLMWithValueHead.from_pretrained(ppo_path).to(DEVICE).eval()
rlhf_model = rlhf_model_full.pretrained_model
return reward_model, reward_tokenizer, orig_model, orig_tokenizer, rlhf_model
with st.spinner("⏳ Подготовка нейросетей... (занимает около минуты при первом старте)"):
reward_model, reward_tokenizer, orig_model, orig_tokenizer, rlhf_model = load_models()
def compute_reward(text):
inputs = reward_tokenizer(text, truncation=True, max_length=512, padding=True, return_tensors="pt").to(DEVICE)
with torch.no_grad():
score = reward_model(**inputs).logits[0, 0].item()
return score
def get_positivity_percent(score):
return int((1 / (1 + math.exp(-score))) * 100)
def generate_text(model, tokenizer, prompt, max_new_tokens, temperature, top_p):
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=True,
temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def stream_text(text, delay=0.03):
for word in text.split(" "):
yield word + " "
time.sleep(delay)
st.sidebar.image("https://huggingface.co/front/assets/huggingface_logo-noborder.svg", width=50)
st.sidebar.header("🎛 Настройки генерации")
max_tokens = st.sidebar.slider("Длина продолжения (токенов)", 20, 150, 70)
temp = st.sidebar.slider("Креативность (Temperature)", 0.1, 1.5, 0.8)
st.sidebar.info("💡 **Попробуйте начать так:**\n\n- *I hate this movie because*\n- *The acting was terrible and*\n- *To be honest, the plot was*")
# Главное поле ввода
user_prompt = st.text_input("✍️ Напишите начало отзыва (на англ.) и нажмите Enter:",
value="The director tried to make a good movie and",
max_chars=100)
if st.button("Мне повезет!", type="primary", use_container_width=True):
# Сначала генерируем всё за кулисами
with st.spinner("GPT goes brrr..."):
orig_text = generate_text(orig_model, orig_tokenizer, user_prompt, max_tokens, temp, 0.95)
orig_reward = compute_reward(orig_text)
orig_percent = get_positivity_percent(orig_reward)
rlhf_text = generate_text(rlhf_model, orig_tokenizer, user_prompt, max_tokens, temp, 0.95)
rlhf_reward = compute_reward(rlhf_text)
rlhf_percent = get_positivity_percent(rlhf_reward)
st.markdown("---")
# Создаем две колонки
col1, col2 = st.columns(2)
# КОЛОНКА 1: Оригинальная модель
with col1:
with st.container(border=True):
st.subheader("До RLHF (Свободная GPT-2)")
st.caption("Пишет как попало (может быть негативной)")
# Уровень позитивности с цветным баром
st.progress(orig_percent / 100, text=f"Уровень позитивности: {orig_percent}%")
# Эффект печатной машинки
st.write_stream(stream_text(orig_text))
# КОЛОНКА 2: Обученная модель
with col2:
with st.container(border=True):
st.subheader("После RLHF (Good Boy Model)")
st.caption("Старается вырулить любой текст в позитив")
# Уровень позитивности с цветным баром
st.progress(rlhf_percent / 100, text=f"Уровень позитивности: {rlhf_percent}%")
# Спит чуть-чуть, чтобы эффект был последовательным
time.sleep(1)
st.write_stream(stream_text(rlhf_text, delay=0.04))
# Добавляем эмоций в конце
if rlhf_percent > orig_percent + 20 and rlhf_percent > 70:
st.balloons()
st.toast('🎉 RLHF модель блестяще спасла ситуацию!', icon='😍')
elif rlhf_percent < 50:
st.toast('Начало было настолько суровым, что даже RLHF сдалась.', icon='💀')