iofu728 commited on
Commit
24083d5
1 Parent(s): 1c322b1

Feature(MInference): update the pycuda

Browse files
app.py CHANGED
@@ -5,18 +5,13 @@ subprocess.run(
5
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
6
  shell=True,
7
  )
8
- subprocess.run(
9
- "pip install pycuda==2023.1",
10
- shell=True,
11
- )
12
 
13
  import gradio as gr
14
  import os
15
  import spaces
16
- from transformers import GemmaTokenizer, AutoModelForCausalLM
17
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
18
  from threading import Thread
19
- from minference import MInference
20
 
21
  # Set an environment variable
22
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
@@ -63,8 +58,6 @@ h1 {
63
  model_name = "gradientai/Llama-3-8B-Instruct-262k"
64
  tokenizer = AutoTokenizer.from_pretrained(model_name)
65
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") # to("cuda:0")
66
- minference_patch = MInference("minference", model_name)
67
- model = minference_patch(model)
68
 
69
  terminators = [
70
  tokenizer.eos_token_id,
@@ -87,6 +80,15 @@ def chat_llama3_8b(message: str,
87
  Returns:
88
  str: The generated response.
89
  """
 
 
 
 
 
 
 
 
 
90
  conversation = []
91
  for user, assistant in history:
92
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
 
5
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
6
  shell=True,
7
  )
 
 
 
 
8
 
9
  import gradio as gr
10
  import os
11
  import spaces
12
+ from transformers import AutoModelForCausalLM
13
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
14
  from threading import Thread
 
15
 
16
  # Set an environment variable
17
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
58
  model_name = "gradientai/Llama-3-8B-Instruct-262k"
59
  tokenizer = AutoTokenizer.from_pretrained(model_name)
60
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") # to("cuda:0")
 
 
61
 
62
  terminators = [
63
  tokenizer.eos_token_id,
 
80
  Returns:
81
  str: The generated response.
82
  """
83
+ if "has_patch" not in model.__dict__:
84
+ from minference import MInference
85
+ global model
86
+ subprocess.run(
87
+ "pip install pycuda==2023.1",
88
+ shell=True,
89
+ )
90
+ minference_patch = MInference("minference", model_name)
91
+ model = minference_patch(model)
92
  conversation = []
93
  for user, assistant in history:
94
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
minference/modules/minference_forward.py CHANGED
@@ -1,10 +1,16 @@
 
 
 
1
  import inspect
2
  import json
3
  import os
4
  from importlib import import_module
5
 
6
  from transformers.models.llama.modeling_llama import *
7
- from vllm.attention.backends.flash_attn import *
 
 
 
8
 
9
  from ..ops.block_sparse_flash_attention import block_sparse_attention
10
  from ..ops.pit_sparse_flash_attention_v2 import vertical_slash_sparse_attention
 
1
+ # Copyright (c) 2024 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
  import inspect
5
  import json
6
  import os
7
  from importlib import import_module
8
 
9
  from transformers.models.llama.modeling_llama import *
10
+ from transformers.utils.import_utils import _is_package_available
11
+
12
+ if _is_package_available("vllm"):
13
+ from vllm.attention.backends.flash_attn import *
14
 
15
  from ..ops.block_sparse_flash_attention import block_sparse_attention
16
  from ..ops.pit_sparse_flash_attention_v2 import vertical_slash_sparse_attention
minference/patch.py CHANGED
@@ -780,6 +780,7 @@ def minference_patch(model, config):
780
  model.model, model.model.__class__
781
  )
782
  model.forward = forward_llama_for_causal_lm.__get__(model, model.__class__)
 
783
 
784
  print("Patched model for minference..")
785
  return model
 
780
  model.model, model.model.__class__
781
  )
782
  model.forward = forward_llama_for_causal_lm.__get__(model, model.__class__)
783
+ model.has_patch = True
784
 
785
  print("Patched model for minference..")
786
  return model