enro-rlhf / app.py
coroianpetruta's picture
App.py
64e12c3
# AUTOGENERATED! DO NOT EDIT! File to edit: enro-app.ipynb.
# %% auto 0
__all__ = ['model_name', 'tokenizer', 'model', 'dataset', 'dialogs', 'encoded_key', 'decoded_bytes', 'firebase_creds',
'get_translations', 'clean_sentence', 'get_random_sentence', 'save_option_to_repo', 'update_prompt']
# %% enro-app.ipynb 1
from transformers import MarianMTModel, MarianTokenizer
import gradio as gr
# %% enro-app.ipynb 2
model_name = "Helsinki-NLP/opus-mt-tc-big-en-ro"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
# %% enro-app.ipynb 3
import random
def get_translations(input_text):
tokenized_text = tokenizer(input_text, return_tensors="pt")
# Generate multiple translations
# Set num_return_sequences to the number of translations you want
translated = model.generate(**tokenized_text, num_return_sequences=4)
translations = []
for t in translated:
translations.append(tokenizer.decode(t, skip_special_tokens=True))
random_translations = random.sample(translations, 2)
return random_translations[0], random_translations[1]
# %% enro-app.ipynb 5
from datasets import load_dataset
dataset = load_dataset("daily_dialog")
# %% enro-app.ipynb 6
import re
dialogs = dataset["train"]
# Function to clean extra spaces around punctuation marks
def clean_sentence(sentence):
# Remove space before punctuation
sentence = re.sub(r'\s+([?.!,"\'-])', r'\1', sentence)
# Remove space after punctuation
sentence = re.sub(r'([?.!,"\'-])\s+', r'\1 ', sentence)
sentence = sentence.strip()
return sentence
# Assuming dialogues is a list of lists, where each inner list contains sentences of a dialogue
# Example: dialogues = [["Hello, how are you?", "I'm fine, thank you!"], ["What's your name?", "My name is John."]]
# Function to randomly select one sentence from the dataset
def get_random_sentence():
# Select a random dialogue
random_dialogue = random.choice(dialogs['dialog'])
# Select a random sentence from the chosen dialogue
random_sentence = random.choice(random_dialogue)
# Clean the sentence
clean_random_sentence = clean_sentence(random_sentence)
return clean_random_sentence
# %% enro-app.ipynb 7
import os
import firebase_admin
from firebase_admin import credentials
from firebase_admin import db
import json
import base64
encoded_key = os.getenv('FIREBASE_KEY')
decoded_bytes = base64.b64decode(encoded_key)
firebase_creds = json.loads(decoded_bytes.decode('utf-8'))
if not firebase_admin._apps:
cred = credentials.Certificate(firebase_creds)
firebase_admin.initialize_app(cred, {
'databaseURL': 'https://ro-en-llm-default-rtdb.firebaseio.com/'
})
# %% enro-app.ipynb 8
def save_option_to_repo(trans_prompt, translation1, translation2, button):
ref = db.reference('dpo_feedback')
# Push new data to the database
if button == "Translation 1":
ref.push({
'prompt': trans_prompt,
'chosen': translation1,
'rejected': translation2
})
if button == "Translation 2":
ref.push({
'prompt': trans_prompt,
'chosen': translation2,
'rejected': translation1
})
def update_prompt():
return get_random_sentence()
with gr.Blocks() as demo:
translations = []
option_buttons = []
with gr.Row():
prompt = gr.components.Textbox(scale = 4)
example_sentence = gr.Button("Get Random Sentence", scale = 1)
with gr.Row():
generate = gr.Button("Translate")
with gr.Row(equal_height=True):
translations.append(gr.components.Text(label=f"Translation 1"))
translations.append(gr.components.Text(label=f"Translation 2"))
with gr.Row():
option_buttons.append(gr.Button(value="Translation 1"))
option_buttons.append(gr.Button(value="Equal"))
option_buttons.append(gr.Button(value="Translation 2"))
generate.click(get_translations, inputs=prompt, outputs=translations)
example_sentence.click(update_prompt, outputs=prompt)
for i in range(0,3):
option_buttons[i].click(save_option_to_repo, inputs=[prompt, translations[0], translations[1], option_buttons[i]])
demo.launch()