nisten commited on
Commit
ee12bf1
1 Parent(s): a622fef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -6,7 +6,7 @@ import sys
6
 
7
  # Install required packages
8
  subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "einops", "accelerate", "torch", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
9
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
 
11
  from transformers import OlmoeForCausalLM, AutoTokenizer, TextIteratorStreamer
12
  from threading import Thread
@@ -19,10 +19,10 @@ try:
19
  model = OlmoeForCausalLM.from_pretrained(
20
  model_name,
21
  trust_remote_code=True,
22
- torch_dtype=torch.float16, # Using float16 for lower precision
23
  low_cpu_mem_usage=True,
24
  device_map="auto",
25
- _attn_implementation="flash_attention_2" # Enable Flash Attention 2
26
  ).to(DEVICE)
27
  model.gradient_checkpointing_enable() # Enable gradient checkpointing
28
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -88,7 +88,7 @@ css = """
88
  """
89
 
90
  with gr.Blocks(css=css) as demo:
91
- gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE (Now with Flash Attention 2!)")
92
  chatbot = gr.Chatbot(elem_id="output")
93
  msg = gr.Textbox(label="Meow")
94
  with gr.Row():
 
6
 
7
  # Install required packages
8
  subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "einops", "accelerate", "torch", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
9
+ #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
 
11
  from transformers import OlmoeForCausalLM, AutoTokenizer, TextIteratorStreamer
12
  from threading import Thread
 
19
  model = OlmoeForCausalLM.from_pretrained(
20
  model_name,
21
  trust_remote_code=True,
22
+ torch_dtype=torch.bfloat16, # Using float16 for lower precision
23
  low_cpu_mem_usage=True,
24
  device_map="auto",
25
+ #_attn_implementation="flash_attention_2" # Enable Flash Attention 2
26
  ).to(DEVICE)
27
  model.gradient_checkpointing_enable() # Enable gradient checkpointing
28
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
88
  """
89
 
90
  with gr.Blocks(css=css) as demo:
91
+ gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE (CPU experiment)")
92
  chatbot = gr.Chatbot(elem_id="output")
93
  msg = gr.Textbox(label="Meow")
94
  with gr.Row():