Axel-Student commited on
Commit
c9fdad1
·
1 Parent(s): 6268c2b

change to api

Browse files
Files changed (2) hide show
  1. Dockerfile +2 -3
  2. app.py +32 -18
Dockerfile CHANGED
@@ -2,8 +2,7 @@ FROM python:3.10
2
 
3
  WORKDIR /app
4
 
5
- RUN pip install --no-cache-dir torch torchvision torchaudio diffusers gradio
6
- RUN pip install -U diffusers
7
  RUN pip install git+https://github.com/huggingface/diffusers.git
8
  RUN pip install transformers
9
  RUN pip install accelerate
@@ -14,4 +13,4 @@ COPY . /app
14
 
15
  EXPOSE 7860
16
 
17
- CMD ["python", "app.py"]
 
2
 
3
  WORKDIR /app
4
 
5
+ RUN pip install --no-cache-dir torch torchvision torchaudio diffusers fastapi uvicorn
 
6
  RUN pip install git+https://github.com/huggingface/diffusers.git
7
  RUN pip install transformers
8
  RUN pip install accelerate
 
13
 
14
  EXPOSE 7860
15
 
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -1,26 +1,40 @@
1
  import os
2
- from diffusers import FluxPipeline # type: ignore
3
- import gradio as gr # type: ignore
4
- from huggingface_hub import login, InferenceClient
 
 
5
 
6
- token = os.getenv("HF_TOKEN")
7
 
8
- client = InferenceClient(
9
- provider="together",
10
- api_key=token,
11
- model="black-forest-labs/FLUX.1-dev")
12
 
13
- def generate_image(prompt):
14
- image = client.text_to_image(prompt)
 
 
 
 
 
 
 
 
15
  return image
16
 
17
- gradio_app = gr.Interface(
18
- fn=generate_image,
19
- inputs=gr.Textbox(label="Entrez une description"),
20
- outputs=gr.Image(label="Image générée"),
21
- title="Générateur d'images IA",
22
- description="Entrez une description et générez une image correspondante."
23
- )
 
 
 
 
24
 
25
  if __name__ == "__main__":
26
- gradio_app.launch()
 
 
1
  import os
2
+ import io
3
+ import torch
4
+ from fastapi import FastAPI, HTTPException
5
+ from fastapi.responses import StreamingResponse
6
+ from diffusers import FluxPipeline # type: ignore
7
 
8
+ app = FastAPI()
9
 
10
+ # Récupération du token et authentification
11
+ token = os.getenv("HF_TOKEN")
12
+ login(token=token)
 
13
 
14
+ def generate_image(prompt: str):
15
+ image = pipe(
16
+ prompt,
17
+ height=1024,
18
+ width=1024,
19
+ guidance_scale=3.5,
20
+ num_inference_steps=50,
21
+ max_sequence_length=512,
22
+ generator=torch.Generator("cpu").manual_seed(0)
23
+ ).images[0]
24
  return image
25
 
26
+ @app.get("/generate")
27
+ def generate(prompt: str):
28
+ try:
29
+ image = generate_image(prompt)
30
+ # On sauvegarde l'image dans un buffer en mémoire
31
+ buf = io.BytesIO()
32
+ image.save(buf, format="PNG")
33
+ buf.seek(0)
34
+ return StreamingResponse(buf, media_type="image/png")
35
+ except Exception as e:
36
+ raise HTTPException(status_code=500, detail=str(e))
37
 
38
  if __name__ == "__main__":
39
+ import uvicorn # type: ignore
40
+ uvicorn.run(app, host="0.0.0.0", port=7860)