File size: 4,693 Bytes
821b37c
 
 
44f2a0a
 
821b37c
 
 
2f5766d
 
821b37c
 
 
 
 
 
50b88d8
 
821b37c
 
 
 
 
 
 
 
50b88d8
 
821b37c
c9cd7db
 
 
 
 
 
 
 
 
821b37c
44f2a0a
c9cd7db
50b88d8
 
44f2a0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d955ea
44f2a0a
 
 
 
 
 
 
 
 
 
 
 
50b88d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
821b37c
44f2a0a
50b88d8
9bb6eef
50b88d8
9bb6eef
 
 
 
 
 
 
 
 
 
 
 
50b88d8
44f2a0a
50b88d8
 
9bcd4c6
50b88d8
 
 
 
 
821b37c
50b88d8
 
 
 
 
 
821b37c
50b88d8
 
 
 
 
 
 
 
2f5766d
50b88d8
64e12c3
50b88d8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# AUTOGENERATED! DO NOT EDIT! File to edit: enro-app.ipynb.

# %% auto 0
__all__ = ['model_name', 'tokenizer', 'model', 'dataset', 'dialogs', 'encoded_key', 'decoded_bytes', 'firebase_creds',
           'get_translations', 'clean_sentence', 'get_random_sentence', 'save_option_to_repo', 'update_prompt']

# %% enro-app.ipynb 1
from transformers import MarianMTModel, MarianTokenizer
import gradio as gr

# %% enro-app.ipynb 2
model_name = "Helsinki-NLP/opus-mt-tc-big-en-ro"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)

# %% enro-app.ipynb 3
import random

def get_translations(input_text):
    tokenized_text = tokenizer(input_text, return_tensors="pt")
# Generate multiple translations
# Set num_return_sequences to the number of translations you want
    translated = model.generate(**tokenized_text, num_return_sequences=4)
    translations = []
    for t in translated:
        translations.append(tokenizer.decode(t, skip_special_tokens=True))
    random_translations = random.sample(translations, 2)
    return random_translations[0], random_translations[1]

def get_top_translation(input_text):
    tokenized_text = tokenizer(input_text, return_tensors="pt")
# Generate multiple translations
# Set num_return_sequences to the number of translations you want
    translated = model.generate(**tokenized_text)
    translations = []
    translations.append(tokenizer.decode(translated, skip_special_tokens=True))
    return translations[0]

# %% enro-app.ipynb 5
from datasets import load_dataset
dataset = load_dataset("daily_dialog", trust_remote_code=True)

# %% enro-app.ipynb 6
import re

dialogs = dataset["train"]
# Function to clean extra spaces around punctuation marks
def clean_sentence(sentence):
    # Remove space before punctuation
    sentence = re.sub(r'\s+([?.!,"\'-])', r'\1', sentence)
    # Remove space after punctuation
    sentence = re.sub(r'([?.!,"\'-])\s+', r'\1 ', sentence)
    sentence = sentence.strip()
    return sentence

# Assuming dialogues is a list of lists, where each inner list contains sentences of a dialogue
# Example: dialogues = [["Hello, how are you?", "I'm fine, thank you!"], ["What's your name?", "My name is John."]]

# Function to randomly select one sentence from the dataset
def get_random_sentence():
    # Select a random dialogue
    random_dialogue = random.choice(dialogs['dialog'])
    # Select a random sentence from the chosen dialogue
    random_sentence = random.choice(random_dialogue)

    # Clean the sentence
    clean_random_sentence = clean_sentence(random_sentence)
    return clean_random_sentence



# %% enro-app.ipynb 7
import os
import firebase_admin
from firebase_admin import credentials
from firebase_admin import db
import json
import base64

encoded_key = os.getenv('FIREBASE_KEY')
decoded_bytes = base64.b64decode(encoded_key)
firebase_creds = json.loads(decoded_bytes.decode('utf-8'))
if not firebase_admin._apps:
    cred = credentials.Certificate(firebase_creds)
    firebase_admin.initialize_app(cred, {
        'databaseURL': 'https://ro-en-llm-default-rtdb.firebaseio.com/'
    })

# %% enro-app.ipynb 8
def save_option_to_repo(trans_prompt, translation1, translation2, button):
    ref = db.reference('dpo_feedback')
    # Push new data to the database
    if button == "Translation 1":
        ref.push({
            'prompt': trans_prompt,
            'chosen': translation1,
            'rejected': translation2
        })
    if button == "Translation 2":
        ref.push({
            'prompt': trans_prompt,
            'chosen': translation2,
            'rejected': translation1
        })



def update_prompt():
    return get_random_sentence() 

with gr.Blocks() as demo: 
    
    translations = []
    option_buttons = []

    with gr.Row():
        prompt = gr.components.Textbox(scale = 4)
        example_sentence = gr.Button("Get Random Sentence", scale = 1)
    
    with gr.Row():
        generate = gr.Button("Translate")

    with gr.Row(equal_height=True):
        translations.append(gr.components.Text(label=f"Translation 1"))
        translations.append(gr.components.Text(label=f"Translation 2"))
    
    with gr.Row():
        option_buttons.append(gr.Button(value="Translation 1"))
        option_buttons.append(gr.Button(value="Equal"))
        option_buttons.append(gr.Button(value="Translation 2"))

    generate.click(get_translations, inputs=prompt, outputs=translations)
    example_sentence.click(update_prompt, outputs=prompt)
    for i in range(0,3):
        option_buttons[i].click(save_option_to_repo, inputs=[prompt, translations[0], translations[1], option_buttons[i]])
    
    
demo.launch()