enro-rlhf / app.py
coroianpetruta's picture
Gradio blocks interface
50b88d8
raw
history blame
No virus
3.41 kB
# AUTOGENERATED! DO NOT EDIT! File to edit: enro-app.ipynb.
# %% auto 0
__all__ = ['model_name', 'tokenizer', 'model', 'example_translations', 'encoded_key', 'decoded_bytes', 'firebase_creds',
'get_translations', 'save_option_to_repo', 'get_random_sentence', 'update_prompt']
# %% enro-app.ipynb 1
from transformers import MarianMTModel, MarianTokenizer
import gradio as gr
# %% enro-app.ipynb 2
model_name = "Helsinki-NLP/opus-mt-tc-big-en-ro"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
# %% enro-app.ipynb 3
import random
def get_translations(input_text):
tokenized_text = tokenizer(input_text, return_tensors="pt")
# Generate multiple translations
# Set num_return_sequences to the number of translations you want
translated = model.generate(**tokenized_text, num_return_sequences=4)
translations = []
for t in translated:
translations.append(tokenizer.decode(t, skip_special_tokens=True))
random_translations = random.sample(translations, 2)
return random_translations[0], random_translations[1]
# %% enro-app.ipynb 5
example_translations = ["What is the way to get to the mosque?",
"Practice makes perfect.",
"The cat is eating the small mouse.",
"I'm not good at learning foreign languages.",
"What do you mean you don't know?",
"I should have sent a letter of apology.",
"His handwriting is poor.",
"He bowed his head."]
# %% enro-app.ipynb 6
import os
import firebase_admin
from firebase_admin import credentials
from firebase_admin import db
import json
import base64
encoded_key = os.getenv('FIREBASE_KEY')
decoded_bytes = base64.b64decode(encoded_key)
firebase_creds = json.loads(decoded_bytes.decode('utf-8'))
if not firebase_admin._apps:
cred = credentials.Certificate(firebase_creds)
firebase_admin.initialize_app(cred, {
'databaseURL': 'https://ro-en-llm-default-rtdb.firebaseio.com/'
})
# %% enro-app.ipynb 7
def save_option_to_repo(trans_prompt, translation1, translation2, button):
ref = db.reference('feedback')
# Push new data to the database
ref.push({
'prompt': trans_prompt,
'translation_1': translation1,
'translation_2': translation2,
'feedback': button
})
def get_random_sentence():
return random.sample(example_translations, 1)[0]
def update_prompt():
prompt.change(value=get_random_sentence())
with gr.Blocks() as demo:
translations = []
option_buttons = []
with gr.Row():
prompt = gr.components.Textbox(scale = 4)
example_sentence = gr.Button("Get Random Sentence", scale = 1)
with gr.Row():
generate = gr.Button("Translate")
with gr.Row(equal_height=True):
translations.append(gr.components.Text(label=f"Translation 1"))
translations.append(gr.components.Text(label=f"Translation 2"))
with gr.Row():
option_buttons.append(gr.Button(value="Translation 1"))
option_buttons.append(gr.Button(value="Equal"))
option_buttons.append(gr.Button(value="Translation 2"))
generate.click(get_translations, inputs=prompt, outputs=translations)
example_sentence.click(get_random_sentence, outputs=prompt)
for i in range(0,3):
option_buttons[i].click(save_option_to_repo, inputs=[prompt, translations[0], translations[1], option_buttons[i]])
demo.launch()