File size: 2,948 Bytes
b649a34
 
 
 
 
 
 
 
156cdff
b649a34
 
 
7c52de8
114aae8
b649a34
38da75e
b649a34
 
 
870b2f7
 
 
b649a34
 
 
 
 
 
 
 
 
 
 
 
 
589d8b5
b649a34
 
 
 
 
156cdff
c94c8e7
b649a34
 
 
 
 
 
 
eee0d11
 
c94c8e7
eee0d11
b649a34
eee0d11
 
 
639b64b
eee0d11
 
 
b649a34
 
 
6f32094
c94c8e7
6f32094
 
b649a34
 
 
 
 
 
 
 
 
 
 
 
 
6f32094
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import json 
import gradio as gr
import os
import requests

hf_token = os.getenv('HF_TOKEN')
api_url = os.getenv('API_URL') 
headers = {
    'Authorization': 'Bearer ' + hf_token,
    'Content-Type': 'application/json',
}

system_message = "\nTesting by KelvinLo UD\n"
title = "Llama-2 Chatbot"
description = """
Demo by Kelvin Lo, UD
"""
css = """.toast-wrap { display: none !important } """
examples=[
    'Can you write a javascripts sample to print the time now?',
    '可以用中文字作詩比我?',
    "Write a 100-word article on 'Benefits of private AI Server'",
    ]


def predict(message, chatbot):
    
    input_prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n "
    for interaction in chatbot:
        input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s> [INST] "

    input_prompt = input_prompt + str(message) + " [/INST] "

    data = {
        "inputs": input_prompt,
        "parameters": {"max_new_tokens": 1000,
                      "do_sample":True,
                      "top_p":0.6,
                      "temperature":0.9,}
    }

    response = requests.post(api_url, headers=headers, data=json.dumps(data), stream=True)
    #print(response)
    
    partial_message = ""
    for line in response.iter_lines():
        if line:  # filter out keep-alive new lines
            # Decode from bytes to string
            decoded_line = line.decode('utf-8')


            json_line = decoded_line
            #print(decoded_line)
            
            # Remove 'data:' prefix 
            #if decoded_line.startswith('data:'):
            #    json_line = decoded_line[5:]  # Exclude the first 5 characters ('data:')
            #else:
                #gr.Warning(f"This line does not start with 'data:': {decoded_line}")
            #    json_line = decoded_line
            #    print(decoded_line)
            #    continue

            # Load as JSON
            try:
                json_obj = json.loads(json_line)[0]
                #print (json_obj)
                if 'generated_text' in json_obj:
                    partial_message = partial_message + json_obj['generated_text']#['token']['text']
                    yield partial_message
                elif 'error' in json_obj:
                    yield json_obj['error'] + '. Please refresh and try again with an appropriate smaller input prompt.'
                else:
                    gr.Warning(f"The key 'token' does not exist in this JSON object: {json_obj}")

            except json.JSONDecodeError:
                gr.Warning(f"This line is not valid JSON: {json_line}")
                continue
            except KeyError as e:
                gr.Warning(f"KeyError: {e} occurred for JSON object: {json_obj}")
                continue

gr.ChatInterface(predict, title=title, description=description, css=css, examples=examples, cache_examples=True).queue(concurrency_count=75).launch()