File size: 1,917 Bytes
cb4762a
 
 
 
 
 
 
 
 
f83fa14
cb4762a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f83fa14
f5d7c54
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from fastapi import FastAPI, File, UploadFile
from fastapi import FastAPI, File, UploadFile, Form, Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse
from gradio_client import Client
import os
import io
app = FastAPI()

hf_token = os.environ.get('HF_TOKEN')
client = Client("https://ashrafb-ifur.hf.space/", hf_token=hf_token)

import tempfile



import base64
from fastapi.middleware.cors import CORSMiddleware

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Adjust as needed, '*' allows requests from any origin
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.post("/upload/")
async def upload_file(file: UploadFile = File(...), version: str = Form(...), scale: int = Form(...)):
    with tempfile.NamedTemporaryFile(delete=False) as temp_file:
        temp_file.write(await file.read())
        temp_file_path = temp_file.name

    try:
        result = client.predict(temp_file_path, version, scale, api_name="/predict")

        
        # Check if the result is valid
        if result and len(result) == 2:
            # Convert the image data to base64 string
            with open(result[0], "rb") as image_file:
                image_data = base64.b64encode(image_file.read()).decode("utf-8")

            return {
                "sketch_image_base64": f"data:image/png;base64,{image_data}",
                "result_file": result[1]
            }
        else:
            return {"error": "Invalid result from the prediction API."}
    except Exception as e:
        return {"error": str(e)}
    finally:
        if os.path.exists(temp_file_path):
            os.unlink(temp_file_path)