Spaces:
Sleeping
Sleeping
# AUTOGENERATED! DO NOT EDIT! File to edit: enro-app.ipynb. | |
# %% auto 0 | |
__all__ = ['model_name', 'tokenizer', 'model', 'dataset', 'dialogs', 'encoded_key', 'decoded_bytes', 'firebase_creds', | |
'get_translations', 'clean_sentence', 'get_random_sentence', 'save_option_to_repo', 'update_prompt'] | |
# %% enro-app.ipynb 1 | |
from transformers import MarianMTModel, MarianTokenizer | |
import gradio as gr | |
# %% enro-app.ipynb 2 | |
model_name = "Helsinki-NLP/opus-mt-tc-big-en-ro" | |
tokenizer = MarianTokenizer.from_pretrained(model_name) | |
model = MarianMTModel.from_pretrained(model_name) | |
# %% enro-app.ipynb 3 | |
import random | |
def get_translations(input_text): | |
tokenized_text = tokenizer(input_text, return_tensors="pt") | |
# Generate multiple translations | |
# Set num_return_sequences to the number of translations you want | |
translated = model.generate(**tokenized_text, num_return_sequences=4) | |
translations = [] | |
for t in translated: | |
translations.append(tokenizer.decode(t, skip_special_tokens=True)) | |
random_translations = random.sample(translations, 2) | |
return random_translations[0], random_translations[1] | |
def get_top_translation(input_text): | |
tokenized_text = tokenizer(input_text, return_tensors="pt") | |
# Generate multiple translations | |
# Set num_return_sequences to the number of translations you want | |
translated = model.generate(**tokenized_text) | |
translations = [] | |
translations.append(tokenizer.decode(translated, skip_special_tokens=True)) | |
return translations[0] | |
# %% enro-app.ipynb 5 | |
from datasets import load_dataset | |
dataset = load_dataset("daily_dialog", trust_remote_code=True) | |
# %% enro-app.ipynb 6 | |
import re | |
dialogs = dataset["train"] | |
def flatten(xss): | |
return [x for xs in xss for x in xs] | |
def split_keep_delimiters(s): | |
# Remove spaces before dots, exclamation points, commas, and question marks | |
s = re.sub(r'\s+([.,!?])', r'\1', s) | |
# Remove spaces before and after apostrophes | |
s = re.sub(r"\s*['’]\s*", r"'", s) | |
# Use re.findall to split by the delimiters while keeping them | |
parts = re.findall(r'[^.!?\s][^.!?]*[.!?]', s) | |
parts = [part.capitalize() for part in parts] | |
return parts | |
random_sentences = flatten(dialogs["dialog"]) | |
random_sentences_stripped = [] | |
for line in random_sentences: | |
sentences = split_keep_delimiters(line) | |
for sentence in sentences: | |
random_sentences_stripped.append(sentence) | |
# Function to randomly select one sentence from the dataset | |
def get_random_sentence(): | |
return random.choice(random_sentences_stripped) | |
# %% enro-app.ipynb 7 | |
import os | |
import firebase_admin | |
from firebase_admin import credentials | |
from firebase_admin import db | |
import json | |
import base64 | |
encoded_key = os.getenv('FIREBASE_KEY') | |
decoded_bytes = base64.b64decode(encoded_key) | |
firebase_creds = json.loads(decoded_bytes.decode('utf-8')) | |
if not firebase_admin._apps: | |
cred = credentials.Certificate(firebase_creds) | |
firebase_admin.initialize_app(cred, { | |
'databaseURL': 'https://ro-en-llm-default-rtdb.firebaseio.com/' | |
}) | |
# %% enro-app.ipynb 8 | |
def save_option_to_repo(trans_prompt, translation1, translation2, button): | |
ref = db.reference('dpo_feedback') | |
# Push new data to the database | |
if button == "Translation 1": | |
ref.push({ | |
'prompt': trans_prompt, | |
'chosen': translation1, | |
'rejected': translation2 | |
}) | |
if button == "Translation 2": | |
ref.push({ | |
'prompt': trans_prompt, | |
'chosen': translation2, | |
'rejected': translation1 | |
}) | |
def update_prompt(): | |
return get_random_sentence() | |
with gr.Blocks() as demo: | |
translations = [] | |
option_buttons = [] | |
with gr.Row(): | |
prompt = gr.components.Textbox(scale = 4) | |
example_sentence = gr.Button("Get Random Sentence", scale = 1) | |
with gr.Row(): | |
generate = gr.Button("Translate") | |
with gr.Row(equal_height=True): | |
translations.append(gr.components.Text(label=f"Translation 1")) | |
translations.append(gr.components.Text(label=f"Translation 2")) | |
with gr.Row(): | |
option_buttons.append(gr.Button(value="Translation 1")) | |
option_buttons.append(gr.Button(value="Equal")) | |
option_buttons.append(gr.Button(value="Translation 2")) | |
generate.click(get_translations, inputs=prompt, outputs=translations) | |
example_sentence.click(update_prompt, outputs=prompt) | |
for i in range(0,3): | |
option_buttons[i].click(save_option_to_repo, inputs=[prompt, translations[0], translations[1], option_buttons[i]]) | |
demo.launch() | |