gnilets commited on
Commit
c4d3ac4
·
verified ·
1 Parent(s): 9591807

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +298 -0
app.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from asyncio import sleep
2
+ from base64 import b64decode
3
+ from binascii import Error as BinasciiError
4
+ from contextlib import asynccontextmanager
5
+ from io import BytesIO
6
+ from json import dumps, loads
7
+ from logging import Formatter, INFO, StreamHandler, getLogger
8
+ from pathlib import Path
9
+ from random import choice
10
+ from typing import AsyncGenerator
11
+
12
+ from PIL.Image import open as image_open
13
+ from fastapi import FastAPI, Request
14
+ from fastapi.responses import HTMLResponse, JSONResponse
15
+ from httpx import AsyncClient
16
+ from patchright.async_api import async_playwright
17
+ from playwright.async_api import FilePayload, Request as PlaywrightRequest, async_playwright
18
+ from prlps_fakeua import UserAgent
19
+ from starlette.responses import Response
20
+
21
+ logger = getLogger('RHYMES_AI_API')
22
+ logger.setLevel(INFO)
23
+ handler = StreamHandler()
24
+ handler.setLevel(INFO)
25
+ formatter = Formatter('%(asctime)s | %(levelname)s : %(message)s', datefmt='%d.%m.%Y %H:%M:%S')
26
+ handler.setFormatter(formatter)
27
+ logger.addHandler(handler)
28
+
29
+ logger.info('инициализация приложения...')
30
+
31
+ ua = UserAgent(os=['windows', 'mac'])
32
+
33
+ workdir = Path(__file__).parent
34
+
35
+ infer_data = workdir / 'infer_data.json'
36
+
37
+ BASE_URL = 'https://akhaliq-anychat.hf.space'
38
+
39
+
40
+ def base64_to_jpeg_bytes(base64_str: str) -> bytes:
41
+ try:
42
+ if ',' not in base64_str:
43
+ raise ValueError("недопустимый формат строки base64")
44
+ base64_data = base64_str.split(',', 1)[1]
45
+ binary_data = b64decode(base64_data)
46
+ with image_open(BytesIO(binary_data)) as img:
47
+ with BytesIO() as jpeg_bytes:
48
+ img.convert('RGB').save(jpeg_bytes, format='JPEG', quality=90, optimize=True)
49
+ return jpeg_bytes.getvalue()
50
+ except (BinasciiError, OSError) as e:
51
+ raise ValueError('данные не являются корректным изображением') from e
52
+
53
+
54
+ def image_bytes(base64_image_str: str) -> FilePayload:
55
+ return FilePayload(
56
+ name=generate_random_string(12) + '.jpeg',
57
+ mimeType='image/jpeg',
58
+ buffer=base64_to_jpeg_bytes(base64_image_str)
59
+ )
60
+
61
+
62
+ def generate_random_string(length):
63
+ return ''.join(choice('abcdefghijklmnopqrstuvwxyz0123456789') for _ in range(length))
64
+
65
+
66
+ def get_infer_data() -> tuple[int, int, str]:
67
+ data = loads(infer_data.read_text())
68
+ logger.debug(f'загруженные из файла данные `get_infer_data`: {data}')
69
+ return data['fn_index'], data['trigger_id'], data['session_hash']
70
+
71
+
72
+ def prepare_data(gradio_file_path: str, question: str, fn_index: int, trigger_id: int, session_hash: str) -> dict:
73
+ return {
74
+ "data": [
75
+ None,
76
+ [[{
77
+ "file": {
78
+ "path": gradio_file_path,
79
+ "url": f"{BASE_URL}/gradio_api/file={gradio_file_path}",
80
+ "size": None, "orig_name": None, "mime_type": "image/jpeg", "is_stream": False,
81
+ "meta": {"_type": "gradio.FileData"}
82
+ },
83
+ "alt_text": None
84
+ }, None], [question, None]]
85
+ ], "event_data": None,
86
+ "fn_index": fn_index,
87
+ "trigger_id": trigger_id,
88
+ "session_hash": session_hash
89
+ }
90
+
91
+
92
+ async def fetch_result(base64_image_str: str, question: str) -> str | None:
93
+ fn_index, trigger_id, session_hash = get_infer_data()
94
+ async with AsyncClient(follow_redirects=True, timeout=40) as client:
95
+ image_file = image_bytes(base64_image_str)
96
+ boundary = f'----WebKitFormBoundary{generate_random_string(15).upper()}'
97
+ upload_response = await client.post(
98
+ f'{BASE_URL}/gradio_api/upload?upload_id={generate_random_string(11)}',
99
+ headers={
100
+ 'Content-Type': f'multipart/form-data; boundary={boundary}',
101
+ 'accept': '*/*'
102
+ },
103
+ content=(
104
+ f'--{boundary}\r\n'
105
+ f'Content-Disposition: form-data; name="files"; filename="{image_file.get('name')}"\r\n'
106
+ f'Content-Type: {image_file.get("mimeType")}\r\n\r\n'
107
+ f'{image_file.get("buffer").decode("latin1")}\r\n'
108
+ f'--{boundary}--\r\n'
109
+ ).encode('latin1')
110
+ )
111
+ upload_response.raise_for_status()
112
+ gradio_file_path = upload_response.json()[0]
113
+ logger.debug(f'gradio_file_path: {gradio_file_path}')
114
+ send_response = await client.post(
115
+ f'{BASE_URL}/gradio_api/queue/join',
116
+ headers={
117
+ 'accept': '*/*',
118
+ 'content-type': 'application/json'
119
+ },
120
+ json=prepare_data(gradio_file_path, question, fn_index, trigger_id, session_hash)
121
+ )
122
+ send_response.raise_for_status()
123
+ logger.debug(f'send_response: {send_response.text}')
124
+ async with client.stream(
125
+ 'GET',
126
+ f'{BASE_URL}/gradio_api/queue/data?session_hash={session_hash}',
127
+ headers={'accept': 'text/event-stream', 'content-type': 'application/json'
128
+ }) as result_response:
129
+ result_response.raise_for_status()
130
+ async for line in result_response.aiter_lines():
131
+ if line.startswith('data:'):
132
+ logger.debug(f'result_response line: {line}')
133
+ event_data = loads(line[6:])
134
+ if event_data.get('msg') == 'process_completed':
135
+ logger.debug(f'process_completed: {event_data}')
136
+ data = event_data.get('output', {}).get('data', [])
137
+ if data:
138
+ return data[0][1][1]
139
+ return None
140
+
141
+
142
+ def take_infer_data(request: PlaywrightRequest):
143
+ if request.url.startswith("https://akhaliq-anychat.hf.space/gradio_api/queue/join"):
144
+ try:
145
+ data = loads(request.post_data)
146
+ if data.get('data'):
147
+ fn_index = data.get('fn_index')
148
+ trigger_id = data.get('trigger_id')
149
+ session_hash = data.get('session_hash')
150
+ if fn_index and trigger_id and session_hash:
151
+ infer_data_json = {
152
+ 'fn_index': fn_index,
153
+ 'trigger_id': trigger_id,
154
+ 'session_hash': session_hash
155
+ }
156
+ infer_data.write_text(dumps(infer_data_json, indent=4))
157
+ logger.debug(f'полученные из браузера данные в `take_infer_data`: {infer_data_json}')
158
+ except Exception as ext:
159
+ logger.error(f'ошибка `take_infer_data`: {ext}')
160
+ pass
161
+
162
+
163
+ async def browser_request(base64_image_str: str, question: str) -> str | None:
164
+ async with async_playwright() as playwright:
165
+ browser = await playwright.chromium.launch(headless=True, args=['--disable-blink-features=AutomationControlled'])
166
+ context = await browser.new_context(
167
+ color_scheme='dark',
168
+ ignore_https_errors=True,
169
+ locale='en-US',
170
+ user_agent=ua.random,
171
+ no_viewport=True,
172
+ )
173
+ try:
174
+ page = await context.new_page()
175
+ image_file = image_bytes(base64_image_str)
176
+ page.on('request', take_infer_data)
177
+ await page.goto('https://akhaliq-anychat.hf.space/?__theme=light')
178
+ await page.get_by_role('tab', name='Grok').click()
179
+ await page.get_by_role('textbox', name='Type a message...').fill(question)
180
+ await page.get_by_role('group', name='Multimedia input field').get_by_test_id('file-upload').set_input_files(image_file)
181
+ await page.wait_for_selector('img.thumbnail-image')
182
+ submit_button = page.get_by_role('group', name='Multimedia input field').locator('.submit-button')
183
+ await submit_button.click()
184
+ await page.wait_for_selector('button[aria-label="Retry"]', state='visible')
185
+ await submit_button.wait_for(state='visible')
186
+ caption = ' '.join(await page.get_by_test_id('bot').all_text_contents()).strip()
187
+ await context.close()
188
+ await browser.close()
189
+ if caption:
190
+ logger.info('результат получен из `browser_request`')
191
+ return caption
192
+ except Exception as exc:
193
+ logger.error(f'ошибка `browser_request`: {exc}')
194
+ return None
195
+
196
+
197
+ async def httpx_request(base64_image_str: str, question: str) -> str | None:
198
+ try:
199
+ caption = await fetch_result(base64_image_str, question)
200
+ logger.debug(caption)
201
+ if caption:
202
+ logger.info('результат получен из `httpx_request`')
203
+ return caption
204
+ except Exception as exc:
205
+ logger.error(f'ошибка `browser_request`: {exc}')
206
+ return None
207
+
208
+
209
+ async def get_grok_caption(base64_image_str: str, question: str) -> str | None:
210
+ attempts = 3
211
+ for _ in range(attempts):
212
+ result = await httpx_request(base64_image_str, question)
213
+ if result:
214
+ return result
215
+ result = await browser_request(base64_image_str, question)
216
+ if result:
217
+ return result
218
+ await sleep(1.5)
219
+ logger.error(f'превышено максимальное количество попыток')
220
+ return None
221
+
222
+
223
+ @asynccontextmanager
224
+ async def app_lifespan(_) -> AsyncGenerator:
225
+ logger.info('запуск приложения')
226
+ try:
227
+ logger.info('старт API')
228
+ yield
229
+ finally:
230
+ logger.info('приложение завершено')
231
+
232
+
233
+ app = FastAPI(lifespan=app_lifespan, title='RHYMES_AI_API')
234
+
235
+ banned_endpoints = [
236
+ '/openapi.json',
237
+ '/docs',
238
+ '/docs/oauth2-redirect',
239
+ 'swagger_ui_redirect',
240
+ '/redoc',
241
+ ]
242
+
243
+
244
+ @app.middleware('http')
245
+ async def block_banned_endpoints(request: Request, call_next):
246
+ logger.debug(f'получен запрос: {request.url.path}')
247
+ if request.url.path in banned_endpoints:
248
+ logger.warning(f'запрещенный endpoint: {request.url.path}')
249
+ return Response(status_code=403)
250
+ response = await call_next(request)
251
+ return response
252
+
253
+
254
+ @app.post('/v1/describe')
255
+ async def describe_v1(request: Request):
256
+ logger.info('запрос `describe_v1`')
257
+ body = await request.json()
258
+ content_text = ''
259
+ image_data = ''
260
+
261
+ messages = body.get('messages', [])
262
+ for message in messages:
263
+ role = message.get('role')
264
+ content = message.get('content')
265
+
266
+ if role in ['system', 'user']:
267
+ if isinstance(content, str):
268
+ content_text += content + ' '
269
+ elif isinstance(content, list):
270
+ for item in content:
271
+ if item.get('type') == 'text':
272
+ content_text += item.get('text', '') + ' '
273
+ elif item.get('type') == 'image_url':
274
+ image_url = item.get('image_url', {})
275
+ url = image_url.get('url')
276
+ if url and url.startswith('data:image/'):
277
+ image_data = url
278
+ image_data, content_text = image_data.strip(), content_text.strip()
279
+
280
+ if not content_text or not image_data:
281
+ return JSONResponse({'caption': 'изображение должно быть передано как строка base64 `data:image/jpeg;base64,{base64_img}` а также текст'}, status_code=400)
282
+ try:
283
+ caption = await get_grok_caption(image_data, content_text)
284
+ return JSONResponse({'caption': caption}, status_code=200)
285
+ except Exception as e:
286
+ return JSONResponse({'caption': str(e)}, status_code=500)
287
+
288
+
289
+ @app.get('/')
290
+ async def root():
291
+ return HTMLResponse('ну пролапс, ну и что', status_code=200)
292
+
293
+
294
+ if __name__ == '__main__':
295
+ from uvicorn import run as uvicorn_run
296
+
297
+ logger.info('запуск сервера uvicorn')
298
+ uvicorn_run(app, host='0.0.0.0', port=7860)