xujz0703 commited on
Commit
f08695f
1 Parent(s): 935b23c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -26
app.py CHANGED
@@ -2,27 +2,15 @@ import os
2
 
3
  import gradio as gr
4
  import torch
5
- from diffusers import StableDiffusionPipeline
6
  from PIL import Image
7
-
8
- import ImageReward as RM
9
-
10
- # initialize
11
-
12
- model_id = "runwayml/stable-diffusion-v1-5"
13
- pipe = StableDiffusionPipeline.from_pretrained(
14
- model_id,
15
- torch_dtype=torch.float16,
16
- )
17
-
18
- model = RM.load("ImageReward-v1.0")
19
 
20
  images_in_gallery = []
21
  rewards_in_gallery = []
22
 
23
- # event functions
24
-
25
-
26
  def generate_images(
27
  prompt, magic_words, num, height, width, num_inference_steps, guidance_scale
28
  ):
@@ -31,14 +19,33 @@ def generate_images(
31
  if magic_words is not None:
32
  prompt += ", ".join(magic_words)
33
 
34
- images_in_gallery = pipe(
35
- prompt,
36
- height=height,
37
- width=width,
38
- num_inference_steps=num_inference_steps,
39
- guidance_scale=guidance_scale,
40
- num_images_per_prompt=num,
41
- ).images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  rewards_in_gallery = [None] * len(images_in_gallery)
43
  return list(zip(images_in_gallery, rewards_in_gallery))
44
 
@@ -50,9 +57,25 @@ def score_and_rank(prompt):
50
 
51
  if num_not_scored > 0:
52
  images_to_score = images_in_gallery[-num_not_scored:]
 
 
 
 
 
 
 
 
53
  with torch.no_grad():
54
- ranks, rewards = model.inference_rank(prompt, images_to_score)
55
-
 
 
 
 
 
 
 
 
56
  if not isinstance(rewards, list):
57
  rewards = [rewards]
58
  rewards_in_gallery = rewards_in_gallery[:-num_not_scored] + rewards
 
2
 
3
  import gradio as gr
4
  import torch
 
5
  from PIL import Image
6
+ import io
7
+ import base64
8
+ import requests
9
+ import json
 
 
 
 
 
 
 
 
10
 
11
  images_in_gallery = []
12
  rewards_in_gallery = []
13
 
 
 
 
14
  def generate_images(
15
  prompt, magic_words, num, height, width, num_inference_steps, guidance_scale
16
  ):
 
19
  if magic_words is not None:
20
  prompt += ", ".join(magic_words)
21
 
22
+ # post 请求发送到服务器
23
+
24
+ # 定义请求的 URL 和数据
25
+ url = 'https://tianqi.aminer.cn/image_reward_hf/generate_image'
26
+ data = {'prompt': prompt,
27
+ 'height': height,
28
+ 'width':width,
29
+ 'num_inference_steps':num_inference_steps,
30
+ 'guidance_scale':guidance_scale,
31
+ 'num':num
32
+ }
33
+ headers = {'Content-Type': 'application/json'}
34
+
35
+ # 发送 POST 请求
36
+ data = json.dumps(data)
37
+ response = requests.post(url, data=data, headers=headers)
38
+ image_ls = response.json()['image_list']
39
+
40
+ images_in_gallery = []
41
+ for base_image in image_ls:
42
+ image_bytes = base64.b64decode(base_image)
43
+ # 创建 BytesIO 对象并读取图像字节流
44
+ image_stream = io.BytesIO(image_bytes)
45
+ # 打开图像
46
+ image = Image.open(image_stream)
47
+ images_in_gallery.append(image)
48
+
49
  rewards_in_gallery = [None] * len(images_in_gallery)
50
  return list(zip(images_in_gallery, rewards_in_gallery))
51
 
 
57
 
58
  if num_not_scored > 0:
59
  images_to_score = images_in_gallery[-num_not_scored:]
60
+ image_ls = []
61
+ for image in images_to_score:
62
+ image_bytes = io.BytesIO()
63
+ image.save(image_bytes, format='JPEG')
64
+ image_bytes.seek(0)
65
+ # 将字节流转换为 Base64 编码
66
+ base64_image = base64.b64encode(image_bytes.read()).decode('utf-8')
67
+ image_ls.append(base64_image)
68
  with torch.no_grad():
69
+ # post 请求发送到服务器
70
+ url = 'https://tianqi.aminer.cn/image_reward_hf/score_and_rank'
71
+ data = {'images_to_score': image_ls, 'prompt':prompt}
72
+ data = json.dumps(data)
73
+ headers = {'Content-Type': 'application/json'}
74
+
75
+ # 发送 POST 请求
76
+ response = requests.post(url, data=data, headers=headers)
77
+ rewards = response.json()['rewards']
78
+
79
  if not isinstance(rewards, list):
80
  rewards = [rewards]
81
  rewards_in_gallery = rewards_in_gallery[:-num_not_scored] + rewards