import os import json import uuid import logging from fastapi import FastAPI, HTTPException from pydantic import BaseModel from google.cloud import storage from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer import uvicorn import torch import requests import io from safetensors import safe_open from dotenv import load_dotenv load_dotenv() API_KEY = os.getenv("API_KEY") GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME") GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON") HF_API_TOKEN = os.getenv("HF_API_TOKEN") logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) try: credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON) storage_client = storage.Client.from_service_account_info(credentials_info) bucket = storage_client.bucket(GCS_BUCKET_NAME) logger.info(f"Conexión con Google Cloud Storage exitosa. Bucket: {GCS_BUCKET_NAME}") except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e: logger.error(f"Error al cargar las credenciales o bucket: {e}") raise RuntimeError(f"Error al cargar las credenciales o bucket: {e}") app = FastAPI() class DownloadModelRequest(BaseModel): model_name: str pipeline_task: str input_text: str class GCSHandler: def __init__(self, bucket_name): self.bucket = storage_client.bucket(bucket_name) def file_exists(self, blob_name): exists = self.bucket.blob(blob_name).exists() logger.debug(f"Comprobando existencia de archivo '{blob_name}': {exists}") return exists def download_file(self, blob_name): blob = self.bucket.blob(blob_name) if not blob.exists(): logger.error(f"Archivo '{blob_name}' no encontrado en GCS.") raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.") logger.debug(f"Descargando archivo '{blob_name}' de GCS.") return blob def generate_signed_url(self, blob_name, expiration=3600): blob = self.bucket.blob(blob_name) url = blob.generate_signed_url(expiration=expiration) logger.debug(f"Generada URL firmada para '{blob_name}': {url}") return url def load_model_from_gcs(model_name: str, model_files: list): gcs_handler = GCSHandler(GCS_BUCKET_NAME) model_blobs = {file: gcs_handler.download_file(f"{model_name}/{file}") for file in model_files} model_stream = model_blobs.get("pytorch_model.bin") or model_blobs.get("model.safetensors") config_stream = model_blobs.get("config.json") tokenizer_stream = model_blobs.get("tokenizer.json") if "safetensors" in model_stream.name: model = load_safetensors_model(model_stream) else: model = AutoModelForCausalLM.from_pretrained(model_stream, config=config_stream) tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream) return model, tokenizer def load_safetensors_model(model_stream): with safe_open(model_stream, framework="pt") as model_data: model = torch.load(model_data) return model def get_model_files_from_gcs(model_name: str): gcs_handler = GCSHandler(GCS_BUCKET_NAME) blob_list = list(gcs_handler.bucket.list_blobs(prefix=f"{model_name}/")) model_files = [blob.name for blob in blob_list if "pytorch_model" in blob.name or "model" in blob.name] model_files = sorted(model_files) return model_files @app.post("/predict/") async def predict(request: DownloadModelRequest): logger.info(f"Iniciando predicción para el modelo '{request.model_name}' con tarea '{request.pipeline_task}'...") try: gcs_handler = GCSHandler(GCS_BUCKET_NAME) model_prefix = request.model_name model_files = get_model_files_from_gcs(model_prefix) if not model_files: logger.error(f"Modelos no encontrados en GCS para '{model_prefix}'.") raise HTTPException(status_code=404, detail="Model files not found in GCS.") model, tokenizer = load_model_from_gcs(model_prefix, model_files) pipe = pipeline(request.pipeline_task, model=model, tokenizer=tokenizer) if request.pipeline_task in ["text-generation", "translation", "summarization"]: result = pipe(request.input_text) logger.info(f"Resultado generado para la tarea '{request.pipeline_task}': {result[0]}") return {"response": result[0]} elif request.pipeline_task == "image-generation": images = pipe(request.input_text) image = images[0] image_filename = f"{uuid.uuid4().hex}.png" image_path = f"images/{image_filename}" image.save(image_path) gcs_handler.upload_file(image_path, open(image_path, "rb")) image_url = gcs_handler.generate_signed_url(image_path) return {"response": {"image_url": image_url}} elif request.pipeline_task == "image-editing": edited_images = pipe(request.input_text) edited_image = edited_images[0] edited_image_filename = f"{uuid.uuid4().hex}_edited.png" edited_image.save(edited_image_filename) gcs_handler.upload_file(f"images/{edited_image_filename}", open(edited_image_filename, "rb")) edited_image_url = gcs_handler.generate_signed_url(f"images/{edited_image_filename}") return {"response": {"edited_image_url": edited_image_url}} elif request.pipeline_task == "image-to-image": transformed_images = pipe(request.input_text) transformed_image = transformed_images[0] transformed_image_filename = f"{uuid.uuid4().hex}_transformed.png" transformed_image.save(transformed_image_filename) gcs_handler.upload_file(f"images/{transformed_image_filename}", open(transformed_image_filename, "rb")) transformed_image_url = gcs_handler.generate_signed_url(f"images/{transformed_image_filename}") return {"response": {"transformed_image_url": transformed_image_url}} elif request.pipeline_task == "text-to-3d": model_3d_filename = f"{uuid.uuid4().hex}.obj" model_3d_path = f"3d-models/{model_3d_filename}" with open(model_3d_path, "w") as f: f.write("Simulated 3D model data") gcs_handler.upload_file(f"3d-models/{model_3d_filename}", open(model_3d_path, "rb")) model_3d_url = gcs_handler.generate_signed_url(f"3d-models/{model_3d_filename}") return {"response": {"model_3d_url": model_3d_url}} except HTTPException as e: logger.error(f"HTTPException: {e.detail}") raise e except Exception as e: logger.error(f"Error inesperado: {e}") raise HTTPException(status_code=500, detail=f"Error: {e}") def download_model_from_huggingface(model_name): url = f"https://huggingface.co/{model_name}/tree/main" headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} try: logger.info(f"Descargando el modelo '{model_name}' desde Hugging Face...") response = requests.get(url, headers=headers) if response.status_code == 200: model_files = [ "pytorch_model.bin", "config.json", "tokenizer.json", "model.safetensors", ] for file_name in model_files: file_url = f"https://huggingface.co/{model_name}/resolve/main/{file_name}" file_content = requests.get(file_url).content blob_name = f"{model_name}/{file_name}" blob = bucket.blob(blob_name) blob.upload_from_string(file_content) logger.info(f"Archivo '{file_name}' subido exitosamente al bucket GCS.") else: logger.error(f"Error al acceder al árbol de archivos de Hugging Face para '{model_name}'.") raise HTTPException(status_code=404, detail="Error al acceder al árbol de archivos de Hugging Face.") except Exception as e: logger.error(f"Error descargando archivos de Hugging Face: {e}") raise HTTPException(status_code=500, detail=f"Error descargando archivos de Hugging Face: {e}") @app.on_event("startup") async def startup_event(): logger.info("Iniciando la API...") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)