GPT2-PBE / app.py
tymbos's picture
Create app.py
0430da2 verified
raw
history blame
4.49 kB
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from io import BytesIO
from train_tokenizer import train_tokenizer
from tokenizers import Tokenizer
from datasets import load_dataset
import tempfile
import os
def create_iterator(files=None, dataset_name=None, split="train", streaming=True):
if dataset_name:
dataset = load_dataset(dataset_name, split=split, streaming=streaming)
for example in dataset:
yield example['text']
elif files:
for file in files:
with open(file.name, 'r', encoding='utf-8') as f:
for line in f:
yield line.strip()
def enhanced_validation(tokenizer, test_text):
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded.ids)
# Ανάλυση Unknown Tokens
unknown_tokens = sum(1 for t in encoded.tokens if t == "<unk>")
unknown_percent = unknown_tokens / len(encoded.tokens) * 100 if encoded.tokens else 0
# Κατανομή μηκών tokens
token_lengths = [len(t) for t in encoded.tokens]
avg_length = np.mean(token_lengths) if token_lengths else 0
# Έλεγχος code coverage
code_symbols = ['{', '}', '(', ')', ';', '//', 'printf']
code_coverage = {sym: sym in test_text and sym in encoded.tokens for sym in code_symbols}
# Δημιουργία histogram
fig = plt.figure()
plt.hist(token_lengths, bins=20)
plt.xlabel('Token Length')
plt.ylabel('Frequency')
img_buffer = BytesIO()
plt.savefig(img_buffer, format='png')
plt.close()
return {
"roundtrip_success": test_text == decoded,
"unknown_tokens": f"{unknown_tokens} ({unknown_percent:.2f}%)",
"average_token_length": f"{avg_length:.2f}",
"code_coverage": code_coverage,
"token_length_distribution": img_buffer.getvalue()
}
def train_and_test(files, dataset_name, split, vocab_size, min_freq, test_text):
# Επιβεβαίωση εισόδων
if not files and not dataset_name:
raise gr.Error("Πρέπει να παρέχετε αρχεία ή όνομα dataset!")
# Δημιουργία iterator με streaming
iterator = create_iterator(files, dataset_name, split)
try:
tokenizer = train_tokenizer(iterator, vocab_size, min_freq)
except Exception as e:
raise gr.Error(f"Σφάλμα εκπαίδευσης: {str(e)}")
# Αποθήκευση και φόρτωση για validation
with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as f:
tokenizer.save(f.name)
trained_tokenizer = Tokenizer.from_file(f.name)
os.unlink(f.name)
# Εκτενής επικύρωση
validation = enhanced_validation(trained_tokenizer, test_text)
return {
"validation_metrics": {k:v for k,v in validation.items() if k != "token_length_distribution"},
"histogram": validation["token_length_distribution"]
}
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Προχωρημένος BPE Tokenizer Trainer")
with gr.Row():
with gr.Column():
with gr.Tab("Local Files"):
file_input = gr.File(file_count="multiple", label="Ανέβασμα αρχείων")
with gr.Tab("Hugging Face Dataset"):
dataset_name = gr.Textbox(label="Όνομα Dataset (π.χ. 'wikitext', 'codeparrot/github-code')")
split = gr.Textbox(value="train", label="Split")
vocab_size = gr.Slider(1000, 100000, value=32000, label="Μέγεθος Λεξιλογίου")
min_freq = gr.Slider(1, 100, value=2, label="Ελάχιστη Συχνότητα")
test_text = gr.Textbox(
value='function helloWorld() { console.log("Γειά σου Κόσμε!"); } // Ελληνικά + κώδικας',
label="Test Text"
)
train_btn = gr.Button("Εκπαίδευση Tokenizer", variant="primary")
with gr.Column():
results_json = gr.JSON(label="Μετρικές")
results_plot = gr.Image(label="Κατανομή Μηκών Tokens")
train_btn.click(
fn=train_and_test,
inputs=[file_input, dataset_name, split, vocab_size, min_freq, test_text],
outputs=[results_json, results_plot]
)
if __name__ == "__main__":
demo.launch()