iofu728 commited on
Commit
ad9d4f6
1 Parent(s): 00e9e5c

Feature(MInference): add minference

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -81,20 +81,21 @@ def chat_llama3_8b(message: str,
81
  str: The generated response.
82
  """
83
  global model
84
- subprocess.run(
85
- "pip install pycuda==2023.1",
86
- shell=True,
87
- )
88
- if "has_patch" not in model.__dict__:
89
- from minference import MInference
90
- minference_patch = MInference("minference", model_name)
91
- model = minference_patch(model)
92
  conversation = []
93
  for user, assistant in history:
94
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
95
  conversation.append({"role": "user", "content": message})
96
 
97
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
98
 
99
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
100
 
 
81
  str: The generated response.
82
  """
83
  global model
 
 
 
 
 
 
 
 
84
  conversation = []
85
  for user, assistant in history:
86
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
87
  conversation.append({"role": "user", "content": message})
88
 
89
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
90
+ print(model.device)
91
+ # subprocess.run(
92
+ # "pip install pycuda==2023.1",
93
+ # shell=True,
94
+ # )
95
+ # if "has_patch" not in model.__dict__:
96
+ # from minference import MInference
97
+ # minference_patch = MInference("minference", model_name)
98
+ # model = minference_patch(model)
99
 
100
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
101