gcs / app.py
Hjgugugjhuhjggg's picture
Update app.py
8e4fcb7 verified
raw
history blame
8.2 kB
import os
import json
import logging
import uuid
import threading
import io
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from google.cloud import storage
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import uvicorn
import torch
import requests
from safetensors import safe_open
from dotenv import load_dotenv
load_dotenv()
API_KEY = os.getenv("API_KEY")
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
try:
credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
storage_client = storage.Client.from_service_account_info(credentials_info)
bucket = storage_client.bucket(GCS_BUCKET_NAME)
logger.info(f"Conexión con Google Cloud Storage exitosa. Bucket: {GCS_BUCKET_NAME}")
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.error(f"Error al cargar las credenciales o bucket: {e}")
raise RuntimeError(f"Error al cargar las credenciales o bucket: {e}")
app = FastAPI()
class DownloadModelRequest(BaseModel):
model_name: str
pipeline_task: str
input_text: str
class GCSHandler:
def __init__(self, bucket_name):
self.bucket = storage_client.bucket(bucket_name)
def file_exists(self, blob_name):
return self.bucket.blob(blob_name).exists()
def download_file(self, blob_name):
blob = self.bucket.blob(blob_name)
if not blob.exists():
raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
return blob.download_as_bytes()
def upload_file(self, blob_name, file_data):
blob = self.bucket.blob(blob_name)
blob.upload_from_file(file_data)
def generate_signed_url(self, blob_name, expiration=3600):
blob = self.bucket.blob(blob_name)
return blob.generate_signed_url(expiration=expiration)
def create_folder(self, folder_name):
blob = self.bucket.blob(folder_name + "/")
blob.upload_from_string("") # Create an empty "folder"
def load_model_from_gcs(model_name: str, model_files: list):
gcs_handler = GCSHandler(GCS_BUCKET_NAME)
model_blobs = {file: gcs_handler.download_file(f"{model_name}/{file}") for file in model_files}
model_stream = model_blobs.get("pytorch_model.bin") or model_blobs.get("model.safetensors")
config_stream = model_blobs.get("config.json")
tokenizer_stream = model_blobs.get("tokenizer.json")
if model_stream and model_stream.endswith(".safetensors"):
model = load_safetensors_model(model_stream)
else:
model = AutoModelForCausalLM.from_pretrained(io.BytesIO(model_stream), config=config_stream)
tokenizer = AutoTokenizer.from_pretrained(io.BytesIO(tokenizer_stream))
return model, tokenizer
def load_safetensors_model(model_stream):
with safe_open(io.BytesIO(model_stream), framework="pt") as model_data:
model = torch.load(model_data)
return model
def get_model_files_from_gcs(model_name: str):
gcs_handler = GCSHandler(GCS_BUCKET_NAME)
blob_list = list(gcs_handler.bucket.list_blobs(prefix=f"{model_name}/"))
model_files = [blob.name for blob in blob_list if any(part in blob.name for part in ["pytorch_model", "model"]) and "index" not in blob.name]
model_files = sorted(model_files)
return model_files
def download_model_from_huggingface(model_name):
url = f"https://huggingface.co/{model_name}/tree/main"
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
try:
response = requests.get(url, headers=headers)
if response.status_code == 200:
model_files = [
"pytorch_model.bin",
"config.json",
"tokenizer.json",
"model.safetensors",
]
def download_file(file_name):
file_url = f"https://huggingface.co/{model_name}/resolve/main/{file_name}"
file_content = requests.get(file_url).content
blob_name = f"{model_name}/{file_name}"
blob = bucket.blob(blob_name)
blob.upload_from_string(file_content)
threads = [threading.Thread(target=download_file, args=(file_name,)) for file_name in model_files]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
raise HTTPException(status_code=404, detail="Error al acceder al árbol de archivos de Hugging Face.")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error descargando archivos de Hugging Face: {e}")
def download_model_files(model_name: str):
model_files = get_model_files_from_gcs(model_name)
if not model_files:
download_model_from_huggingface(model_name)
model_files = get_model_files_from_gcs(model_name)
return model_files
@app.post("/predict/")
async def predict(request: DownloadModelRequest):
try:
gcs_handler = GCSHandler(GCS_BUCKET_NAME)
model_prefix = request.model_name
model_files = download_model_files(model_prefix)
model, tokenizer = load_model_from_gcs(model_prefix, model_files)
pipe = pipeline(request.pipeline_task, model=model, tokenizer=tokenizer)
if request.pipeline_task in ["text-generation", "translation", "summarization"]:
result = pipe(request.input_text)
return {"response": result[0]}
elif request.pipeline_task == "image-generation":
images = pipe(request.input_text)
image = images[0]
image_filename = f"{uuid.uuid4().hex}.png"
image_path = f"images/{image_filename}"
image.save(image_path)
gcs_handler.upload_file(image_path, open(image_path, "rb"))
image_url = gcs_handler.generate_signed_url(image_path)
return {"response": {"image_url": image_url}}
elif request.pipeline_task == "image-editing":
edited_images = pipe(request.input_text)
edited_image = edited_images[0]
edited_image_filename = f"{uuid.uuid4().hex}_edited.png"
edited_image.save(edited_image_filename)
gcs_handler.upload_file(f"images/{edited_image_filename}", open(edited_image_filename, "rb"))
edited_image_url = gcs_handler.generate_signed_url(f"images/{edited_image_filename}")
return {"response": {"edited_image_url": edited_image_url}}
elif request.pipeline_task == "image-to-image":
transformed_images = pipe(request.input_text)
transformed_image = transformed_images[0]
transformed_image_filename = f"{uuid.uuid4().hex}_transformed.png"
transformed_image.save(transformed_image_filename)
gcs_handler.upload_file(f"images/{transformed_image_filename}", open(transformed_image_filename, "rb"))
transformed_image_url = gcs_handler.generate_signed_url(f"images/{transformed_image_filename}")
return {"response": {"transformed_image_url": transformed_image_url}}
elif request.pipeline_task == "text-to-3d":
model_3d_filename = f"{uuid.uuid4().hex}.obj"
model_3d_path = f"3d-models/{model_3d_filename}"
with open(model_3d_path, "w") as f:
f.write("Simulated 3D model data")
gcs_handler.upload_file(f"3d-models/{model_3d_filename}", open(model_3d_path, "rb"))
model_3d_url = gcs_handler.generate_signed_url(f"3d-models/{model_3d_filename}")
return {"response": {"model_3d_url": model_3d_url}}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {e}")
@app.on_event("startup")
async def startup_event():
logger.info("Iniciando la API...")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)