gcs / app.py
Hjgugugjhuhjggg's picture
Update app.py
abeeac6 verified
raw
history blame
5.7 kB
import os
import json
import threading
import logging
from google.cloud import storage
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from pydantic import BaseModel
from fastapi import FastAPI, HTTPException
import requests
import uvicorn
from dotenv import load_dotenv
load_dotenv()
API_KEY = os.getenv("API_KEY")
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
storage_client = storage.Client.from_service_account_info(credentials_info)
bucket = storage_client.bucket(GCS_BUCKET_NAME)
app = FastAPI()
class DownloadModelRequest(BaseModel):
model_name: str
pipeline_task: str
input_text: str
class GCSHandler:
def __init__(self, bucket_name):
self.bucket = storage_client.bucket(bucket_name)
def file_exists(self, blob_name):
return self.bucket.blob(blob_name).exists()
def create_folder_if_not_exists(self, folder_name):
if not self.file_exists(folder_name):
self.bucket.blob(folder_name + "/").upload_from_string("")
def upload_file(self, blob_name, file_stream):
self.create_folder_if_not_exists(os.path.dirname(blob_name))
blob = self.bucket.blob(blob_name)
blob.upload_from_file(file_stream)
def download_file(self, blob_name):
blob = self.bucket.blob(blob_name)
if not blob.exists():
raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
return blob.open("rb")
def generate_signed_url(self, blob_name, expiration=3600):
blob = self.bucket.blob(blob_name)
return blob.generate_signed_url(expiration=expiration)
def download_model_from_huggingface(model_name):
url = f"https://huggingface.co/{model_name}/tree/main"
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
response = requests.get(url, headers=headers)
if response.status_code == 200:
model_files = [
"pytorch_model.bin",
"config.json",
"tokenizer.json",
"model.safetensors",
]
for file_name in model_files:
file_url = f"https://huggingface.co/{model_name}/resolve/main/{file_name}"
file_content = requests.get(file_url).content
blob_name = f"models/{model_name}/{file_name}"
bucket.blob(blob_name).upload_from_string(file_content)
else:
raise HTTPException(status_code=404, detail="Error accessing Hugging Face model files.")
def download_and_verify_model(model_name):
model_files = [
"pytorch_model.bin",
"config.json",
"tokenizer.json",
"model.safetensors",
]
gcs_handler = GCSHandler(GCS_BUCKET_NAME)
if not all(gcs_handler.file_exists(f"models/{model_name}/{file}") for file in model_files):
download_model_from_huggingface(model_name)
def load_model_from_gcs(model_name):
model_files = [
"pytorch_model.bin",
"config.json",
"tokenizer.json",
"model.safetensors",
]
gcs_handler = GCSHandler(GCS_BUCKET_NAME)
model_files_streams = {
file: gcs_handler.download_file(f"models/{model_name}/{file}")
for file in model_files if gcs_handler.file_exists(f"models/{model_name}/{file}")
}
model_stream = model_files_streams.get("pytorch_model.bin") or model_files_streams.get("model.safetensors")
tokenizer_stream = model_files_streams.get("tokenizer.json")
config_stream = model_files_streams.get("config.json")
model = AutoModelForCausalLM.from_pretrained(model_stream, config=config_stream)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
return model, tokenizer
def load_model(model_name):
gcs_handler = GCSHandler(GCS_BUCKET_NAME)
try:
return load_model_from_gcs(model_name)
except HTTPException:
download_and_verify_model(model_name)
return load_model_from_gcs(model_name)
@app.on_event("startup")
async def startup():
gcs_handler = GCSHandler(GCS_BUCKET_NAME)
blobs = list(bucket.list_blobs(prefix="models/"))
model_names = set(blob.name.split("/")[1] for blob in blobs)
def download_model_thread(model_name):
try:
download_and_verify_model(model_name)
except Exception as e:
logger.error(f"Error downloading model '{model_name}': {e}")
threads = [threading.Thread(target=download_model_thread, args=(model_name,)) for model_name in model_names]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
@app.post("/predict/")
async def predict(request: DownloadModelRequest):
model_name = request.model_name
pipeline_task = request.pipeline_task
input_text = request.input_text
model, tokenizer = load_model(model_name)
pipe = pipeline(pipeline_task, model=model, tokenizer=tokenizer)
result = pipe(input_text)
return {"result": result}
def download_all_models_in_background():
models_url = "https://huggingface.co/api/models"
response = requests.get(models_url)
if response.status_code == 200:
models = response.json()
for model in models:
download_model_from_huggingface(model["id"])
def run_in_background():
threading.Thread(target=download_all_models_in_background, daemon=True).start()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)