File size: 1,524 Bytes
dd7be29
eaa1a7e
 
 
 
 
 
ac9a037
eaa1a7e
dd7be29
ac9a037
eaa1a7e
 
ac9a037
eaa1a7e
ac9a037
 
eaa1a7e
 
 
ac9a037
dd7be29
ac9a037
dd7be29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaa1a7e
dd7be29
eaa1a7e
ac9a037
 
 
eaa1a7e
ac9a037
eaa1a7e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from typing import Dict, Any, Optional, List, Union
import base64
import io
from PIL import Image
import torch
import os
import sys
import json

# Import the handler
from handler import EndpointHandler

# Initialize the app
app = FastAPI()

# Initialize the model
model = EndpointHandler(model_dir="/code")

@app.post("/")
async def process_request(request: Request):
    try:
        # Get the raw request body
        body = await request.body()
        
        # Try to parse as JSON
        try:
            data = json.loads(body)
        except:
            # If not JSON, treat as plain text
            data = {"inputs": body.decode("utf-8")}
        
        # Handle different input formats
        if isinstance(data, dict):
            if "inputs" in data:
                # Standard format
                pass
            elif "text" in data:
                # Text field directly
                data = {"inputs": data["text"]}
            else:
                # No recognized fields, use the whole dict as input
                data = {"inputs": str(data)}
        else:
            # Not a dict, use as is
            data = {"inputs": str(data)}
        
        # Process the request
        result = model(data)
        return result
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Add a health check endpoint
@app.get("/health")
async def health():
    return {"status": "ok"}