File size: 4,982 Bytes
319a292
 
f84a20c
1c3034c
319a292
 
 
b5bc6a9
f173552
319a292
 
 
 
 
 
 
 
 
f84a20c
319a292
09dbcd2
b5bc6a9
 
09dbcd2
 
b5bc6a9
 
 
 
 
 
 
 
 
319a292
f84a20c
319a292
 
 
09dbcd2
efa228b
319a292
 
 
 
 
 
 
 
b5bc6a9
319a292
 
 
 
 
 
b5bc6a9
319a292
b5bc6a9
319a292
b5bc6a9
319a292
b5bc6a9
 
 
319a292
f173552
b5bc6a9
f173552
 
 
 
b5bc6a9
f173552
 
b5bc6a9
f173552
 
 
 
 
 
 
b5bc6a9
 
efa228b
f173552
319a292
 
 
b5bc6a9
f173552
b5bc6a9
 
 
 
 
319a292
b5bc6a9
 
 
f173552
b5bc6a9
 
 
 
 
 
 
 
 
319a292
 
 
 
 
d84cd10
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import re
import json
import requests
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from google.cloud import storage
from google.auth import exceptions
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from io import BytesIO
from dotenv import load_dotenv
import uvicorn

load_dotenv()

API_KEY = os.getenv("API_KEY")
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
HF_API_TOKEN = os.getenv("HF_API_TOKEN")

def validate_bucket_name(bucket_name):
    if not re.match(r"^[a-z0-9][a-z0-9\-]*[a-z0-9]$", bucket_name):
        raise ValueError(f"Invalid bucket name '{bucket_name}'. Must start and end with a letter or number.")
    return bucket_name

def validate_huggingface_repo_name(repo_name):
    if not re.match(r"^[a-zA-Z0-9_.-]+$", repo_name):
        raise ValueError(f"Invalid repository name '{repo_name}'. Must use alphanumeric characters, '-', '_', or '.'.")
    if repo_name.startswith(('-', '.')) or repo_name.endswith(('-', '.')) or '..' in repo_name:
        raise ValueError(f"Invalid repository name '{repo_name}'. Cannot start or end with '-' or '.', or contain '..'.")
    if len(repo_name) > 96:
        raise ValueError(f"Repository name '{repo_name}' exceeds max length of 96 characters.")
    return repo_name

try:
    GCS_BUCKET_NAME = validate_bucket_name(GCS_BUCKET_NAME)
    credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
    storage_client = storage.Client.from_service_account_info(credentials_info)
    bucket = storage_client.bucket(GCS_BUCKET_NAME)
except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
    raise RuntimeError(f"Error al cargar credenciales o bucket: {e}")

app = FastAPI()

class DownloadModelRequest(BaseModel):
    model_name: str
    pipeline_task: str
    input_text: str

class GCSHandler:
    def __init__(self, bucket_name):
        self.bucket = storage_client.bucket(bucket_name)

    def file_exists(self, blob_name):
        return self.bucket.blob(blob_name).exists()

    def upload_file(self, blob_name, file_stream):
        blob = self.bucket.blob(blob_name)
        blob.upload_from_file(file_stream)

    def download_file(self, blob_name):
        blob = self.bucket.blob(blob_name)
        if not blob.exists():
            raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
        return BytesIO(blob.download_as_bytes())

def download_model_from_huggingface(model_name):
    model_name = validate_huggingface_repo_name(model_name)
    file_patterns = [
        "pytorch_model.bin",
        "config.json",
        "tokenizer.json",
        "model.safetensors",
    ]
    for i in range(1, 100):
        file_patterns.extend([f"pytorch_model-{i:05}-of-00001", f"model-{i:05}"])
    for filename in file_patterns:
        url = f"https://huggingface.co/{model_name}/resolve/main/{filename}"
        headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
        try:
            response = requests.get(url, headers=headers, stream=True)
            if response.status_code == 200:
                blob_name = f"{model_name}/{filename}"
                bucket.blob(blob_name).upload_from_file(BytesIO(response.content))
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Error downloading {filename} from Hugging Face: {e}")

@app.post("/predict/")
async def predict(request: DownloadModelRequest):
    try:
        gcs_handler = GCSHandler(GCS_BUCKET_NAME)
        model_prefix = request.model_name
        model_files = [
            "pytorch_model.bin",
            "config.json",
            "tokenizer.json",
            "model.safetensors",
        ]
        for i in range(1, 100):
            model_files.extend([f"pytorch_model-{i:05}-of-00001", f"model-{i:05}"])
        if not any(gcs_handler.file_exists(f"{model_prefix}/{file}") for file in model_files):
            download_model_from_huggingface(model_prefix)
        model_files_streams = {file: gcs_handler.download_file(f"{model_prefix}/{file}") for file in model_files if gcs_handler.file_exists(f"{model_prefix}/{file}")}
        config_stream = model_files_streams.get("config.json")
        tokenizer_stream = model_files_streams.get("tokenizer.json")
        if not config_stream or not tokenizer_stream:
            raise HTTPException(status_code=500, detail="Required model files missing.")
        model = AutoModelForCausalLM.from_pretrained(config_stream)
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
        pipeline_ = pipeline(request.pipeline_task, model=model, tokenizer=tokenizer)
        result = pipeline_(request.input_text)
        return {"response": result}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error: {e}")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)