Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -72,7 +72,8 @@ from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
|
|
| 72 |
from jose import JWTError, jwt
|
| 73 |
from passlib.context import CryptContext
|
| 74 |
from datetime import datetime, timedelta
|
| 75 |
-
from
|
|
|
|
| 76 |
|
| 77 |
#setting up logging
|
| 78 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
|
|
@@ -153,15 +154,14 @@ class GenerateRequest(BaseModel):
|
|
| 153 |
mask_image: Optional[UploadFile] = None # for image inpainting
|
| 154 |
low_res_image: Optional[UploadFile] = None # for image super-resolution
|
| 155 |
|
| 156 |
-
|
| 157 |
-
@validator("task_type")
|
| 158 |
def validate_task_type(cls, value):
|
| 159 |
allowed_types = ["text", "image", "audio", "video", "classification", "translation", "question-answering", "speech-to-text", "text-to-speech", "image-segmentation", "feature-extraction", "token-classification", "fill-mask", "image-inpainting", "image-super-resolution", "object-detection", "image-captioning", "audio-transcription", "summarization"]
|
| 160 |
if value not in allowed_types:
|
| 161 |
raise ValueError(f"Invalid task_type. Allowed types are: {allowed_types}")
|
| 162 |
return value
|
| 163 |
|
| 164 |
-
@
|
| 165 |
def check_input(cls, values):
|
| 166 |
task_type = values.get("task_type")
|
| 167 |
if task_type == "text" and values.get("input_text") is None:
|
|
@@ -182,8 +182,6 @@ class GenerateRequest(BaseModel):
|
|
| 182 |
raise ValueError("low_res_image is required for image super-resolution.")
|
| 183 |
return values
|
| 184 |
|
| 185 |
-
|
| 186 |
-
|
| 187 |
class S3ModelLoader:
|
| 188 |
def __init__(self, bucket_name, aws_access_key_id, aws_secret_access_key, aws_region):
|
| 189 |
self.bucket_name = bucket_name
|
|
@@ -688,4 +686,4 @@ if __name__ == "__main__":
|
|
| 688 |
|
| 689 |
create_db_and_table() # Initialize database on startup
|
| 690 |
|
| 691 |
-
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
|
|
|
|
| 72 |
from jose import JWTError, jwt
|
| 73 |
from passlib.context import CryptContext
|
| 74 |
from datetime import datetime, timedelta
|
| 75 |
+
from pydantic import BaseModel, field_validator, model_validator, Field, EmailStr, constr, ValidationError
|
| 76 |
+
from typing import Optional, List, Union
|
| 77 |
|
| 78 |
#setting up logging
|
| 79 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
|
|
|
|
| 154 |
mask_image: Optional[UploadFile] = None # for image inpainting
|
| 155 |
low_res_image: Optional[UploadFile] = None # for image super-resolution
|
| 156 |
|
| 157 |
+
@field_validator('task_type')
|
|
|
|
| 158 |
def validate_task_type(cls, value):
|
| 159 |
allowed_types = ["text", "image", "audio", "video", "classification", "translation", "question-answering", "speech-to-text", "text-to-speech", "image-segmentation", "feature-extraction", "token-classification", "fill-mask", "image-inpainting", "image-super-resolution", "object-detection", "image-captioning", "audio-transcription", "summarization"]
|
| 160 |
if value not in allowed_types:
|
| 161 |
raise ValueError(f"Invalid task_type. Allowed types are: {allowed_types}")
|
| 162 |
return value
|
| 163 |
|
| 164 |
+
@model_validator(mode='after')
|
| 165 |
def check_input(cls, values):
|
| 166 |
task_type = values.get("task_type")
|
| 167 |
if task_type == "text" and values.get("input_text") is None:
|
|
|
|
| 182 |
raise ValueError("low_res_image is required for image super-resolution.")
|
| 183 |
return values
|
| 184 |
|
|
|
|
|
|
|
| 185 |
class S3ModelLoader:
|
| 186 |
def __init__(self, bucket_name, aws_access_key_id, aws_secret_access_key, aws_region):
|
| 187 |
self.bucket_name = bucket_name
|
|
|
|
| 686 |
|
| 687 |
create_db_and_table() # Initialize database on startup
|
| 688 |
|
| 689 |
+
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
|