Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ import sys
|
|
6 |
|
7 |
# Install required packages
|
8 |
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "einops", "accelerate", "torch", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
|
9 |
-
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
10 |
|
11 |
from transformers import OlmoeForCausalLM, AutoTokenizer, TextIteratorStreamer
|
12 |
from threading import Thread
|
@@ -19,10 +19,10 @@ try:
|
|
19 |
model = OlmoeForCausalLM.from_pretrained(
|
20 |
model_name,
|
21 |
trust_remote_code=True,
|
22 |
-
torch_dtype=torch.
|
23 |
low_cpu_mem_usage=True,
|
24 |
device_map="auto",
|
25 |
-
_attn_implementation="flash_attention_2" # Enable Flash Attention 2
|
26 |
).to(DEVICE)
|
27 |
model.gradient_checkpointing_enable() # Enable gradient checkpointing
|
28 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
@@ -88,7 +88,7 @@ css = """
|
|
88 |
"""
|
89 |
|
90 |
with gr.Blocks(css=css) as demo:
|
91 |
-
gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE (
|
92 |
chatbot = gr.Chatbot(elem_id="output")
|
93 |
msg = gr.Textbox(label="Meow")
|
94 |
with gr.Row():
|
|
|
6 |
|
7 |
# Install required packages
|
8 |
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "einops", "accelerate", "torch", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
|
9 |
+
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
10 |
|
11 |
from transformers import OlmoeForCausalLM, AutoTokenizer, TextIteratorStreamer
|
12 |
from threading import Thread
|
|
|
19 |
model = OlmoeForCausalLM.from_pretrained(
|
20 |
model_name,
|
21 |
trust_remote_code=True,
|
22 |
+
torch_dtype=torch.bfloat16, # Using float16 for lower precision
|
23 |
low_cpu_mem_usage=True,
|
24 |
device_map="auto",
|
25 |
+
#_attn_implementation="flash_attention_2" # Enable Flash Attention 2
|
26 |
).to(DEVICE)
|
27 |
model.gradient_checkpointing_enable() # Enable gradient checkpointing
|
28 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
88 |
"""
|
89 |
|
90 |
with gr.Blocks(css=css) as demo:
|
91 |
+
gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE (CPU experiment)")
|
92 |
chatbot = gr.Chatbot(elem_id="output")
|
93 |
msg = gr.Textbox(label="Meow")
|
94 |
with gr.Row():
|