File size: 8,754 Bytes
f410e72
 
 
9dc7bb2
f410e72
9dc7bb2
f410e72
 
 
 
 
9dc7bb2
f410e72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dc7bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
475b1c1
 
9dc7bb2
 
475b1c1
 
 
9dc7bb2
 
475b1c1
 
f410e72
 
 
 
9dc7bb2
 
 
 
f410e72
 
 
9dc7bb2
 
 
 
f410e72
 
 
9dc7bb2
 
 
 
 
f410e72
475b1c1
 
 
 
9dc7bb2
 
475b1c1
 
 
 
 
9dc7bb2
 
475b1c1
 
f410e72
e152c3e
f410e72
 
 
9dc7bb2
 
 
 
 
 
 
 
 
 
b7c9c3b
 
f410e72
9dc7bb2
 
f410e72
 
9dc7bb2
 
 
f410e72
 
 
 
9dc7bb2
 
 
f410e72
 
 
 
9dc7bb2
 
 
 
 
f410e72
 
 
 
9dc7bb2
f410e72
 
833cd26
 
 
 
 
 
 
 
 
9dc7bb2
 
 
 
 
 
 
 
 
 
 
f410e72
 
 
 
 
 
 
 
 
 
 
 
 
9dc7bb2
f410e72
9dc7bb2
f410e72
 
 
9dc7bb2
 
 
 
 
 
 
 
 
f410e72
 
9dc7bb2
 
f410e72
 
 
 
 
 
 
 
 
 
 
9dc7bb2
 
f410e72
 
 
 
 
 
 
 
 
 
 
 
 
9dc7bb2
c1d20db
9dc7bb2
 
 
 
 
 
b34543e
9dc7bb2
 
f410e72
9dc7bb2
f410e72
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import os
import re
import shutil
import tempfile
import gradio as gr
import requests
from cogsgpt import CogsGPT


class Client:
    def __init__(self):
        self._client = CogsGPT(temperature=0.2, verbose=True)

    def _extract_medias(self, message):
        image_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(jpg|jpeg|tiff|gif|png)")
        image_urls = []
        for match in image_pattern.finditer(message):
            if match.group(0) not in image_urls:
                image_urls.append(match.group(0))

        audio_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(flac|wav)")
        audio_urls = []
        for match in audio_pattern.finditer(message):
            if match.group(0) not in audio_urls:
                audio_urls.append(match.group(0))

        return image_urls, audio_urls

    def _download_media(self, url):
        ext = url.split('.')[-1]
        response = requests.get(url, stream=True)
        with tempfile.NamedTemporaryFile(mode='w+b', suffix='.' + ext, delete=False) as media_file:
            shutil.copyfileobj(response.raw, media_file)
        return media_file.name

    def add_text(self, chatbot, text_input):
        self._text_input = text_input
        if self._text_input == "":
            return chatbot
        
        chatbot += [(self._text_input, None)]

        self._image_inputs, self._audio_inputs = self._extract_medias(self._text_input)
        for image_url in self._image_inputs:
            if image_url.startswith('http'):
                image_url = self._download_media(image_url)
            if os.path.exists(image_url):
                chatbot += [((image_url,), None)]
        for audio_url in self._audio_inputs:
            if audio_url.startswith('http'):
                audio_url = self._download_media(audio_url)
            if os.path.exists(audio_url):
                chatbot += [((audio_url,), None)]

        return chatbot

    def parse_task(self):
        if self._text_input == "":
            return
        
        self._task_list = self._client.parse_tasks(self._text_input)
        return self._task_list

    def execute_task(self):
        if self._text_input == "":
            return
        
        self._task_result_list = self._client.execute_tasks(self._task_list)
        return self._task_result_list

    def generate_response(self, chatbot):
        if self._text_input == "":
            return chatbot
        
        self._response = self._client.generate_response(self._text_input, self._task_result_list)
        chatbot += [(None, self._response)]

        image_outputs, audio_outputs = self._extract_medias(self._response)
        for image_url in image_outputs:
            if image_url in self._image_inputs:
                continue
            if image_url.startswith('http'):
                image_url = self._download_media(image_url)
            if os.path.exists(image_url):
                chatbot += [(None, (image_url,))]
        for audio_url in audio_outputs:
            if audio_url in self._audio_inputs:
                continue
            if audio_url.startswith('http'):
                audio_url = self._download_media(audio_url)
            if os.path.exists(audio_url):
                chatbot += [(None, (audio_url,))]

        # self._client.save_context(self._text_input, self._response)

        return chatbot

def set_key(state, openai_api_key):
    os.environ["OPENAI_API_TYPE"] = "openai"
    os.environ["OPENAI_API_KEY"] = openai_api_key
    os.environ["OPENAI_MODEL_NAME"] = "gpt-3.5-turbo"

    state["client"] = Client()
    return state, openai_api_key

def add_text(state, chatbot, text_input):
    if "client" not in state:
        chatbot += [(None, "Please set your OpenAI API key first!!!")]
        return chatbot, text_input

    chatbot = state["client"].add_text(chatbot, text_input)
    return chatbot, ""

def parse_task(state, chatbot):
    if "client" not in state:
        return chatbot, None
    
    task_list = state["client"].parse_task()
    return chatbot, task_list

def execute_task(state, chatbot):
    if "client" not in state:
        return chatbot, None
    
    task_result_list = state["client"].execute_task()
    return chatbot, task_result_list

def generate_response(state, chatbot):
    if "client" not in state:
        return chatbot
    
    chatbot = state["client"].generate_response(chatbot)
    return chatbot


css = ".json {height: 527px; overflow: scroll;} .json-holder {height: 527px; overflow: scroll;}"
with gr.Blocks(css=css) as demo:
    state = gr.State(value={})

    gr.Markdown("<h1><center>CogsGPT</center></h1>")
    gr.Markdown("<p align='center' style='font-size: 20px;'>A conversational system which integrates ChatGPT with Azure Cognitive Services to achieve multimodal capabilities.</p>")
    gr.Markdown("<p align='center' style='font-size: 18px;'>If you find it useful, please consider giving it a star on <a href='https://github.com/whiskyboy/cogsgpt'>Github</a>! :)</p>")
    gr.Markdown("""
    <div style="text-align: center;">
        <img alt="GitHub watchers" src="https://img.shields.io/github/watchers/whiskyboy/cogsgpt?style=social" style="display: inline-block;">
        <img alt="GitHub forks" src="https://img.shields.io/github/forks/whiskyboy/cogsgpt?style=social" style="display: inline-block;">
        <img alt="GitHub Repo stars" src="https://img.shields.io/github/stars/whiskyboy/cogsgpt?style=social" style="display: inline-block;">
    </div>
    """)

    with gr.Row():
        with gr.Column(scale=0.85):
            openai_api_key = gr.Textbox(
                show_label=False,
                placeholder="Set your OpenAI API key here and press Enter",
                lines=1,
                type="password"
            ).style(container=False)
        with gr.Column(scale=0.15, min_width=0):
            set_key_btn = gr.Button("Submit")

    # Ouput Row
    with gr.Row():
        with gr.Column(scale=0.6):
            chatbot = gr.Chatbot([], label="Chatbot").style(height=500)
        
        with gr.Column(scale=0.4):
            task_output = gr.JSON(label="Tasks", elem_classes="json")

    # Input Row
    with gr.Row():
        with gr.Column(scale=0.85):
            text_input = gr.Textbox(lines=1, show_label=False, interactive=True,
                                    placeholder="Enter text and press enter. The url must contain the media type. e.g, https://example.com/example.jpg",
                                    ).style(container=False)
        with gr.Column(scale=0.15, min_width=0):
            send_btn = gr.Button("Send", label="Send", interactive=True)

    # Even binding
    openai_api_key.submit(
        fn=set_key,
        inputs=[state, openai_api_key],
        outputs=[state, openai_api_key])
    set_key_btn.click(
        fn=set_key,
        inputs=[state, openai_api_key],
        outputs=[state, openai_api_key])

    text_input.submit(
        fn=add_text,
        inputs=[state, chatbot, text_input],
        outputs=[chatbot, text_input]).then(
        fn=parse_task,
        inputs=[state, chatbot],
        outputs=[chatbot, task_output]).then(
        fn=execute_task,
        inputs=[state, chatbot],
        outputs=[chatbot, task_output]).then(
        fn=generate_response,
        inputs=[state, chatbot],
        outputs=[chatbot])
    send_btn.click(
        fn=add_text,
        inputs=[state, chatbot, text_input],
        outputs=[chatbot, text_input]).then(
        fn=parse_task,
        inputs=[state, chatbot],
        outputs=[chatbot, task_output]).then(
        fn=execute_task,
        inputs=[state, chatbot],
        outputs=[chatbot, task_output]).then(
        fn=generate_response,
        inputs=[state, chatbot],
        outputs=[chatbot])
    
    # Examples
    gr.Examples(
        examples=[
            # CV
            "What can I make with these ingredients? ./tests/examples/ingredients.png",
            "Extract the text from the image: ./tests/examples/handwritten-note.jpg",
            # Speech
            "Convert the text 'CogsGPT is a multi-modal LLM integrated ChatGPT with Azure Cognitive Service' into speech.",
            "Extract the content of audio: ./tests/examples/cogsgpt.wav",
            # Form
            "List all the items and their prices from the receipt: ./tests/examples/receipt.png",
            "List all the flights with China Eastern airline in the flight schedule table from the file: ./tests/examples/flight-schedule.png.",
            # Complex task
            "Summarize the content in the audio file: ./tests/examples/voa-1min-news.wav, and translate it into Chinese. Then read it out.",
        ],
        inputs=text_input,
    )

demo.launch(show_api=False)