Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
-
from fastapi import FastAPI, HTTPException
|
2 |
-
from pydantic import BaseModel
|
3 |
import os
|
4 |
import json
|
5 |
import requests
|
|
|
|
|
6 |
from google.cloud import storage
|
7 |
from google.auth import exceptions
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
@@ -49,28 +49,24 @@ class GCSHandler:
|
|
49 |
return BytesIO(blob.download_as_bytes())
|
50 |
|
51 |
def download_model_from_huggingface(model_name):
|
52 |
-
|
53 |
-
|
54 |
-
"config.json",
|
55 |
-
"tokenizer.json",
|
56 |
-
"model.safetensors",
|
57 |
-
]
|
58 |
-
for i in range(1, 100):
|
59 |
-
file_patterns.extend([f"pytorch_model-{i:05}-of-00001", f"model-{i:05}"])
|
60 |
|
61 |
-
#
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
68 |
blob_name = f"{model_name}/{filename}"
|
69 |
-
bucket.blob(blob_name).upload_from_file(BytesIO(
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
|
75 |
@app.post("/predict/")
|
76 |
async def predict(request: DownloadModelRequest):
|
@@ -83,8 +79,6 @@ async def predict(request: DownloadModelRequest):
|
|
83 |
"tokenizer.json",
|
84 |
"model.safetensors",
|
85 |
]
|
86 |
-
for i in range(1, 100):
|
87 |
-
model_files.extend([f"pytorch_model-{i:05}-of-00001", f"model-{i:05}"])
|
88 |
|
89 |
# Verificar si los archivos del modelo están en GCS
|
90 |
model_files_exist = all(gcs_handler.file_exists(f"{model_prefix}/{file}") for file in model_files)
|
|
|
|
|
|
|
1 |
import os
|
2 |
import json
|
3 |
import requests
|
4 |
+
from fastapi import FastAPI, HTTPException
|
5 |
+
from pydantic import BaseModel
|
6 |
from google.cloud import storage
|
7 |
from google.auth import exceptions
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
|
|
49 |
return BytesIO(blob.download_as_bytes())
|
50 |
|
51 |
def download_model_from_huggingface(model_name):
|
52 |
+
url = f"https://huggingface.co/{model_name}/tree/main"
|
53 |
+
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
+
# Intentar obtener el árbol de archivos
|
56 |
+
try:
|
57 |
+
response = requests.get(url, headers=headers)
|
58 |
+
if response.status_code == 200:
|
59 |
+
# Extraer la lista de archivos del árbol (parseo HTML o JSON depende de la respuesta)
|
60 |
+
# Aquí asumimos que el archivo de modelos está disponible
|
61 |
+
file_urls = [] # Aquí agregarías la lógica para extraer los enlaces correctos del HTML de la página
|
62 |
+
for file_url in file_urls:
|
63 |
+
filename = file_url.split("/")[-1]
|
64 |
blob_name = f"{model_name}/{filename}"
|
65 |
+
bucket.blob(blob_name).upload_from_file(BytesIO(requests.get(file_url).content))
|
66 |
+
else:
|
67 |
+
raise HTTPException(status_code=404, detail="Error al acceder al árbol de archivos de Hugging Face.")
|
68 |
+
except Exception as e:
|
69 |
+
raise HTTPException(status_code=500, detail=f"Error descargando archivos de Hugging Face: {e}")
|
70 |
|
71 |
@app.post("/predict/")
|
72 |
async def predict(request: DownloadModelRequest):
|
|
|
79 |
"tokenizer.json",
|
80 |
"model.safetensors",
|
81 |
]
|
|
|
|
|
82 |
|
83 |
# Verificar si los archivos del modelo están en GCS
|
84 |
model_files_exist = all(gcs_handler.file_exists(f"{model_prefix}/{file}") for file in model_files)
|