Spaces:
Sleeping
Sleeping
import itertools | |
import gradio as gr | |
import requests | |
import os | |
from gradio.themes.utils import sizes | |
import json | |
import pandas as pd | |
import base64 | |
import io | |
from PIL import Image | |
import numpy as np | |
def respond(message, history): | |
if len(message.strip()) == 0: | |
return "質問を入力してください" | |
local_token = os.getenv('API_TOKEN') | |
local_endpoint = os.getenv('API_ENDPOINT') | |
if local_token is None or local_endpoint is None: | |
return "ERROR missing env variables" | |
# Add your API token to the headers | |
headers = { | |
'Content-Type': 'application/json', | |
'Authorization': f'Bearer {local_token}' | |
} | |
#prompt = list(itertools.chain.from_iterable(history)) | |
#prompt.append(message) | |
# プロンプトの作成 | |
prompt = pd.DataFrame( | |
{"prompt": [message], "num_inference_steps": 25} | |
) | |
print(prompt) | |
ds_dict = {"dataframe_split": prompt.to_dict(orient="split")} | |
data_json = json.dumps(ds_dict, allow_nan=True) | |
embed_image_markdown = "" | |
try: | |
# モデルサービングエンドポイントに問い合わせ | |
response = requests.request(method="POST", headers=headers, url=local_endpoint, data=data_json) | |
response_data = response.json() | |
#print(response_data["predictions"]) | |
# numpy arrayに変換 | |
im_array = np.array(response_data["predictions"], dtype=np.uint8) | |
#print(im_array) | |
# 画像に変換 | |
im = Image.fromarray(im_array, 'RGB') | |
# debug | |
#image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/e/ec/Mona_Lisa,_by_Leonardo_da_Vinci,_from_C2RMF_retouched.jpg/687px-Mona_Lisa,_by_Leonardo_da_Vinci,_from_C2RMF_retouched.jpg" | |
#print("image_url:", image_url) | |
#im = Image.open(io.BytesIO(requests.get(image_url).content)) | |
#numpydata = np.asarray(im) | |
rawBytes = io.BytesIO() | |
im.save(rawBytes, "PNG") | |
rawBytes.seek(0) # ファイルの先頭に移動 | |
# base64にエンコード | |
image_encoded = base64.b64encode(rawBytes.read()).decode('ascii') | |
#print(image_encoded) | |
# マークダウンに埋め込み | |
embed_image_markdown = f"![](data:image/png;base64,{image_encoded})" | |
#print(embed_image_markdown) | |
except Exception as error: | |
response_data = f"ERROR status_code: {type(error).__name__}" | |
#+ str(response.status_code) + " response:" + response.text | |
return embed_image_markdown | |
theme = gr.themes.Soft( | |
text_size=sizes.text_sm,radius_size=sizes.radius_sm, spacing_size=sizes.spacing_sm, | |
) | |
demo = gr.ChatInterface( | |
respond, | |
chatbot=gr.Chatbot(show_label=False, container=False, show_copy_button=True, bubble_full_width=True), | |
textbox=gr.Textbox(placeholder="質問を入力してください", | |
container=False, scale=7), | |
title="Databricks QAチャットボット", | |
description="TBD", | |
examples=[["Databricksクラスターとは?"], | |
["Unity Catalogの有効化方法"], | |
["リネージの保持期間"],], | |
cache_examples=False, | |
theme=theme, | |
retry_btn=None, | |
undo_btn=None, | |
clear_btn="Clear", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |