Hjgugugjhuhjggg commited on
Commit
db17ba5
·
verified ·
1 Parent(s): 3e20aa7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -24
app.py CHANGED
@@ -1,8 +1,8 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
  import os
4
  import json
5
  import requests
 
 
6
  from google.cloud import storage
7
  from google.auth import exceptions
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
@@ -49,28 +49,24 @@ class GCSHandler:
49
  return BytesIO(blob.download_as_bytes())
50
 
51
  def download_model_from_huggingface(model_name):
52
- file_patterns = [
53
- "pytorch_model.bin",
54
- "config.json",
55
- "tokenizer.json",
56
- "model.safetensors",
57
- ]
58
- for i in range(1, 100):
59
- file_patterns.extend([f"pytorch_model-{i:05}-of-00001", f"model-{i:05}"])
60
 
61
- # Descargar los archivos del modelo
62
- for filename in file_patterns:
63
- url = f"https://huggingface.co/{model_name}/resolve/main/{filename}"
64
- headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
65
- try:
66
- response = requests.get(url, headers=headers, stream=True)
67
- if response.status_code == 200:
 
 
68
  blob_name = f"{model_name}/{filename}"
69
- bucket.blob(blob_name).upload_from_file(BytesIO(response.content))
70
- else:
71
- raise HTTPException(status_code=404, detail=f"File {filename} not found on Hugging Face.")
72
- except Exception as e:
73
- raise HTTPException(status_code=500, detail=f"Error downloading {filename} from Hugging Face: {e}")
74
 
75
  @app.post("/predict/")
76
  async def predict(request: DownloadModelRequest):
@@ -83,8 +79,6 @@ async def predict(request: DownloadModelRequest):
83
  "tokenizer.json",
84
  "model.safetensors",
85
  ]
86
- for i in range(1, 100):
87
- model_files.extend([f"pytorch_model-{i:05}-of-00001", f"model-{i:05}"])
88
 
89
  # Verificar si los archivos del modelo están en GCS
90
  model_files_exist = all(gcs_handler.file_exists(f"{model_prefix}/{file}") for file in model_files)
 
 
 
1
  import os
2
  import json
3
  import requests
4
+ from fastapi import FastAPI, HTTPException
5
+ from pydantic import BaseModel
6
  from google.cloud import storage
7
  from google.auth import exceptions
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
49
  return BytesIO(blob.download_as_bytes())
50
 
51
  def download_model_from_huggingface(model_name):
52
+ url = f"https://huggingface.co/{model_name}/tree/main"
53
+ headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
 
 
 
 
 
 
54
 
55
+ # Intentar obtener el árbol de archivos
56
+ try:
57
+ response = requests.get(url, headers=headers)
58
+ if response.status_code == 200:
59
+ # Extraer la lista de archivos del árbol (parseo HTML o JSON depende de la respuesta)
60
+ # Aquí asumimos que el archivo de modelos está disponible
61
+ file_urls = [] # Aquí agregarías la lógica para extraer los enlaces correctos del HTML de la página
62
+ for file_url in file_urls:
63
+ filename = file_url.split("/")[-1]
64
  blob_name = f"{model_name}/{filename}"
65
+ bucket.blob(blob_name).upload_from_file(BytesIO(requests.get(file_url).content))
66
+ else:
67
+ raise HTTPException(status_code=404, detail="Error al acceder al árbol de archivos de Hugging Face.")
68
+ except Exception as e:
69
+ raise HTTPException(status_code=500, detail=f"Error descargando archivos de Hugging Face: {e}")
70
 
71
  @app.post("/predict/")
72
  async def predict(request: DownloadModelRequest):
 
79
  "tokenizer.json",
80
  "model.safetensors",
81
  ]
 
 
82
 
83
  # Verificar si los archivos del modelo están en GCS
84
  model_files_exist = all(gcs_handler.file_exists(f"{model_prefix}/{file}") for file in model_files)