craftgamesnetwork commited on
Commit
6a07040
1 Parent(s): e1fc40e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +139 -20
main.py CHANGED
@@ -5,16 +5,77 @@ from gradio_client import Client
5
  from huggingface_hub import create_repo, upload_file
6
 
7
  app = Flask(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Função para chamar a API de hospedagem de imagens
10
- def host_image(image_url):
11
- api_url = "https://wosocial.bubbleapps.io/version-test/api/1.1/wf/save"
12
- payload = {'file': image_url}
13
- response = requests.post(api_url, data=payload)
14
- if response.status_code == 200:
15
- return response.json()["response"]["result"]
16
- else:
17
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  @app.route('/faceswapper', methods=['GET'])
20
  def faceswapper():
@@ -25,22 +86,80 @@ def faceswapper():
25
 
26
  # Chamar a API Gradio
27
  client = Client(endpoint, upload_files=True)
28
- result_path = client.predict(
29
  user_photo,
30
  result_photo,
31
  api_name="/predict"
32
  )
33
 
34
- # Mesclar o endpoint com o caminho do arquivo
35
- full_url = endpoint + "/file=" + result_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Hospedar a imagem e obter a URL
38
- hosted_url = host_image(full_url)
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- if hosted_url:
41
- return jsonify({"result_url": hosted_url})
42
- else:
43
- return jsonify({"error": "Falha ao hospedar a imagem."}), 500
 
 
44
 
45
- if __name__ == '__main__':
46
- app.run(debug=True)
 
5
  from huggingface_hub import create_repo, upload_file
6
 
7
  app = Flask(__name__)
8
+
9
+ @app.route('/run', methods=['POST'])
10
+ def run_model():
11
+ # Obter parâmetros da consulta da URL
12
+ endpoint = request.args.get('endpoint', default='https://pierroromeu-zbilatuca2testzz.hf.space')
13
+ prompt = request.args.get('prompt', default='Hello!!')
14
+ negative_prompt = request.args.get('negative_prompt', default='Hello!!')
15
+ prompt_2 = request.args.get('prompt_2', default='Hello!!')
16
+ negative_prompt_2 = request.args.get('negative_prompt_2', default='Hello!!')
17
+ use_negative_prompt = request.args.get('use_negative_prompt', type=bool, default=True)
18
+ use_prompt_2 = request.args.get('use_prompt_2', type=bool, default=True)
19
+ use_negative_prompt_2 = request.args.get('use_negative_prompt_2', type=bool, default=False)
20
+ seed = request.args.get('seed', type=int, default=0)
21
+ width = request.args.get('width', type=int, default=256)
22
+ height = request.args.get('height', type=int, default=256)
23
+ guidance_scale = request.args.get('guidance_scale', type=float, default=5.5)
24
+ num_inference_steps = request.args.get('num_inference_steps', type=int, default=50)
25
+ strength = request.args.get('strength', type=float, default=0.7)
26
+ use_vae_str = request.args.get('use_vae', default='false') # Obtém use_vae como string
27
+ use_vae = use_vae_str.lower() == 'true' # Converte para booleano
28
+ use_lora_str = request.args.get('use_lora', default='false') # Obtém use_lora como string
29
+ use_lora = use_lora_str.lower() == 'true' # Converte para booleano
30
+ use_img2img_str = request.args.get('use_img2img', default='false') # Obtém use_vae como string
31
+ use_img2img = use_img2img_str.lower() == 'true' # Converte para booleano
32
+ model = request.args.get('model', default='stabilityai/stable-diffusion-xl-base-1.0')
33
+ vaecall = request.args.get('vaecall', default='madebyollin/sdxl-vae-fp16-fix')
34
+ lora = request.args.get('lora', default='amazonaws-la/sdxl')
35
+ lora_scale = request.args.get('lora_scale', type=float, default=0.7)
36
+ url = request.args.get('url', default='https://example.com/image.png')
37
 
38
+ # Chamar a API Gradio
39
+ client = Client(endpoint)
40
+ result = client.predict(
41
+ prompt, negative_prompt, prompt_2, negative_prompt_2,
42
+ use_negative_prompt, use_prompt_2, use_negative_prompt_2,
43
+ seed, width, height,
44
+ guidance_scale,
45
+ num_inference_steps,
46
+ strength,
47
+ use_vae,
48
+ use_lora,
49
+ model,
50
+ vaecall,
51
+ lora,
52
+ lora_scale,
53
+ use_img2img,
54
+ url,
55
+ api_name="/run"
56
+ )
57
+
58
+ return jsonify(result)
59
+
60
+ @app.route('/predict', methods=['POST'])
61
+ def predict_gan():
62
+ # Obter parâmetros da consulta da URL
63
+ endpoint = request.args.get('endpoint', default='https://pierroromeu-gfpgan.hf.space/--replicas/dgwcd/')
64
+ hf_token = request.args.get('hf_token', default='')
65
+ filepath = request.args.get('filepath', default='')
66
+ version = request.args.get('version', default='v1.4')
67
+ rescaling_factor = request.args.get('rescaling_factor', type=float, default=2.0)
68
+
69
+ # Chamar a API Gradio
70
+ client = Client(endpoint, hf_token=hf_token)
71
+ result = client.predict(
72
+ filepath,
73
+ version,
74
+ rescaling_factor,
75
+ api_name="/predict"
76
+ )
77
+
78
+ return jsonify(result)
79
 
80
  @app.route('/faceswapper', methods=['GET'])
81
  def faceswapper():
 
86
 
87
  # Chamar a API Gradio
88
  client = Client(endpoint, upload_files=True)
89
+ result = client.predict(
90
  user_photo,
91
  result_photo,
92
  api_name="/predict"
93
  )
94
 
95
+ return jsonify(result)
96
+
97
+ @app.route('/train', methods=['POST'])
98
+ def answer():
99
+ # Obter parâmetros da consulta da URL
100
+ token = request.args.get('token', default='')
101
+ endpoint = request.args.get('endpoint', default='https://pierroromeu-gfpgan.hf.space/--replicas/dgwcd/')
102
+ dataset_id=request.args.get('dataset_id', default='')
103
+ output_model_folder_name=request.args.get('output_model_folder_name', default='')
104
+ concept_prompt=request.args.get('concept_prompt', default='')
105
+ max_training_steps=request.args.get('max_training_steps', type=int, default=0)
106
+ checkpoints_steps=request.args.get('checkpoints_steps', type=int, default=0)
107
+ remove_gpu_after_training_str = request.args.get('remove_gpu_after_training', default='false') # Obtém como string
108
+ remove_gpu_after_training = remove_gpu_after_training_str.lower() == 'true' # Converte para booleano
109
+
110
+ # Chamar a API Gradio
111
+ client = Client(endpoint, hf_token=token)
112
+ result = client.predict(
113
+ dataset_id,
114
+ output_model_folder_name,
115
+ concept_prompt,
116
+ max_training_steps,
117
+ checkpoints_steps,
118
+ remove_gpu_after_training,
119
+ api_name="/main"
120
+ )
121
+
122
+ return jsonify(result)
123
+
124
+ @app.route('/verify', methods=['GET'])
125
+ # ‘/’ URL is bound with hello_world() function.
126
+ def hello_world():
127
+ return jsonify('Check')
128
+
129
+ @app.route('/upload_model', methods=['POST'])
130
+ def upload_model():
131
+ # Parâmetros
132
+ file_name= request.args.get('file_name', default='')
133
+ repo = request.args.get('repo', default='')
134
+ url = request.args.get('url', default='')
135
+ token = request.args.get('token', default='')
136
+
137
+ try:
138
+ # Crie o repositório
139
+ repo_id = repo
140
+ create_repo(repo_id=repo_id, token=token)
141
 
142
+ # Faça o download do conteúdo da URL em memória
143
+ response = requests.get(url)
144
+ if response.status_code == 200:
145
+ # Obtenha o conteúdo do arquivo em bytes
146
+ file_content = response.content
147
+ # Crie um objeto de arquivo em memória
148
+ file_obj = BytesIO(file_content)
149
+ # Faça o upload do arquivo
150
+ upload_file(
151
+ path_or_fileobj=file_obj,
152
+ path_in_repo=file_name,
153
+ repo_id=repo_id,
154
+ token=token
155
+ )
156
 
157
+ # Mensagem de sucesso
158
+ return jsonify({"message": "Sucess"})
159
+ else:
160
+ return jsonify({"error": "Failed"}), 500
161
+ except Exception as e:
162
+ return jsonify({"error": str(e)}), 500
163
 
164
+ if __name__ == "__main__":
165
+ app.run(host="0.0.0.0", port=7860)