Afrinetwork7 commited on
Commit
9237c74
·
verified ·
1 Parent(s): ed07d30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -19
app.py CHANGED
@@ -3,11 +3,13 @@ from pydantic import BaseModel
3
  import numpy as np
4
  import random
5
  import torch
6
- from diffusers import DiffusionPipeline
7
  import boto3
8
  from io import BytesIO
9
  import time
10
  import os
 
 
 
11
 
12
  # S3 Configuration
13
  S3_BUCKET = "afri"
@@ -21,9 +23,25 @@ s3_client = boto3.client('s3',
21
  aws_access_key_id=S3_ACCESS_KEY_ID,
22
  aws_secret_access_key=S3_SECRET_ACCESS_KEY)
23
 
24
- dtype = torch.bfloat16
 
 
 
 
 
 
 
 
 
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)
 
 
 
 
 
 
27
  MAX_SEED = np.iinfo(np.int32).max
28
  MAX_IMAGE_SIZE = 2048
29
 
@@ -32,24 +50,33 @@ app = FastAPI()
32
  class InferenceRequest(BaseModel):
33
  prompt: str
34
  seed: int = 42
35
- randomize_seed: bool = False
36
  width: int = 1024
37
  height: int = 1024
38
- guidance_scale: float = 5.0
39
- num_inference_steps: int = 28
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def save_image_to_s3(image):
42
  img_byte_arr = BytesIO()
43
  image.save(img_byte_arr, format='PNG')
44
  img_byte_arr = img_byte_arr.getvalue()
45
-
46
  filename = f"generated_image_{int(time.time())}.png"
47
-
48
  s3_client.put_object(Bucket=S3_BUCKET,
49
  Key=filename,
50
  Body=img_byte_arr,
51
  ContentType='image/png')
52
-
53
  url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}"
54
  return url
55
 
@@ -59,18 +86,18 @@ async def infer(request: InferenceRequest):
59
  seed = random.randint(0, MAX_SEED)
60
  else:
61
  seed = request.seed
62
-
63
  generator = torch.Generator().manual_seed(seed)
64
 
65
  try:
66
- image = pipe(
67
- prompt=request.prompt,
68
- width=request.width,
69
- height=request.height,
70
- num_inference_steps=request.num_inference_steps,
71
- generator=generator,
72
- guidance_scale=request.guidance_scale
73
- ).images[0]
 
74
 
75
  image_url = save_image_to_s3(image)
76
 
@@ -80,4 +107,4 @@ async def infer(request: InferenceRequest):
80
 
81
  @app.get("/")
82
  async def root():
83
- return {"message": "Welcome to the IG API"}
 
3
  import numpy as np
4
  import random
5
  import torch
 
6
  import boto3
7
  from io import BytesIO
8
  import time
9
  import os
10
+ from safetensors.torch import load_file
11
+ from huggingface_hub import hf_hub_download
12
+ from diffusers import FluxPipeline
13
 
14
  # S3 Configuration
15
  S3_BUCKET = "afri"
 
23
  aws_access_key_id=S3_ACCESS_KEY_ID,
24
  aws_secret_access_key=S3_SECRET_ACCESS_KEY)
25
 
26
+ # Set up cache path
27
+ cache_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
28
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
29
+ os.environ["HF_HUB_CACHE"] = cache_path
30
+ os.environ["HF_HOME"] = cache_path
31
+
32
+ if not os.path.exists(cache_path):
33
+ os.makedirs(cache_path, exist_ok=True)
34
+
35
+ # Set up CUDA and model
36
+ torch.backends.cuda.matmul.allow_tf32 = True
37
  device = "cuda" if torch.cuda.is_available() else "cpu"
38
+
39
+ # Initialize FluxPipeline
40
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
41
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
42
+ pipe.fuse_lora(lora_scale=0.125)
43
+ pipe.to(device=device, dtype=torch.bfloat16)
44
+
45
  MAX_SEED = np.iinfo(np.int32).max
46
  MAX_IMAGE_SIZE = 2048
47
 
 
50
  class InferenceRequest(BaseModel):
51
  prompt: str
52
  seed: int = 42
53
+ randomize_seed: bool = True
54
  width: int = 1024
55
  height: int = 1024
56
+ guidance_scale: float = 3.5
57
+ num_inference_steps: int = 8
58
+
59
+ class Timer:
60
+ def __init__(self, method_name="timed process"):
61
+ self.method = method_name
62
+
63
+ def __enter__(self):
64
+ self.start = time.time()
65
+ print(f"{self.method} starts")
66
+
67
+ def __exit__(self, exc_type, exc_val, exc_tb):
68
+ end = time.time()
69
+ print(f"{self.method} took {str(round(end - self.start, 2))}s")
70
 
71
  def save_image_to_s3(image):
72
  img_byte_arr = BytesIO()
73
  image.save(img_byte_arr, format='PNG')
74
  img_byte_arr = img_byte_arr.getvalue()
 
75
  filename = f"generated_image_{int(time.time())}.png"
 
76
  s3_client.put_object(Bucket=S3_BUCKET,
77
  Key=filename,
78
  Body=img_byte_arr,
79
  ContentType='image/png')
 
80
  url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}"
81
  return url
82
 
 
86
  seed = random.randint(0, MAX_SEED)
87
  else:
88
  seed = request.seed
 
89
  generator = torch.Generator().manual_seed(seed)
90
 
91
  try:
92
+ with Timer("Image generation"):
93
+ image = pipe(
94
+ prompt=request.prompt,
95
+ width=request.width,
96
+ height=request.height,
97
+ num_inference_steps=request.num_inference_steps,
98
+ generator=generator,
99
+ guidance_scale=request.guidance_scale
100
+ ).images[0]
101
 
102
  image_url = save_image_to_s3(image)
103
 
 
107
 
108
  @app.get("/")
109
  async def root():
110
+ return {"message": "Welcome to the IG API"}