Spaces:
Runtime error
Runtime error
import gradio as gr | |
import openai | |
import requests | |
import os | |
import fileinput | |
from dotenv import load_dotenv | |
import io | |
from PIL import Image | |
from stability_sdk import client | |
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation | |
title="Storytelling-AI-1-test" | |
inputs_label="あなたが入力に応じてストーリーを生成します" | |
outputs_label="AIが生成したストーリー" | |
visual_outputs_label="AIが生成したビジュアルイメージ" | |
description=""" | |
- ストーリーの元になるアイデアを入力してください。エラーが発生した場合や、出力された内容が気に入らない場合は、再度送信するか、違う内容を入力して送信してください。 | |
""" | |
article = """ | |
<ul> | |
<li style="font-size: small;">出力されたタイトルと、選択肢ABのいずれかの出力をコピーして、次のステップに進みます→<a href="https://huggingface.co/spaces/Masa-digital-art/Storytelling-AI-2-test">https://huggingface.co/spaces/Masa-digital-art/Storytelling-AI-2-test</a></li> | |
</ul> | |
<h5>リリースノート</h5> | |
<ul> | |
<li style="font-size: small;">2023-08-31 v1.0</li> | |
<li style="font-size: small;">2023-09-09 v1.2</li> | |
<li style="font-size: small;">2023-09-13 v1.4</li> | |
<li style="font-size: small;">2023-09-17 v1.5</li> | |
</ul> | |
<h5>注意事項</h5> | |
<ul> | |
<li style="font-size: small;">当サービスでは、2023/3/14にリリースされたOpenAI社のChatGPT APIのgpt-4と、2022/4/13にリリースされたSability AI社のStable Diffusion XL 'sAPIを使用しております。</li> | |
<li style="font-size: small;">当サービスで生成されたテキストは、OpenAI が提供する人工知能によるものであり、当サービスやOpenAI がその正確性や信頼性を保証するものではありません。</li> | |
<li style="font-size: small;">当サービスで生成されたイメージは、Stability AI が提供する人工知能によるものであり、当サービスやStabiliy AI がその信頼性を保証するものではありません。</li> | |
<li style="font-size: small;"><a href="https://platform.openai.com/docs/usage-policies">OpenAI の利用規約</a>に従い、データ保持しない方針です(ただし諸般の事情によっては変更する可能性はございます)。 | |
<li style="font-size: small;">当サービスで生成されたコンテンツは事実確認をした上で、コンテンツ生成者およびコンテンツ利用者の責任において利用してください。</li> | |
<li style="font-size: small;">当サービスでの使用により発生したいかなる損害についても、当社は一切の責任を負いません。</li> | |
<li style="font-size: small;">当サービスはβ版のため、予告なくサービスを終了する場合がございます。</li> | |
</ul> | |
""" | |
load_dotenv() | |
openai.api_key = os.getenv('OPENAI_API_KEY') | |
os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443' | |
stability_api = client.StabilityInference( | |
key=os.getenv('STABILITY_KEY'), | |
engine="stable-diffusion-xl-1024-v1-0", | |
verbose=True, | |
) | |
MODEL = "gpt-4" | |
def get_filetext(filename, cache={}): | |
if filename in cache: | |
# キャッシュに保存されている場合は、キャッシュからファイル内容を取得する | |
return cache[filename] | |
else: | |
if not os.path.exists(filename): | |
raise ValueError(f"ファイル '{filename}' が見つかりませんでした") | |
with open(filename, "r") as f: | |
text = f.read() | |
# ファイル内容をキャッシュする | |
cache[filename] = text | |
return text | |
class OpenAI: | |
def chat_completion(cls, prompt, start_with=""): | |
constraints = get_filetext(filename = "constraints.md") | |
template = get_filetext(filename = "template.md") | |
# ChatCompletion APIに渡すデータを定義する | |
data = { | |
"model": "gpt-4", | |
"messages": [ | |
{"role": "system", "content": constraints} | |
,{"role": "system", "content": template} | |
,{"role": "assistant", "content": "Sure!"} | |
,{"role": "user", "content": prompt} | |
,{"role": "assistant", "content": start_with} | |
], | |
} | |
# ChatCompletion APIを呼び出す | |
response = requests.post( | |
"https://api.openai.com/v1/chat/completions", | |
headers={ | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {openai.api_key}" | |
}, | |
json=data | |
) | |
# ChatCompletion APIから返された結果を取得する | |
result = response.json() | |
print(result) | |
content = result["choices"][0]["message"]["content"].strip() | |
visualize_prompt = content.split("## Prompt for Visual Expression\n\n")[1] | |
answers = stability_api.generate( | |
prompt=("high quality illustlation, Stunning detail, crisp images, high-contrast images, cinematic lighting, sharp focus, imaginative concept art, fantastic colors, impressive shading, establishing shot, image board, wide shot, image of the beginning of the story" + visualize_prompt), | |
steps=50, | |
width=768, | |
height=512, | |
) | |
for resp in answers: | |
for artifact in resp.artifacts: | |
if artifact.finish_reason == generation.FILTER: | |
print("NSFW") | |
if artifact.type == generation.ARTIFACT_IMAGE: | |
img = Image.open(io.BytesIO(artifact.binary)) | |
return [content, img] | |
class MasasanAI: | |
def generate_vision_prompt(cls, user_message): | |
template = get_filetext(filename="template.md") | |
prompt = f""" | |
{user_message} | |
--- | |
上記を元に、下記テンプレートを埋めてください。 | |
--- | |
{template} | |
""" | |
return prompt | |
def generate_vision(cls, user_message): | |
prompt = MasasanAI.generate_vision_prompt(user_message); | |
start_with = "" | |
result = OpenAI.chat_completion(prompt=prompt, start_with=start_with) | |
return result | |
def main(): | |
iface = gr.Interface(fn=MasasanAI.generate_vision, | |
inputs=gr.Textbox(label=inputs_label), | |
outputs=[gr.Textbox(label=inputs_label), | |
gr.Image(label=visual_outputs_label)], | |
title=title, | |
description=description, | |
article=article, | |
allow_flagging='never' | |
) | |
iface.launch() | |
if __name__ == '__main__': | |
main() | |