File size: 3,657 Bytes
0408f67
726ba73
 
 
fb1f1d4
0408f67
17025ed
 
0408f67
17025ed
0408f67
 
 
 
17025ed
485bfef
17025ed
6593855
 
 
 
17025ed
0408f67
 
 
 
 
 
 
 
 
17025ed
0408f67
17025ed
 
0408f67
 
 
17025ed
0408f67
 
 
 
 
 
17025ed
 
 
 
 
 
0408f67
17025ed
0408f67
 
 
 
 
17025ed
 
 
 
 
0408f67
 
 
485bfef
fb1f1d4
0408f67
 
 
 
 
 
fb1f1d4
0408f67
 
17025ed
fb1f1d4
17025ed
0408f67
17025ed
 
0408f67
 
 
 
17025ed
 
fb1f1d4
0408f67
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
os.system('pip install --upgrade pip')
os.system('pip install gradio transformers torch')

import gradio as gr
from typing import List, Optional, Tuple, Dict
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import torch

default_system = 'You are JOSIE which is an acronym for "Just an Outstandingly Smart Intelligent Entity", a private and super-intelligent AI assistant, created by Gökdeniz Gülmez.'

History = List[Tuple[str, str]]
Messages = List[Dict[str, str]]

# Load model and tokenizer
model_name = 'mlx-community/J.O.S.I.E.3-Beta12-7B-slerp-8-bit'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    ignore_mismatched_sizes=True,
)

def clear_session() -> History:
    return '', []

def modify_system_session(system: str) -> str:
    if system is None or len(system) == 0:
        system = default_system
    return system, system, []

def history_to_messages(history: History, system: str) -> Messages:
    messages = [{'role': 'system', 'content': system}]
    for h in history:
        messages.append({'role': 'user', 'content': h[0]})
        messages.append({'role': 'assistant', 'content': h[1]})
    return messages

def messages_to_history(messages: Messages) -> Tuple[str, History]:
    assert messages[0]['role'] == 'system'
    system = messages[0]['content']
    history = []
    for q, r in zip(messages[1::2], messages[2::2]):
        history.append([q['content'], r['content']])
    return system, history

def generate_response(messages: Messages) -> str:
    prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
    inputs = tokenizer(prompt, return_tensors='pt')
    outputs = model.generate(inputs['input_ids'], max_length=512, pad_token_id=tokenizer.eos_token_id)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response.split('assistant:')[-1].strip()

def model_chat(query: Optional[str], history: Optional[History], system: str) -> Tuple[str, str, History]:
    if query is None:
        query = ''
    if history is None:
        history = []
    messages = history_to_messages(history, system)
    messages.append({'role': 'user', 'content': query})
    response = generate_response(messages)
    messages.append({'role': 'assistant', 'content': response})
    system, history = messages_to_history(messages)
    return '', history, system

with gr.Blocks() as demo:
    gr.Markdown("""<center><font size=8>J.O.S.I.E.3-Beta12 Preview👾</center>""")
    gr.Markdown("""<center><font size=4>J.O.S.I.E. is also multilingual (German, Spanish, Chinese, Japanese, French)</center>""")

    with gr.Row():
        with gr.Column(scale=1):
            modify_system = gr.Button("🛠️ Set system prompt and clear history", scale=2)
        system_state = gr.Textbox(value=default_system, visible=False)
    chatbot = gr.Chatbot(label='J.O.S.I.E.3-Beta12-7B')
    textbox = gr.Textbox(lines=2, label='Input')

    with gr.Row():
        clear_history = gr.Button("🧹 Clear history")
        submit = gr.Button("🚀 Send")

    submit.click(model_chat,
                 inputs=[textbox, chatbot, system_state],
                 outputs=[textbox, chatbot, system_state],
                 concurrency_limit=5)
    clear_history.click(fn=clear_session,
                        inputs=[],
                        outputs=[textbox, chatbot])
    modify_system.click(fn=modify_system_session,
                        inputs=[system_state],
                        outputs=[system_state, system_state, chatbot])

demo.queue(api_open=False)
demo.launch(max_threads=5)