File size: 4,743 Bytes
ad2821c
 
 
 
 
 
d749765
 
 
 
ad2821c
0458e89
 
ad2821c
d749765
ad2821c
e21e3fa
ad2821c
 
 
 
 
 
 
d749765
 
 
 
 
1c36440
ad2821c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c36440
ad2821c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08f93b8
ad2821c
08f93b8
1bc29e0
290fc46
 
 
 
08f93b8
09f34cb
ef46641
09f34cb
 
 
 
ad2821c
290fc46
ef46641
d749765
 
 
 
 
 
 
 
 
 
 
 
ad2821c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d749765
ad2821c
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gradio as gr
import openai
import requests
import os
import fileinput
from dotenv import load_dotenv
import io
from PIL import Image
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation

title="ochyAI recipe generator"
inputs_label="どんな料理か教えてくれれば,新しいレシピを考えます"
outputs_label="ochyAIが返信をします"
visual_outputs_label="料理のイメージ"
description="""
- ※入出力の文字数は最大1000文字程度までを目安に入力してください。解答に120秒くらいかかります.エラーが出た場合はログを開いてエラーメッセージを送ってくれるとochyAIが喜びます
"""

article = """
"""

load_dotenv()
openai.api_key = os.getenv('OPENAI_API_KEY')
os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'
stability_api = client.StabilityInference(
    key=os.getenv('STABILITY_KEY'), 
    verbose=True,
)
MODEL = "gpt-4"

def get_filetext(filename, cache={}):
    if filename in cache:
        # キャッシュに保存されている場合は、キャッシュからファイル内容を取得する
        return cache[filename]
    else:
        if not os.path.exists(filename):
            raise ValueError(f"ファイル '{filename}' が見つかりませんでした")
        with open(filename, "r") as f:
            text = f.read()
        # ファイル内容をキャッシュする
        cache[filename] = text
        return text

class OpenAI:
    
    @classmethod
    def chat_completion(cls, prompt, start_with=""):
        constraints = get_filetext(filename = "constraints.md")
        template = get_filetext(filename = "template.md")
        
        # ChatCompletion APIに渡すデータを定義する
        data = {
            "model": "gpt-4",
            "messages": [
                {"role": "system", "content": constraints}
                ,{"role": "system", "content": template}
                ,{"role": "assistant", "content": "Sure!"}
                ,{"role": "user", "content": prompt}
                ,{"role": "assistant", "content": start_with}
                ],
        }

        # ChatCompletion APIを呼び出す
        response = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers={
                "Content-Type": "application/json",
                "Authorization": f"Bearer {openai.api_key}"
            },
            json=data
        )

        # ChatCompletion APIから返された結果を取得する
        result = response.json()
        print(result)
        
        content = result["choices"][0]["message"]["content"].strip()
        
        split_content = content.split("### Prompt for Visual Expression\n\n")
        if len(split_content) > 1:
            visualize_prompt = split_content[1]
        else:
            visualize_prompt = "vacant dish"  # or any default value

        #print("split_content:"+split_content)

        #if len(split_content) > 1:
        #    visualize_prompt = split_content[1]
        #else:
        #    visualize_prompt = "vacant dish"

        print("visualize_prompt:"+visualize_prompt)
        
        answers = stability_api.generate(
            prompt=visualize_prompt,
        )

        for resp in answers:
            for artifact in resp.artifacts:
                if artifact.finish_reason == generation.FILTER:
                    print("NSFW")
                if artifact.type == generation.ARTIFACT_IMAGE:
                    img = Image.open(io.BytesIO(artifact.binary))
                    return [content, img]
        
class NajiminoAI:
    
    @classmethod
    def generate_emo_prompt(cls, user_message):
        template = get_filetext(filename="template.md")
        prompt = f"""
        {user_message}
        ---
        上記を元に、下記テンプレートを埋めてください。
        ---
        {template}
        """
        return prompt

    @classmethod
    def generate_emo(cls, user_message):
        prompt = NajiminoAI.generate_emo_prompt(user_message);
        start_with = ""
        result = OpenAI.chat_completion(prompt=prompt, start_with=start_with)
        return result

def main():
    iface = gr.Interface(fn=NajiminoAI.generate_emo,
                        inputs=gr.Textbox(label=inputs_label),
                        outputs=[gr.Textbox(label=inputs_label), gr.Image(label=visual_outputs_label)],
                        title=title,
                        description=description,
                        article=article,
                        allow_flagging='never'
                        )

    iface.launch()

if __name__ == '__main__':
    main()