Tuchuanhuhuhu commited on
Commit
53518f7
·
1 Parent(s): 67d913f

优化llama模型

Browse files
Files changed (2) hide show
  1. modules/models.py +17 -9
  2. requirements_advanced.txt +1 -1
modules/models.py CHANGED
@@ -340,9 +340,14 @@ class LLaMA_Client(BaseLLMModel):
340
  self.end_string = "\n\n"
341
 
342
  def _get_llama_style_input(self):
343
- history = [x["content"] for x in self.history]
344
- context = "\n".join(history)
345
- context += "\nOutput:"
 
 
 
 
 
346
  return context
347
 
348
  def get_answer_at_once(self):
@@ -365,14 +370,15 @@ class LLaMA_Client(BaseLLMModel):
365
  def get_answer_stream_iter(self):
366
  context = self._get_llama_style_input()
367
  partial_text = ""
368
- for i in range(self.max_generation_token):
 
369
  input_dataset = self.dataset.from_dict(
370
  {"type": "text_only", "instances": [{"text": context+partial_text}]}
371
  )
372
  output_dataset = self.inferencer.inference(
373
  model=self.model,
374
  dataset=input_dataset,
375
- max_new_tokens=1,
376
  temperature=self.temperature,
377
  )
378
  response = output_dataset.to_dict()["instances"][0]["text"]
@@ -402,9 +408,11 @@ class ModelManager:
402
  dont_change_lora_selector = False
403
  if model_type != ModelType.OpenAI:
404
  config.local_embedding = True
 
405
  model = None
406
  try:
407
  if model_type == ModelType.OpenAI:
 
408
  model = OpenAIClient(
409
  model_name=model_name,
410
  api_key=access_key,
@@ -413,15 +421,17 @@ class ModelManager:
413
  top_p=top_p,
414
  )
415
  elif model_type == ModelType.ChatGLM:
 
416
  model = ChatGLM_Client(model_name)
417
  elif model_type == ModelType.LLaMA and lora_model_path == "":
418
- msg = "现在请选择LoRA模型"
419
  logging.info(msg)
420
  lora_selector_visibility = True
421
  if os.path.isdir("lora"):
422
  lora_choices = get_file_names("lora", plain=True, filetypes=[""])
423
  lora_choices = ["No LoRA"] + lora_choices
424
  elif model_type == ModelType.LLaMA and lora_model_path != "":
 
425
  dont_change_lora_selector = True
426
  if lora_model_path == "No LoRA":
427
  lora_model_path = None
@@ -429,15 +439,13 @@ class ModelManager:
429
  else:
430
  msg += f" + {lora_model_path}"
431
  model = LLaMA_Client(model_name, lora_model_path)
432
- pass
433
  elif model_type == ModelType.Unknown:
434
  raise ValueError(f"未知模型: {model_name}")
435
  logging.info(msg)
436
  except Exception as e:
437
  logging.error(e)
438
  msg = f"{STANDARD_ERROR_MSG}: {e}"
439
- if model is not None:
440
- self.model = model
441
  if dont_change_lora_selector:
442
  return msg
443
  else:
 
340
  self.end_string = "\n\n"
341
 
342
  def _get_llama_style_input(self):
343
+ history = []
344
+ for x in self.history:
345
+ if x["role"] == "user":
346
+ history.append(f"Input: {x['content']}")
347
+ else:
348
+ history.append(f"Output: {x['content']}")
349
+ context = "\n\n".join(history)
350
+ context += "\n\nOutput: "
351
  return context
352
 
353
  def get_answer_at_once(self):
 
370
  def get_answer_stream_iter(self):
371
  context = self._get_llama_style_input()
372
  partial_text = ""
373
+ step = 1
374
+ for _ in range(0, self.max_generation_token, step):
375
  input_dataset = self.dataset.from_dict(
376
  {"type": "text_only", "instances": [{"text": context+partial_text}]}
377
  )
378
  output_dataset = self.inferencer.inference(
379
  model=self.model,
380
  dataset=input_dataset,
381
+ max_new_tokens=step,
382
  temperature=self.temperature,
383
  )
384
  response = output_dataset.to_dict()["instances"][0]["text"]
 
408
  dont_change_lora_selector = False
409
  if model_type != ModelType.OpenAI:
410
  config.local_embedding = True
411
+ self.model = None
412
  model = None
413
  try:
414
  if model_type == ModelType.OpenAI:
415
+ logging.info(f"正在加载OpenAI模型: {model_name}")
416
  model = OpenAIClient(
417
  model_name=model_name,
418
  api_key=access_key,
 
421
  top_p=top_p,
422
  )
423
  elif model_type == ModelType.ChatGLM:
424
+ logging.info(f"正在加载ChatGLM模型: {model_name}")
425
  model = ChatGLM_Client(model_name)
426
  elif model_type == ModelType.LLaMA and lora_model_path == "":
427
+ msg = f"现在请为 {model_name} 选择LoRA模型"
428
  logging.info(msg)
429
  lora_selector_visibility = True
430
  if os.path.isdir("lora"):
431
  lora_choices = get_file_names("lora", plain=True, filetypes=[""])
432
  lora_choices = ["No LoRA"] + lora_choices
433
  elif model_type == ModelType.LLaMA and lora_model_path != "":
434
+ logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
435
  dont_change_lora_selector = True
436
  if lora_model_path == "No LoRA":
437
  lora_model_path = None
 
439
  else:
440
  msg += f" + {lora_model_path}"
441
  model = LLaMA_Client(model_name, lora_model_path)
 
442
  elif model_type == ModelType.Unknown:
443
  raise ValueError(f"未知模型: {model_name}")
444
  logging.info(msg)
445
  except Exception as e:
446
  logging.error(e)
447
  msg = f"{STANDARD_ERROR_MSG}: {e}"
448
+ self.model = model
 
449
  if dont_change_lora_selector:
450
  return msg
451
  else:
requirements_advanced.txt CHANGED
@@ -2,6 +2,6 @@ transformers
2
  torch
3
  icetk
4
  protobuf==3.19.0
5
- git+https://github.com/OptimalScale/LMFlow.git#egg=lmflow
6
  cpm-kernels
7
  sentence_transformers
 
2
  torch
3
  icetk
4
  protobuf==3.19.0
5
+ git+https://github.com/OptimalScale/LMFlow.git
6
  cpm-kernels
7
  sentence_transformers