enro-rlhf / app.py
coroianpetruta's picture
Improved random sentence generator
2f1571b
raw
history blame contribute delete
No virus
4.61 kB
# AUTOGENERATED! DO NOT EDIT! File to edit: enro-app.ipynb.
# %% auto 0
__all__ = ['model_name', 'tokenizer', 'model', 'dataset', 'dialogs', 'encoded_key', 'decoded_bytes', 'firebase_creds',
'get_translations', 'clean_sentence', 'get_random_sentence', 'save_option_to_repo', 'update_prompt']
# %% enro-app.ipynb 1
from transformers import MarianMTModel, MarianTokenizer
import gradio as gr
# %% enro-app.ipynb 2
model_name = "Helsinki-NLP/opus-mt-tc-big-en-ro"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
# %% enro-app.ipynb 3
import random
def get_translations(input_text):
tokenized_text = tokenizer(input_text, return_tensors="pt")
# Generate multiple translations
# Set num_return_sequences to the number of translations you want
translated = model.generate(**tokenized_text, num_return_sequences=4)
translations = []
for t in translated:
translations.append(tokenizer.decode(t, skip_special_tokens=True))
random_translations = random.sample(translations, 2)
return random_translations[0], random_translations[1]
def get_top_translation(input_text):
tokenized_text = tokenizer(input_text, return_tensors="pt")
# Generate multiple translations
# Set num_return_sequences to the number of translations you want
translated = model.generate(**tokenized_text)
translations = []
translations.append(tokenizer.decode(translated, skip_special_tokens=True))
return translations[0]
# %% enro-app.ipynb 5
from datasets import load_dataset
dataset = load_dataset("daily_dialog", trust_remote_code=True)
# %% enro-app.ipynb 6
import re
dialogs = dataset["train"]
def flatten(xss):
return [x for xs in xss for x in xs]
def split_keep_delimiters(s):
# Remove spaces before dots, exclamation points, commas, and question marks
s = re.sub(r'\s+([.,!?])', r'\1', s)
# Remove spaces before and after apostrophes
s = re.sub(r"\s*['’]\s*", r"'", s)
# Use re.findall to split by the delimiters while keeping them
parts = re.findall(r'[^.!?\s][^.!?]*[.!?]', s)
parts = [part.capitalize() for part in parts]
return parts
random_sentences = flatten(dialogs["dialog"])
random_sentences_stripped = []
for line in random_sentences:
sentences = split_keep_delimiters(line)
for sentence in sentences:
random_sentences_stripped.append(sentence)
# Function to randomly select one sentence from the dataset
def get_random_sentence():
return random.choice(random_sentences_stripped)
# %% enro-app.ipynb 7
import os
import firebase_admin
from firebase_admin import credentials
from firebase_admin import db
import json
import base64
encoded_key = os.getenv('FIREBASE_KEY')
decoded_bytes = base64.b64decode(encoded_key)
firebase_creds = json.loads(decoded_bytes.decode('utf-8'))
if not firebase_admin._apps:
cred = credentials.Certificate(firebase_creds)
firebase_admin.initialize_app(cred, {
'databaseURL': 'https://ro-en-llm-default-rtdb.firebaseio.com/'
})
# %% enro-app.ipynb 8
def save_option_to_repo(trans_prompt, translation1, translation2, button):
ref = db.reference('dpo_feedback')
# Push new data to the database
if button == "Translation 1":
ref.push({
'prompt': trans_prompt,
'chosen': translation1,
'rejected': translation2
})
if button == "Translation 2":
ref.push({
'prompt': trans_prompt,
'chosen': translation2,
'rejected': translation1
})
def update_prompt():
return get_random_sentence()
with gr.Blocks() as demo:
translations = []
option_buttons = []
with gr.Row():
prompt = gr.components.Textbox(scale = 4)
example_sentence = gr.Button("Get Random Sentence", scale = 1)
with gr.Row():
generate = gr.Button("Translate")
with gr.Row(equal_height=True):
translations.append(gr.components.Text(label=f"Translation 1"))
translations.append(gr.components.Text(label=f"Translation 2"))
with gr.Row():
option_buttons.append(gr.Button(value="Translation 1"))
option_buttons.append(gr.Button(value="Equal"))
option_buttons.append(gr.Button(value="Translation 2"))
generate.click(get_translations, inputs=prompt, outputs=translations)
example_sentence.click(update_prompt, outputs=prompt)
for i in range(0,3):
option_buttons[i].click(save_option_to_repo, inputs=[prompt, translations[0], translations[1], option_buttons[i]])
demo.launch()