khawir commited on
Commit
648acd4
1 Parent(s): 6de967f

added image variations

Browse files
app/api/generate.py CHANGED
@@ -140,17 +140,16 @@ def image_variations(model: Request, request: ImageVariations, db: Annotated[Ses
140
  if not current_user.is_active:
141
  raise HTTPException(status_code=403, detail="Forbidden")
142
 
143
- # prompt = create_prompt(request.prompt, medium=request.medium, style=request.style, artist=request.artist, website=request.website, resolution=request.resolution, additional_details=request.additional_details, color=request.color, lightning=request.lightning)
144
- # image = decode_image(request.image)
145
- # image.resize((512, 512))
146
 
147
- # if prompt:
148
- # crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
149
 
150
- # images = model.state.iv_model.generate(pil_image=image, num_samples=request.num_samples, num_inference_steps=request.num_inference_steps,
151
- # seed=request.seed, prompt=prompt, scale=request.scale, negative_prompt=request.negative_prompt)
152
 
153
- # images = [encode_image(image) for image in images]
154
 
155
- # return images
156
- return ["Image Variations is not supported yet."]
 
140
  if not current_user.is_active:
141
  raise HTTPException(status_code=403, detail="Forbidden")
142
 
143
+ prompt = create_prompt(request.prompt, medium=request.medium, style=request.style, artist=request.artist, website=request.website, resolution=request.resolution, additional_details=request.additional_details, color=request.color, lightning=request.lightning)
144
+ image = decode_image(request.image)
145
+ image.resize((512, 512))
146
 
147
+ if prompt:
148
+ crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
149
 
150
+ images = model.state.iv_model.generate(pil_image=image, num_samples=request.num_samples, num_inference_steps=request.num_inference_steps,
151
+ seed=request.seed, prompt=prompt, scale=request.scale, negative_prompt=request.negative_prompt)
152
 
153
+ images = [encode_image(image) for image in images]
154
 
155
+ return images
 
app/config.py CHANGED
@@ -3,7 +3,7 @@ import os
3
 
4
  DATABASE_URL = 'sqlite:///./sql_app.db'
5
 
6
- SECRET_KEY = os.environ.get("SECRET_KEY")
7
  ALGORITHM = os.environ.get("ALGORITHM")
8
 
9
  ACCESS_TOKEN_EXPIRE_MINUTES = 30
 
3
 
4
  DATABASE_URL = 'sqlite:///./sql_app.db'
5
 
6
+ SECRET_KEY = os.environ.get("SECRET_KEY")
7
  ALGORITHM = os.environ.get("ALGORITHM")
8
 
9
  ACCESS_TOKEN_EXPIRE_MINUTES = 30
app/main.py CHANGED
@@ -6,7 +6,7 @@ from fastapi.security import OAuth2PasswordRequestForm
6
  from fastapi import APIRouter, FastAPI, HTTPException, Depends
7
 
8
  import torch
9
- # from ip_adapter import IPAdapterXL
10
  from transformers import AutoModelForImageSegmentation
11
  from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler, StableDiffusionXLPipeline
12
 
@@ -40,20 +40,20 @@ async def lifespan(app: FastAPI):
40
  "briaai/RMBG-1.4", trust_remote_code=True)
41
  br_model.to(DEVICE)
42
 
43
- # sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(
44
- # base_model_path,
45
- # torch_dtype=torch.float16,
46
- # add_watermarker=False,
47
- # )
48
- # iv_model = IPAdapterXL(sdxl_pipe, image_encoder_path, ip_ckpt, DEVICE)
49
 
50
- yield {'ti_pipe': ti_pipe, 'ii_pipe': ii_pipe, 'br_model': br_model} # , 'iv_model': iv_model
51
 
52
  del ti_pipe
53
  del ii_pipe
54
  del br_model
55
- # del sdxl_pipe
56
- # del iv_model
57
 
58
 
59
  app = FastAPI(lifespan=lifespan)
@@ -93,7 +93,7 @@ async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db:
93
  return {"access_token": access_token, "token_type": "bearer"}
94
 
95
 
96
- router.include_router(user.router, prefix="/users")
97
- router.include_router(prompt.router, prefix="/prompts")
98
- router.include_router(generate.router, prefix="/generate")
99
  app.include_router(router)
 
6
  from fastapi import APIRouter, FastAPI, HTTPException, Depends
7
 
8
  import torch
9
+ from ip_adapter import IPAdapterXL
10
  from transformers import AutoModelForImageSegmentation
11
  from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler, StableDiffusionXLPipeline
12
 
 
40
  "briaai/RMBG-1.4", trust_remote_code=True)
41
  br_model.to(DEVICE)
42
 
43
+ sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(
44
+ base_model_path,
45
+ torch_dtype=torch.float16,
46
+ add_watermarker=False,
47
+ )
48
+ iv_model = IPAdapterXL(sdxl_pipe, image_encoder_path, ip_ckpt, DEVICE)
49
 
50
+ yield {'ti_pipe': ti_pipe, 'ii_pipe': ii_pipe, 'br_model': br_model, 'iv_model': iv_model}
51
 
52
  del ti_pipe
53
  del ii_pipe
54
  del br_model
55
+ del sdxl_pipe
56
+ del iv_model
57
 
58
 
59
  app = FastAPI(lifespan=lifespan)
 
93
  return {"access_token": access_token, "token_type": "bearer"}
94
 
95
 
96
+ router.include_router(user.router, prefix="/users", tags=["users"])
97
+ router.include_router(prompt.router, prefix="/prompts", tags=["prompts"])
98
+ router.include_router(generate.router, prefix="/generate", tags=["generate"])
99
  app.include_router(router)
requirements.txt CHANGED
@@ -12,4 +12,5 @@ safetensors
12
  torch
13
  torchvision
14
  pillow
15
- numpy
 
 
12
  torch
13
  torchvision
14
  pillow
15
+ einops
16
+ git+https://github.com/tencent-ailab/IP-Adapter.git
sdxl_models/image_encoder/config.json ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPVisionModelWithProjection"
4
+ ],
5
+ "_name_or_path": "",
6
+ "add_cross_attention": false,
7
+ "architectures": null,
8
+ "attention_dropout": 0.0,
9
+ "bad_words_ids": null,
10
+ "begin_suppress_tokens": null,
11
+ "bos_token_id": null,
12
+ "chunk_size_feed_forward": 0,
13
+ "cross_attention_hidden_size": null,
14
+ "decoder_start_token_id": null,
15
+ "diversity_penalty": 0.0,
16
+ "do_sample": false,
17
+ "dropout": 0.0,
18
+ "early_stopping": false,
19
+ "encoder_no_repeat_ngram_size": 0,
20
+ "eos_token_id": null,
21
+ "exponential_decay_length_penalty": null,
22
+ "finetuning_task": null,
23
+ "forced_bos_token_id": null,
24
+ "forced_eos_token_id": null,
25
+ "hidden_act": "gelu",
26
+ "hidden_size": 1664,
27
+ "id2label": {
28
+ "0": "LABEL_0",
29
+ "1": "LABEL_1"
30
+ },
31
+ "image_size": 224,
32
+ "initializer_factor": 1.0,
33
+ "initializer_range": 0.02,
34
+ "intermediate_size": 8192,
35
+ "is_decoder": false,
36
+ "is_encoder_decoder": false,
37
+ "label2id": {
38
+ "LABEL_0": 0,
39
+ "LABEL_1": 1
40
+ },
41
+ "layer_norm_eps": 1e-05,
42
+ "length_penalty": 1.0,
43
+ "max_length": 20,
44
+ "min_length": 0,
45
+ "model_type": "clip_vision_model",
46
+ "no_repeat_ngram_size": 0,
47
+ "num_attention_heads": 16,
48
+ "num_beam_groups": 1,
49
+ "num_beams": 1,
50
+ "num_channels": 3,
51
+ "num_hidden_layers": 48,
52
+ "num_return_sequences": 1,
53
+ "output_attentions": false,
54
+ "output_hidden_states": false,
55
+ "output_scores": false,
56
+ "pad_token_id": null,
57
+ "patch_size": 14,
58
+ "prefix": null,
59
+ "problem_type": null,
60
+ "pruned_heads": {},
61
+ "remove_invalid_values": false,
62
+ "repetition_penalty": 1.0,
63
+ "return_dict": true,
64
+ "return_dict_in_generate": false,
65
+ "sep_token_id": null,
66
+ "suppress_tokens": null,
67
+ "task_specific_params": null,
68
+ "temperature": 1.0,
69
+ "tf_legacy_loss": false,
70
+ "tie_encoder_decoder": false,
71
+ "tie_word_embeddings": true,
72
+ "tokenizer_class": null,
73
+ "top_k": 50,
74
+ "top_p": 1.0,
75
+ "torch_dtype": null,
76
+ "torchscript": false,
77
+ "transformers_version": "4.24.0",
78
+ "typical_p": 1.0,
79
+ "use_bfloat16": false,
80
+ "projection_dim": 1280
81
+ }
sdxl_models/image_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:657723e09f46a7c3957df651601029f66b1748afb12b419816330f16ed45d64d
3
+ size 3689912664
sdxl_models/image_encoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2999562fbc02f9dc0d9c0acb7cf0970ec3a9b2a578d7d05afe82191d606d2d80
3
+ size 3690112753
sdxl_models/ip-adapter-plus-face_sdxl_vit-h.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50e886d82940b3c5873d80c2b06d8a4b0d0fccec70bc44fd53f16ac3cfd7fc36
3
+ size 1013454761
sdxl_models/ip-adapter-plus-face_sdxl_vit-h.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:677ad8860204f7d0bfba12d29e6c31ded9beefdf3e4bbd102518357d31a292c1
3
+ size 847517512
sdxl_models/ip-adapter-plus_sdxl_vit-h.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec70edb7cc8e769c9388d94eeaea3e4526352c9fae793a608782d1d8951fde90
3
+ size 1013454427
sdxl_models/ip-adapter-plus_sdxl_vit-h.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f5062b8400c94b7159665b21ba5c62acdcd7682262743d7f2aefedef00e6581
3
+ size 847517512
sdxl_models/ip-adapter_sdxl.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7525f2731e9e86d1368e0b68467615d55dda459691965bdd7d37fa3d7fd84c12
3
+ size 702585097
sdxl_models/ip-adapter_sdxl.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba1002529e783604c5f326d49f0122025392d1d20ac8d573b3eeb3e6dea4ebb6
3
+ size 702585376
sdxl_models/ip-adapter_sdxl_vit-h.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b382e2501d0ab3fe2e09312e561a59cd3f21262aff25373700e0cd62c635929
3
+ size 698390793
sdxl_models/ip-adapter_sdxl_vit-h.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebf05d918348aec7abb02a5e9ecef77e0aaea6914a5c4ea13f50d45eb1681831
3
+ size 698391064