CaesarCloudSync commited on
Commit
33b57a9
1 Parent(s): 1a27815

CaesarAIART test

Browse files
Files changed (2) hide show
  1. CaesarAIART/caesaraiart.py +26 -3
  2. main.py +5 -4
CaesarAIART/caesaraiart.py CHANGED
@@ -1,6 +1,9 @@
1
  import torch
2
  from diffusers import StableDiffusionPipeline
3
-
 
 
 
4
  class CaesarAIART:
5
  def __init__(self,CURRENT_DIR=""):
6
  self.model_id = "CompVis/stable-diffusion-v1-4"
@@ -11,8 +14,28 @@ class CaesarAIART:
11
  def generate(self,prompt):
12
  image = self.pipe(prompt).images[0]
13
  image.save(f"{self.CURRENT_DIR}/CaesarAIART/caesarart.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  if __name__ == "__main__":
16
  prompt = "a photo of an astronaut riding a horse on mars"
17
- caesaraiart = CaesarAIART()
18
- caesaraiart.generate()
 
 
 
 
1
  import torch
2
  from diffusers import StableDiffusionPipeline
3
+ import requests
4
+ # You can access the image with PIL.Image for example
5
+ import io
6
+ from PIL import Image
7
  class CaesarAIART:
8
  def __init__(self,CURRENT_DIR=""):
9
  self.model_id = "CompVis/stable-diffusion-v1-4"
 
14
  def generate(self,prompt):
15
  image = self.pipe(prompt).images[0]
16
  image.save(f"{self.CURRENT_DIR}/CaesarAIART/caesarart.png")
17
+ @staticmethod
18
+ def generate_api(prompt):
19
+
20
+
21
+ API_URL = "https://api-inference.huggingface.co/models/CompVis/stable-diffusion-v1-4"
22
+ headers = {"Authorization": "Bearer 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"}
23
+
24
+ def query(payload):
25
+ response = requests.post(API_URL, headers=headers, json=payload)
26
+ return response.content
27
+ image_bytes = query({"inputs": prompt})
28
+
29
+ image = Image.open(io.BytesIO(image_bytes))
30
+ filtered_image = io.BytesIO()
31
+ image.save(filtered_image, "PNG")
32
+ filtered_image.seek(0)
33
+ return filtered_image
34
 
35
  if __name__ == "__main__":
36
  prompt = "a photo of an astronaut riding a horse on mars"
37
+ image = CaesarAIART.generate_api(prompt)
38
+
39
+
40
+ #caesaraiart = CaesarAIART()
41
+ #caesaraiart.generate()
main.py CHANGED
@@ -15,7 +15,7 @@ import speech_recognition as sr
15
  import uvicorn
16
  from fastapi import FastAPI, File, UploadFile,Depends, WebSocket, WebSocketDisconnect
17
  from fastapi.middleware.cors import CORSMiddleware
18
- from fastapi.responses import FileResponse
19
  from tqdm import tqdm
20
  from transformers import pipeline
21
 
@@ -30,6 +30,7 @@ from RequestModels import *
30
  from CaesarFaceRecognition.caesardeepface import CaesarDeepFace
31
  from CaesarAIART.caesaraiart import CaesarAIART
32
  #from CaesarAIMusicLoad.caesaraimusicload import CaesarAITelegramBOT
 
33
  importcsv = ImportCSV("CaesarAI")
34
  caesaryolo = CaesarYolo()
35
  caesarfacedetectmodel = CaesarFaceDetection()
@@ -38,7 +39,7 @@ caesarfacedetectmodel = CaesarFaceDetection()
38
  app = FastAPI()
39
 
40
  CURRENT_DIR = os.path.realpath(__file__).replace(f"/main.py","")
41
- caesaraiart = CaesarAIART(CURRENT_DIR)
42
  app.add_middleware(
43
  CORSMiddleware,
44
  allow_origins=["*"], # can alter with time
@@ -134,8 +135,8 @@ def caesarfacesnap(frames: CaesarOCRHTTPModel):
134
  def caesarart(promptjson: CaesarAIARTModel):
135
  try:
136
  promptjson = dict(promptjson)
137
- caesaraiart.generate(promptjson["prompt"])
138
- return FileResponse(f"{CURRENT_DIR}/CaesarAIART/caesarart.png")
139
 
140
  except Exception as ex:
141
  return {"error":f"{type(ex)},{ex}"}
 
15
  import uvicorn
16
  from fastapi import FastAPI, File, UploadFile,Depends, WebSocket, WebSocketDisconnect
17
  from fastapi.middleware.cors import CORSMiddleware
18
+ from fastapi.responses import FileResponse,StreamingResponse
19
  from tqdm import tqdm
20
  from transformers import pipeline
21
 
 
30
  from CaesarFaceRecognition.caesardeepface import CaesarDeepFace
31
  from CaesarAIART.caesaraiart import CaesarAIART
32
  #from CaesarAIMusicLoad.caesaraimusicload import CaesarAITelegramBOT
33
+
34
  importcsv = ImportCSV("CaesarAI")
35
  caesaryolo = CaesarYolo()
36
  caesarfacedetectmodel = CaesarFaceDetection()
 
39
  app = FastAPI()
40
 
41
  CURRENT_DIR = os.path.realpath(__file__).replace(f"/main.py","")
42
+ #caesaraiart = CaesarAIART(CURRENT_DIR)
43
  app.add_middleware(
44
  CORSMiddleware,
45
  allow_origins=["*"], # can alter with time
 
135
  def caesarart(promptjson: CaesarAIARTModel):
136
  try:
137
  promptjson = dict(promptjson)
138
+ image = CaesarAIART.generate_api(promptjson["prompt"])
139
+ return StreamingResponse(image, media_type="image/png")#FileResponse(f"{CURRENT_DIR}/CaesarAIART/caesarart.png")
140
 
141
  except Exception as ex:
142
  return {"error":f"{type(ex)},{ex}"}