Ashrafb commited on
Commit
f44bb3a
โ€ข
1 Parent(s): 2a94a11

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +37 -71
main.py CHANGED
@@ -1,75 +1,41 @@
 
1
  from fastapi import FastAPI
2
  from fastapi.staticfiles import StaticFiles
3
- from fastapi.responses import HTMLResponse
4
- from pydantic import BaseModel
5
- import gradio as gr
6
- import random
7
- import string
8
- import time
9
- from queue import Queue
10
- from threading import Thread
11
 
12
  app = FastAPI()
13
-
14
- # Mount the static directory for serving HTML and CSS files
15
- app.mount("/static", StaticFiles(directory="static"), name="static")
16
-
17
- queue = Queue()
18
- queue_threshold = 100
19
-
20
- def restart_script_periodically():
21
- while True:
22
- random_time = random.randint(540, 600)
23
- time.sleep(random_time)
24
- os.execl(sys.executable, sys.executable, *sys.argv)
25
-
26
- restart_thread = Thread(target=restart_script_periodically, daemon=True)
27
- restart_thread.start()
28
-
29
- def add_random_noise(prompt, noise_level=0.00):
30
- if noise_level == 0:
31
- noise_level = 0.00
32
- percentage_noise = noise_level * 5
33
- num_noise_chars = int(len(prompt) * (percentage_noise / 100))
34
- noise_indices = random.sample(range(len(prompt)), num_noise_chars)
35
- prompt_list = list(prompt)
36
- noise_chars = list(string.ascii_letters + string.punctuation + ' ' + string.digits)
37
- noise_chars.extend(['๐Ÿ˜', '๐Ÿ’ฉ', '๐Ÿ˜‚', '๐Ÿค”', '๐Ÿ˜Š', '๐Ÿค—', '๐Ÿ˜ญ', '๐Ÿ™„', '๐Ÿ˜ท', '๐Ÿคฏ', '๐Ÿคซ', '๐Ÿฅด', '๐Ÿ˜ด', '๐Ÿคฉ', '๐Ÿฅณ', '๐Ÿ˜”', '๐Ÿ˜ฉ', '๐Ÿคช', '๐Ÿ˜‡', '๐Ÿคข', '๐Ÿ˜ˆ', '๐Ÿ‘น', '๐Ÿ‘ป', '๐Ÿค–', '๐Ÿ‘ฝ', '๐Ÿ’€', '๐ŸŽƒ', '๐ŸŽ…', '๐ŸŽ„', '๐ŸŽ', '๐ŸŽ‚', '๐ŸŽ‰', '๐ŸŽˆ', '๐ŸŽŠ', '๐ŸŽฎ', 'โค๏ธ', '๐Ÿ’”', '๐Ÿ’•', '๐Ÿ’–', '๐Ÿ’—', '๐Ÿถ', '๐Ÿฑ', '๐Ÿญ', '๐Ÿน', '๐ŸฆŠ', '๐Ÿป', '๐Ÿจ', '๐Ÿฏ', '๐Ÿฆ', '๐Ÿ˜', '๐Ÿ”ฅ', '๐ŸŒง๏ธ', '๐ŸŒž', '๐ŸŒˆ', '๐Ÿ’ฅ', '๐ŸŒด', '๐ŸŒŠ', '๐ŸŒบ', '๐ŸŒป', '๐ŸŒธ', '๐ŸŽจ', '๐ŸŒ…', '๐ŸŒŒ', 'โ˜๏ธ', 'โ›ˆ๏ธ', 'โ„๏ธ', 'โ˜€๏ธ', '๐ŸŒค๏ธ', 'โ›…๏ธ', '๐ŸŒฅ๏ธ', '๐ŸŒฆ๏ธ', '๐ŸŒง๏ธ', '๐ŸŒฉ๏ธ', '๐ŸŒจ๏ธ', '๐ŸŒซ๏ธ', 'โ˜”๏ธ', '๐ŸŒฌ๏ธ', '๐Ÿ’จ', '๐ŸŒช๏ธ', '๐ŸŒˆ'])
38
- for index in noise_indices:
39
- prompt_list[index] = random.choice(noise_chars)
40
- return "".join(prompt_list)
41
-
42
- # Define FastAPI models
43
- class PromptInput(BaseModel):
44
- prompt: str
45
- noise_level: float
46
-
47
- # Define FastAPI endpoints
48
- @app.post("/get_prompt/")
49
- async def get_prompt(prompt_input: PromptInput):
50
- text_gen = gr.Interface.load("models/Gustavosta/MagicPrompt-Stable-Diffusion")
51
- prompt_text = text_gen("dreamlikeart, " + prompt_input.prompt)
52
- return {"prompt": prompt_text}
53
-
54
- @app.post("/generate_image/")
55
- async def generate_image(prompt_input: PromptInput):
56
- proc1 = gr.Interface.load("models/dreamlike-art/dreamlike-diffusion-1.0", temp_files_path="static/tmp")
57
- prompt_with_noise = add_random_noise(prompt_input.prompt, prompt_input.noise_level)
58
- while queue.qsize() >= queue_threshold:
59
- time.sleep(2)
60
- queue.put(prompt_with_noise)
61
- output_dict = proc1(prompt_with_noise)
62
- output1 = output_dict['output'][0]
63
- return {"image": output1}
64
-
65
- # Serve the HTML frontend
66
- @app.get("/", response_class=HTMLResponse)
67
- async def serve_frontend():
68
- with open("static/index.html", "r") as file:
69
- html_content = file.read()
70
- return HTMLResponse(content=html_content)
71
-
72
- # Run the FastAPI server
73
- if __name__ == "__main__":
74
- import uvicorn
75
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from fastapi import FastAPI, HTTPException
2
  from fastapi import FastAPI
3
  from fastapi.staticfiles import StaticFiles
4
+ from fastapi.responses import FileResponse
5
+ import os
6
+ import requests
 
 
 
 
 
7
 
8
  app = FastAPI()
9
+ API_URL = "https://ashrafb-dreamlikeart-diffusion-1-0.hf.space/"
10
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+
12
+ def make_prediction(prompt, noise_level=0.0, fn_index=0):
13
+ headers={"Authorization": f"Bearer {HF_TOKEN}"}
14
+ data = {"prompt": prompt, "noise_level": noise_level}
15
+ response = requests.post(API_URL, headers=headers, json=data)
16
+ if response.status_code == 200:
17
+ return response.json()
18
+ else:
19
+ raise HTTPException(status_code=response.status_code, detail=response.text)
20
+
21
+ @app.get("/short-prompt/")
22
+ async def short_prompt(prompt: str):
23
+ try:
24
+ result = make_prediction(prompt)
25
+ return {"result": result}
26
+ except Exception as e:
27
+ raise HTTPException(status_code=500, detail=str(e))
28
+
29
+ @app.get("/long-prompt/")
30
+ async def long_prompt(prompt: str, noise_level: float = 0.0):
31
+ try:
32
+ result = make_prediction(prompt, noise_level, fn_index=1)
33
+ return {"result": result}
34
+ except Exception as e:
35
+ raise HTTPException(status_code=500, detail=str(e))
36
+
37
+ app.mount("/", StaticFiles(directory="static", html=True), name="static")
38
+
39
+ @app.get("/")
40
+ def index() -> FileResponse:
41
+ return FileResponse(path="/app/static/index.html", media_type="text/html")