import os from dotenv import load_dotenv import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments from datasets import load_dataset, concatenate_datasets from huggingface_hub import login import time import uvicorn from fastapi import FastAPI load_dotenv() login(token=os.getenv('HUGGINGFACE_TOKEN')) model_name = 'gpt2' tokenizer = GPT2Tokenizer.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_name) # Cargar datasets y mantener todo en RAM dataset_humanizado = load_dataset('daily_dialog', split='train', trust_remote_code=True) dataset_codigo = load_dataset('code_search_net', split='train', trust_remote_code=True) dataset_prompts = load_dataset('openai_humaneval', split='train', trust_remote_code=True) combined_dataset = concatenate_datasets([ dataset_humanizado, dataset_codigo, dataset_prompts ]) # Tokenizar y mantener todo en RAM def tokenize_function(examples): return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512) tokenized_dataset = combined_dataset.map(tokenize_function, batched=True) training_args = TrainingArguments( output_dir='./results', per_device_train_batch_size=100, per_device_eval_batch_size=100, num_train_epochs=0, learning_rate=1e-5, logging_steps=-1, max_grad_norm=1, save_total_limit=1, seed=42, weight_decay=0, warmup_ratio=0.0, evaluation_strategy="no", optim="adamw_torch", lr_scheduler_type="constant", ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, ) app = FastAPI() @app.get("/") async def root(): return {"message": "Modelo entrenado y en ejecución."} @spaces.gpu def run_training(): while True: try: trainer.train() model.push_to_hub('Yhhxhfh/nombre_de_tu_modelo', repo_type='model', use_temp_dir=True, commit_message="Actualización del modelo") tokenizer.push_to_hub('Yhhxhfh/nombre_de_tu_modelo', repo_type='model', use_temp_dir=True, commit_message="Actualización del tokenizador") time.sleep(5) except Exception as e: print(f"Error durante el entrenamiento: {e}. Reiniciando el proceso de entrenamiento...") time.sleep(10) if __name__ == "__main__": import threading threading.Thread(target=run_training).start() uvicorn.run(app, host="0.0.0.0", port=7860)