vilarin commited on
Commit
00adabe
1 Parent(s): 5c3a975

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -128
app.py CHANGED
@@ -1,59 +1,24 @@
1
  import os
2
- import signal
3
- import threading
4
  import time
5
- import subprocess
6
- import asyncio
7
-
8
- OLLAMA = os.path.expanduser("~/ollama")
9
- process = None
10
- OLLAMA_SERVICE_THREAD = None
11
-
12
- if not os.path.exists(OLLAMA):
13
- subprocess.run("curl -L https://ollama.com/download/ollama-linux-amd64 -o ~/ollama", shell=True)
14
- os.chmod(OLLAMA, 0o755)
15
-
16
- def ollama_service_thread():
17
- global process
18
- process = subprocess.Popen("~/ollama serve", shell=True, preexec_fn=os.setsid)
19
- process.wait()
20
-
21
- def terminate():
22
- global process, OLLAMA_SERVICE_THREAD
23
- if process:
24
- os.killpg(os.getpgid(process.pid), signal.SIGTERM)
25
- if OLLAMA_SERVICE_THREAD:
26
- OLLAMA_SERVICE_THREAD.join()
27
- process = None
28
- OLLAMA_SERVICE_THREAD = None
29
- print("Ollama service stopped.")
30
-
31
- # Uncomment and modify the model to what you want locally
32
- # model = "moondream"
33
- # model = os.environ.get("MODEL")
34
-
35
- # subprocess.run(f"~/ollama pull {model}", shell=True)
36
-
37
- import ollama
38
  import gradio as gr
39
- from ollama import AsyncClient
40
- client = AsyncClient(host='http://localhost:11434', timeout=120)
41
 
 
42
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
43
 
44
- TITLE = "<h1><center>ollama-Chat</center></h1>"
45
 
46
- DESCRIPTION = f"""
47
  <center>
48
- <p>Feel free to test models with ollama.
49
- <br>
50
- First run please type <em>/init</em> to launch process.
51
- <br>
52
- Type <em>/pull model_name</em> to pull model.
53
- </p>
54
  </center>
55
  """
56
 
 
57
  CSS = """
58
  .duplicate-button {
59
  margin: auto !important;
@@ -65,86 +30,68 @@ h3 {
65
  text-align: center;
66
  }
67
  """
68
- INIT_SIGN = ""
69
-
70
- def init():
71
- global OLLAMA_SERVICE_THREAD
72
- OLLAMA_SERVICE_THREAD = threading.Thread(target=ollama_service_thread)
73
- OLLAMA_SERVICE_THREAD.start()
74
- print("Giving ollama serve a moment")
75
- time.sleep(10)
76
- global INIT_SIGN
77
- INIT_SIGN = "FINISHED"
78
-
79
- def ollama_func(command):
80
- if " " in command:
81
- c1, c2 = command.split(" ")
82
- else:
83
- c1 = command
84
- c2 = ""
85
- function_map = {
86
- "/init": init,
87
- "/pull": lambda: ollama.pull(c2),
88
- "/list": ollama.list,
89
- "/bye": terminate,
90
- }
91
- if c1 in function_map:
92
- function_map.get(c1)()
93
- return "Running..."
94
- else:
95
- return "No supported command."
96
-
97
- def launch():
98
- global OLLAMA_SERVICE_THREAD
99
- OLLAMA_SERVICE_THREAD = threading.Thread(target=ollama_service_thread)
100
- OLLAMA_SERVICE_THREAD.start()
101
-
102
-
103
- async def stream_chat(message: str, history: list, model: str, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
104
- print(f"message: {message}")
105
  conversation = []
106
  for prompt, answer in history:
107
  conversation.extend([
108
  {"role": "user", "content": prompt},
109
  {"role": "assistant", "content": answer},
110
  ])
 
111
  conversation.append({"role": "user", "content": message})
112
-
113
- print(f"Conversation is -\n{conversation}")
 
 
114
 
115
- if message.startswith("/"):
116
- resp = ollama_func(message)
117
- yield resp
118
- else:
119
- if not INIT_SIGN:
120
- yield "Please initialize Ollama"
121
- else:
122
- if not process:
123
- launch()
124
- print("Giving ollama serve a moment")
125
- time.sleep(10)
126
-
127
- buffer = ""
128
- async for part in await client.chat(
129
- model=model,
130
- stream=True,
131
- messages=conversation,
132
- keep_alive="60s",
133
- options={
134
- 'num_predict': max_new_tokens,
135
- 'temperature': temperature,
136
- 'top_p': top_p,
137
- 'top_k': top_k,
138
- 'repeat_penalty': penalty,
139
- 'low_vram': True,
140
- },
141
- ):
142
- buffer += part['message']['content']
143
- yield buffer
144
-
145
- chatbot = gr.Chatbot(height=600, placeholder=DESCRIPTION)
146
-
147
- with gr.Blocks(css=CSS, theme="soft") as demo:
148
  gr.HTML(TITLE)
149
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
150
  gr.ChatInterface(
@@ -153,32 +100,27 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
153
  fill_height=True,
154
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
155
  additional_inputs=[
156
- gr.Textbox(
157
- value="qwen2:0.5b",
158
- label="Model",
159
- render=False,
160
- ),
161
  gr.Slider(
162
  minimum=0,
163
  maximum=1,
164
  step=0.1,
165
- value=0.8,
166
  label="Temperature",
167
  render=False,
168
  ),
169
  gr.Slider(
170
  minimum=128,
171
- maximum=2048,
172
  step=1,
173
  value=1024,
174
- label="Max New Tokens",
175
  render=False,
176
  ),
177
  gr.Slider(
178
  minimum=0.0,
179
  maximum=1.0,
180
  step=0.1,
181
- value=0.8,
182
  label="top_p",
183
  render=False,
184
  ),
@@ -194,7 +136,7 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
194
  minimum=0.0,
195
  maximum=2.0,
196
  step=0.1,
197
- value=1.0,
198
  label="Repetition penalty",
199
  render=False,
200
  ),
@@ -210,4 +152,4 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
210
 
211
 
212
  if __name__ == "__main__":
213
- demo.launch()
 
1
  import os
 
 
2
  import time
3
+ import spaces
4
+ import torch
5
+ from transformers import OlmoeForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import gradio as gr
7
+ from threading import Thread
 
8
 
9
+ MODEL_LIST = ["allenai/OLMoE-1B-7B-0924-Instruct"]
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+ MODEL = os.environ.get("MODEL_ID")
12
 
13
+ TITLE = "<h1><center>OLMoE</center></h1>"
14
 
15
+ PLACEHOLDER = """
16
  <center>
17
+ <p>Fully open, state-of-the-art Mixture of Expert model with 1.3 billion active and 6.9 billion total parameters.</p>
 
 
 
 
 
18
  </center>
19
  """
20
 
21
+
22
  CSS = """
23
  .duplicate-button {
24
  margin: auto !important;
 
30
  text-align: center;
31
  }
32
  """
33
+
34
+ device = "cuda" # for GPU usage or "cpu" for CPU usage
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
37
+ model = OlmoeForCausalLM.from_pretrained(
38
+ MODEL,
39
+ torch_dtype=torch.bfloat16,
40
+ device_map="auto",
41
+ ignore_mismatched_sizes=True)
42
+
43
+ @spaces.GPU()
44
+ def stream_chat(
45
+ message: str,
46
+ history: list,
47
+ temperature: float = 0.3,
48
+ max_new_tokens: int = 1024,
49
+ top_p: float = 1.0,
50
+ top_k: int = 20,
51
+ penalty: float = 1.2,
52
+ ):
53
+ print(f'message: {message}')
54
+ print(f'history: {history}')
55
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  conversation = []
57
  for prompt, answer in history:
58
  conversation.extend([
59
  {"role": "user", "content": prompt},
60
  {"role": "assistant", "content": answer},
61
  ])
62
+
63
  conversation.append({"role": "user", "content": message})
64
+
65
+ input_text=tokenizer.apply_chat_template(conversation, tokenize=False)
66
+ inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
67
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
68
 
69
+ generate_kwargs = dict(
70
+ input_ids=inputs,
71
+ max_new_tokens = max_new_tokens,
72
+ do_sample = False if temperature == 0 else True,
73
+ top_p = top_p,
74
+ top_k = top_k,
75
+ temperature = temperature,
76
+ streamer=streamer,
77
+ repetition_penalty=penalty,
78
+ pad_token_id = 1,
79
+ eos_token_id = 50279,
80
+ )
81
+
82
+ with torch.no_grad():
83
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
84
+ thread.start()
85
+
86
+ buffer = ""
87
+ for new_text in streamer:
88
+ buffer += new_text
89
+ yield buffer
90
+
91
+
92
+ chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
93
+
94
+ with gr.Blocks(css=CSS, theme="Nymbo/Nymbo_Theme") as demo:
 
 
 
 
 
 
 
95
  gr.HTML(TITLE)
96
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
97
  gr.ChatInterface(
 
100
  fill_height=True,
101
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
102
  additional_inputs=[
 
 
 
 
 
103
  gr.Slider(
104
  minimum=0,
105
  maximum=1,
106
  step=0.1,
107
+ value=0.3,
108
  label="Temperature",
109
  render=False,
110
  ),
111
  gr.Slider(
112
  minimum=128,
113
+ maximum=8192,
114
  step=1,
115
  value=1024,
116
+ label="Max new tokens",
117
  render=False,
118
  ),
119
  gr.Slider(
120
  minimum=0.0,
121
  maximum=1.0,
122
  step=0.1,
123
+ value=1.0,
124
  label="top_p",
125
  render=False,
126
  ),
 
136
  minimum=0.0,
137
  maximum=2.0,
138
  step=0.1,
139
+ value=1.2,
140
  label="Repetition penalty",
141
  render=False,
142
  ),
 
152
 
153
 
154
  if __name__ == "__main__":
155
+ demo.launch()