radames commited on
Commit
ca95568
·
1 Parent(s): c9267e5

add json input

Browse files
Files changed (2) hide show
  1. app.py +74 -29
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,57 +1,102 @@
1
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
2
- import torch
3
  from PIL import Image
4
  import gradio as gr
5
 
 
 
 
 
6
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
7
  dtype = torch.float16
8
- nsfw_pipe = pipeline("image-classification",
9
- model= AutoModelForImageClassification.from_pretrained("carbon225/vit-base-patch16-224-hentai"),
10
- feature_extractor=AutoFeatureExtractor.from_pretrained("carbon225/vit-base-patch16-224-hentai"),
11
- device=device,
12
- torch_dtype=dtype)
 
 
13
 
14
 
15
- style_pipe = pipeline("image-classification",
16
- model= AutoModelForImageClassification.from_pretrained("cafeai/cafe_style"),
17
- feature_extractor=AutoFeatureExtractor.from_pretrained("cafeai/cafe_style"),
18
- device=device,
19
- torch_dtype=dtype)
 
 
20
 
21
- aesthetic_pipe = pipeline("image-classification",
22
- model= AutoModelForImageClassification.from_pretrained("cafeai/cafe_aesthetic"),
23
- feature_extractor=AutoFeatureExtractor.from_pretrained("cafeai/cafe_aesthetic"),
 
 
24
  device=device,
25
  torch_dtype=dtype)
26
 
27
- def predict(image, files=None):
28
- images_paths = [image]
29
- if not files == None:
30
- images_paths = list(map(lambda x: x.name, files))
31
- pil_images = [Image.open(image_path).convert("RGB") for image_path in images_paths]
32
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  style = style_pipe(pil_images)
34
  aesthetic = aesthetic_pipe(pil_images)
35
  nsfw = nsfw_pipe(pil_images)
36
- results = [ a + b + c for (a,b,c) in zip(style, aesthetic, nsfw)]
37
  label_data = {}
38
  if image is not None:
39
- label_data = { row["label"]:row["score"] for row in results[0] }
40
-
41
- return label_data, results
 
42
 
43
  with gr.Blocks() as blocks:
44
  with gr.Row():
45
  with gr.Column():
46
  image = gr.Image(label="Image to test", type="filepath")
47
- files = gr.File(label="Multipls Images", file_types=["image"], file_count="multiple")
 
 
 
 
 
 
48
  with gr.Column():
49
  label = gr.Label(label="style")
50
  results = gr.JSON(label="Results")
51
- # gallery = gr.Gallery().style(grid=[2], height="auto")
52
  btn = gr.Button("Run")
53
-
54
- btn.click(fn=predict, inputs=[image, files], outputs=[label, results], api_name="inference")
 
55
 
56
  blocks.queue()
57
- blocks.launch(debug=True,inline=True)
 
1
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
2
+ import torch
3
  from PIL import Image
4
  import gradio as gr
5
 
6
+ import aiohttp
7
+ import asyncio
8
+ from io import BytesIO
9
+
10
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
11
  dtype = torch.float16
12
+ nsfw_pipe = pipeline("image-classification",
13
+ model=AutoModelForImageClassification.from_pretrained(
14
+ "carbon225/vit-base-patch16-224-hentai"),
15
+ feature_extractor=AutoFeatureExtractor.from_pretrained(
16
+ "carbon225/vit-base-patch16-224-hentai"),
17
+ device=device,
18
+ torch_dtype=dtype)
19
 
20
 
21
+ style_pipe = pipeline("image-classification",
22
+ model=AutoModelForImageClassification.from_pretrained(
23
+ "cafeai/cafe_style"),
24
+ feature_extractor=AutoFeatureExtractor.from_pretrained(
25
+ "cafeai/cafe_style"),
26
+ device=device,
27
+ torch_dtype=dtype)
28
 
29
+ aesthetic_pipe = pipeline("image-classification",
30
+ model=AutoModelForImageClassification.from_pretrained(
31
+ "cafeai/cafe_aesthetic"),
32
+ feature_extractor=AutoFeatureExtractor.from_pretrained(
33
+ "cafeai/cafe_aesthetic"),
34
  device=device,
35
  torch_dtype=dtype)
36
 
37
+
38
+ async def fetch_image(session, image_url):
39
+ print(f"fetching image {image_url}")
40
+ async with session.get(image_url) as response:
41
+ if response.status == 200 and response.headers['content-type'].startswith('image'):
42
+ pil_image = Image.open(BytesIO(await response.read())).convert('RGB')
43
+ # resize image proportional
44
+ # image = ImageOps.fit(image, (400, 400), Image.LANCZOS)
45
+
46
+ return pil_image
47
+ return None
48
+
49
+
50
+ async def fetch_images(image_urls):
51
+ async with aiohttp.ClientSession() as session:
52
+ tasks = [asyncio.ensure_future(fetch_image(
53
+ session, image_url)) for image_url in image_urls]
54
+ return await asyncio.gather(*tasks)
55
+
56
+
57
+ async def predict(json=None, enable_gallery=True, image=None, files=None):
58
+ print(json)
59
+
60
+ if image or files:
61
+ if image is not None:
62
+ images_paths = [image]
63
+ elif files is not None:
64
+ images_paths = list(map(lambda x: x.name, files))
65
+ pil_images = [Image.open(image_path).convert("RGB")
66
+ for image_path in images_paths]
67
+ elif json is not None:
68
+ pil_images = await fetch_images(json["urls"])
69
+
70
  style = style_pipe(pil_images)
71
  aesthetic = aesthetic_pipe(pil_images)
72
  nsfw = nsfw_pipe(pil_images)
73
+ results = [a + b + c for (a, b, c) in zip(style, aesthetic, nsfw)]
74
  label_data = {}
75
  if image is not None:
76
+ label_data = {row["label"]: row["score"] for row in results[0]}
77
+
78
+ return results, label_data, pil_images if enable_gallery else None
79
+
80
 
81
  with gr.Blocks() as blocks:
82
  with gr.Row():
83
  with gr.Column():
84
  image = gr.Image(label="Image to test", type="filepath")
85
+ files = gr.File(label="Multipls Images", file_types=[
86
+ "image"], file_count="multiple")
87
+ enable_gallery = gr.Checkbox(label="Enable Gallery", value=True)
88
+ json = gr.JSON(label="Results", value={"urls": [
89
+ 'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/b9fb3257-6a54-455e-b636-9d61cf261676.jpg',
90
+ 'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/062eb9be-76eb-4d7e-9299-d1ebea14b46f.jpg',
91
+ 'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/8ff6d4f6-08d0-4a31-818c-4d32ab146f81.jpg']})
92
  with gr.Column():
93
  label = gr.Label(label="style")
94
  results = gr.JSON(label="Results")
95
+ gallery = gr.Gallery().style(grid=[2], height="auto")
96
  btn = gr.Button("Run")
97
+
98
+ btn.click(fn=predict, inputs=[json, enable_gallery, image, files],
99
+ outputs=[results, label, gallery], api_name="inference")
100
 
101
  blocks.queue()
102
+ blocks.launch(debug=True, inline=True)
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  transformers
2
  gradio
3
  --extra-index-url https://download.pytorch.org/whl/cu113
4
- torch
 
 
1
  transformers
2
  gradio
3
  --extra-index-url https://download.pytorch.org/whl/cu113
4
+ torch
5
+ aiohttp