|
import os |
|
from dotenv import load_dotenv |
|
import torch |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments |
|
from datasets import load_dataset, concatenate_datasets |
|
from huggingface_hub import login |
|
import time |
|
import uvicorn |
|
from fastapi import FastAPI |
|
|
|
load_dotenv() |
|
login(token=os.getenv('HUGGINGFACE_TOKEN')) |
|
|
|
model_name = 'gpt2' |
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name) |
|
model = GPT2LMHeadModel.from_pretrained(model_name) |
|
|
|
|
|
dataset_humanizado = load_dataset('daily_dialog', split='train', trust_remote_code=True) |
|
dataset_codigo = load_dataset('code_search_net', split='train', trust_remote_code=True) |
|
dataset_prompts = load_dataset('openai_humaneval', split='train', trust_remote_code=True) |
|
|
|
combined_dataset = concatenate_datasets([ |
|
dataset_humanizado, |
|
dataset_codigo, |
|
dataset_prompts |
|
]) |
|
|
|
|
|
def tokenize_function(examples): |
|
return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512) |
|
|
|
tokenized_dataset = combined_dataset.map(tokenize_function, batched=True) |
|
|
|
training_args = TrainingArguments( |
|
output_dir='./results', |
|
per_device_train_batch_size=100, |
|
per_device_eval_batch_size=100, |
|
num_train_epochs=0, |
|
learning_rate=1e-5, |
|
logging_steps=-1, |
|
max_grad_norm=1, |
|
save_total_limit=1, |
|
seed=42, |
|
weight_decay=0, |
|
warmup_ratio=0.0, |
|
evaluation_strategy="no", |
|
optim="adamw_torch", |
|
lr_scheduler_type="constant", |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=tokenized_dataset, |
|
) |
|
|
|
app = FastAPI() |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "Modelo entrenado y en ejecuci贸n."} |
|
|
|
@spaces.gpu |
|
def run_training(): |
|
while True: |
|
try: |
|
trainer.train() |
|
model.push_to_hub('Yhhxhfh/nombre_de_tu_modelo', repo_type='model', use_temp_dir=True, commit_message="Actualizaci贸n del modelo") |
|
tokenizer.push_to_hub('Yhhxhfh/nombre_de_tu_modelo', repo_type='model', use_temp_dir=True, commit_message="Actualizaci贸n del tokenizador") |
|
time.sleep(5) |
|
except Exception as e: |
|
print(f"Error durante el entrenamiento: {e}. Reiniciando el proceso de entrenamiento...") |
|
time.sleep(10) |
|
|
|
if __name__ == "__main__": |
|
import threading |
|
threading.Thread(target=run_training).start() |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|