jree423 commited on
Commit
dd7be29
·
verified ·
1 Parent(s): 8613050

Fix: Update API to better handle text-to-image requests

Browse files
Files changed (1) hide show
  1. api.py +29 -7
api.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from typing import Dict, Any, Optional, List, Union
4
  import base64
@@ -7,6 +7,7 @@ from PIL import Image
7
  import torch
8
  import os
9
  import sys
 
10
 
11
  # Import the handler
12
  from handler import EndpointHandler
@@ -17,15 +18,36 @@ app = FastAPI()
17
  # Initialize the model
18
  model = EndpointHandler(model_dir="/code")
19
 
20
- class TextToImageRequest(BaseModel):
21
- inputs: Union[str, Dict[str, Any]]
22
- parameters: Optional[Dict[str, Any]] = None
23
-
24
  @app.post("/")
25
- async def text_to_image(request: TextToImageRequest):
26
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Process the request
28
- result = model(request.dict())
29
  return result
30
  except Exception as e:
31
  raise HTTPException(status_code=500, detail=str(e))
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
  from pydantic import BaseModel
3
  from typing import Dict, Any, Optional, List, Union
4
  import base64
 
7
  import torch
8
  import os
9
  import sys
10
+ import json
11
 
12
  # Import the handler
13
  from handler import EndpointHandler
 
18
  # Initialize the model
19
  model = EndpointHandler(model_dir="/code")
20
 
 
 
 
 
21
  @app.post("/")
22
+ async def process_request(request: Request):
23
  try:
24
+ # Get the raw request body
25
+ body = await request.body()
26
+
27
+ # Try to parse as JSON
28
+ try:
29
+ data = json.loads(body)
30
+ except:
31
+ # If not JSON, treat as plain text
32
+ data = {"inputs": body.decode("utf-8")}
33
+
34
+ # Handle different input formats
35
+ if isinstance(data, dict):
36
+ if "inputs" in data:
37
+ # Standard format
38
+ pass
39
+ elif "text" in data:
40
+ # Text field directly
41
+ data = {"inputs": data["text"]}
42
+ else:
43
+ # No recognized fields, use the whole dict as input
44
+ data = {"inputs": str(data)}
45
+ else:
46
+ # Not a dict, use as is
47
+ data = {"inputs": str(data)}
48
+
49
  # Process the request
50
+ result = model(data)
51
  return result
52
  except Exception as e:
53
  raise HTTPException(status_code=500, detail=str(e))