Azure99 commited on
Commit
1c29543
·
verified ·
1 Parent(s): eaaf5ea

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import gradio as gr
3
+ import spaces
4
+ from huggingface_hub import hf_hub_download
5
+ from llama_cpp import Llama
6
+ from transformers import AutoTokenizer
7
+
8
+ MAX_NEW_TOKENS = 8192
9
+ MODEL_NAME = "Azure99/Blossom-V6.1-32B"
10
+ MODEL_GGUF_REPO = f"{MODEL_NAME}-GGUF"
11
+ MODEL_FILE = "blossom-v6.1-32b-q8_0.gguf"
12
+ MODEL_LOCAL_DIR = "./"
13
+
14
+ hf_hub_download(repo_id=MODEL_GGUF_REPO, filename=MODEL_FILE, local_dir=MODEL_LOCAL_DIR)
15
+
16
+ llm: Llama = None
17
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
18
+
19
+
20
+ def get_messages(user, history):
21
+ try:
22
+ parsed_body = json.loads(user)
23
+ if parsed_body.get("by_json_str"):
24
+ return parsed_body["messages"]
25
+ except:
26
+ pass
27
+
28
+ messages = []
29
+ messages.extend(history or [])
30
+ messages.append({"role": "user", "content": user})
31
+ return messages
32
+
33
+
34
+ @spaces.GPU(duration=120)
35
+ def chat(user, history, temperature, top_p, repetition_penalty):
36
+ global llm
37
+ if llm is None:
38
+ llm = Llama(
39
+ model_path=MODEL_FILE, n_gpu_layers=-1, flash_attn=True, n_ctx=16384
40
+ )
41
+
42
+ messages = get_messages(user, history)
43
+ print(f"Messages: {messages}")
44
+ input_ids = tokenizer.apply_chat_template(messages)
45
+ generate_config = dict(
46
+ temperature=temperature,
47
+ top_p=top_p,
48
+ repeat_penalty=repetition_penalty,
49
+ top_k=50,
50
+ stream=True,
51
+ max_tokens=MAX_NEW_TOKENS,
52
+ )
53
+
54
+ outputs = ""
55
+ for chunk in llm(input_ids, **generate_config):
56
+ outputs += chunk["choices"][0]["text"]
57
+ yield outputs
58
+
59
+
60
+ additional_inputs = [
61
+ gr.Slider(
62
+ label="Temperature",
63
+ value=0.5,
64
+ minimum=0.0,
65
+ maximum=1.0,
66
+ step=0.05,
67
+ interactive=True,
68
+ info="Controls randomness in choosing words.",
69
+ ),
70
+ gr.Slider(
71
+ label="Top-P",
72
+ value=0.85,
73
+ minimum=0.0,
74
+ maximum=1.0,
75
+ step=0.05,
76
+ interactive=True,
77
+ info="Picks words until their combined probability is at least top_p.",
78
+ ),
79
+ gr.Slider(
80
+ label="Repetition penalty",
81
+ value=1.05,
82
+ minimum=1.0,
83
+ maximum=1.2,
84
+ step=0.01,
85
+ interactive=True,
86
+ info="Repetition Penalty: Controls how much repetition is penalized.",
87
+ ),
88
+ ]
89
+
90
+ gr.ChatInterface(
91
+ chat,
92
+ type="messages",
93
+ chatbot=gr.Chatbot(
94
+ show_label=False,
95
+ height=500,
96
+ show_copy_button=True,
97
+ render_markdown=True,
98
+ type="messages",
99
+ latex_delimiters=[{"left": "\\[", "right": "\\]", "display": True}],
100
+ ),
101
+ textbox=gr.Textbox(placeholder="", container=False, scale=7),
102
+ title=f"{MODEL_NAME} Demo",
103
+ description="Hello, I am Blossom, an open source conversational large language model.🌠"
104
+ '<a href="https://github.com/Azure99/BlossomLM">GitHub</a>',
105
+ theme="soft",
106
+ examples=[
107
+ ["Hello"],
108
+ ["What is MBTI"],
109
+ ["用Python实现二分查找"],
110
+ ["为switch写一篇小红书种草文案,带上emoji"],
111
+ ],
112
+ cache_examples=False,
113
+ additional_inputs=additional_inputs,
114
+ additional_inputs_accordion=gr.Accordion(label="Config", open=True),
115
+ ).queue().launch()