File size: 4,275 Bytes
98e9d77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
Here is an improved and fixed version of the `app.py` script with better structure, error handling, and documentation:

```python
import uuid
import gradio as gr
import re
from diffusers.utils import load_image
import requests
from awesome_chat import chat_huggingface
import os

# Create necessary directories
os.makedirs("public/images", exist_ok=True)
os.makedirs("public/audios", exist_ok=True)
os.makedirs("public/videos", exist_ok=True)

# Fetch environment variables
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
OPENAI_KEY = os.environ.get("OPENAI_KEY")

class Client:
    def __init__(self) -> None:
        self.OPENAI_KEY = OPENAI_KEY
        self.HUGGINGFACE_TOKEN = HUGGINGFACE_TOKEN
        self.all_messages = []

    def set_key(self, openai_key):
        self.OPENAI_KEY = openai_key
        return self.OPENAI_KEY

    def set_token(self, huggingface_token):
        self.HUGGINGFACE_TOKEN = huggingface_token
        return self.HUGGINGFACE_TOKEN
    
    def add_message(self, content, role):
        message = {"role": role, "content": content}
        self.all_messages.append(message)

    def extract_medias(self, message):
        url_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(jpg|jpeg|tiff|gif|png|flac|wav|mp4)")
        urls = [match.group(0) for match in url_pattern.finditer(message)]
        
        image_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(jpg|jpeg|tiff|gif|png)")
        audio_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(flac|wav)")
        video_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(mp4)")

        image_urls = [match.group(0) for match in image_pattern.finditer(message)]
        audio_urls = [match.group(0) for match in audio_pattern.finditer(message)]
        video_urls = [match.group(0) for match in video_pattern.finditer(message)]
        
        return urls, image_urls, audio_urls, video_urls

    def add_text(self, messages, message):
        if not self.OPENAI_KEY or not self.OPENAI_KEY.startswith("sk-") or not self.HUGGINGFACE_TOKEN or not self.HUGGINGFACE_TOKEN.startswith("hf_"):
            return messages, "Please set your OpenAI API key and Hugging Face token first!"
        
        self.add_message(message, "user")
        messages.append((message, None))
        urls, image_urls, audio_urls, video_urls = self.extract_medias(message)

        for image_url in image_urls:
            image_url = "public/" + image_url if not image_url.startswith("http") else image_url
            image = load_image(image_url)
            name = f"public/images/{str(uuid.uuid4())[:4]}.jpg"
            image.save(name)
            messages.append(((f"{name}",), None))
        
        for audio_url in audio_urls:
            audio_url = "public/" + audio_url if not audio_url.startswith("http") else audio_url
            ext = audio_url.split(".")[-1]
            name = f"public/audios/{str(uuid.uuid4())[:4]}.{ext}"
            response = requests.get(audio_url)
            with open(name, "wb") as f:
                f.write(response.content)
            messages.append(((f"{name}",), None))
        
        for video_url in video_urls:
            video_url = "public/" + video_url if not video_url.startswith("http") else video_url
            ext = video_url.split(".")[-1]
            name = f"public/videos/{str(uuid.uuid4())[:4]}.{ext}"
            response = requests.get(video_url)
            with open(name, "wb") as f:
                f.write(response.content)
            messages.append(((f"{name}",), None))
        
        return messages, ""

    def bot(self, messages):
        if not self.OPENAI_KEY or not self.OPENAI_KEY.startswith("sk-") or not self.HUGGINGFACE_TOKEN or not self.HUGGINGFACE_TOKEN.startswith("hf_"):
            return messages, {}
        
        message, results = chat_huggingface(self.all_messages, self.OPENAI_KEY, self.HUGGINGFACE_TOKEN)
        _, image_urls, audio_urls, video_urls = self.extract_medias(message)
        self.add_message(message, "assistant")
        messages[-1][1] = message

        for image_url in image_urls:
            if not image_url.startswith("http"):
                image_url = "public/" + image_url.replace("public/", "")
            messages.append((None, f