masanorihirano commited on
Commit
7a42c18
1 Parent(s): bed8c52
Files changed (2) hide show
  1. app.py +7 -29
  2. pyproject.toml +1 -1
app.py CHANGED
@@ -9,16 +9,12 @@ from typing import Union
9
  import gradio as gr
10
  import requests
11
  import torch
 
12
  from fastchat.conversation import Conversation
13
- from fastchat.conversation import SeparatorStyle
14
- from fastchat.conversation import get_conv_template
15
- from fastchat.conversation import register_conv_template
16
- from fastchat.model.model_adapter import BaseAdapter
17
- from fastchat.model.model_adapter import load_model
18
- from fastchat.model.model_adapter import model_adapters
19
  from fastchat.serve.cli import SimpleChatIO
20
- from fastchat.serve.inference import compress_module
21
  from fastchat.serve.inference import generate_stream
 
22
  from huggingface_hub import Repository
23
  from huggingface_hub import snapshot_download
24
  from peft import LoraConfig
@@ -30,24 +26,8 @@ from transformers import LlamaTokenizer
30
  from transformers import PreTrainedModel
31
  from transformers import PreTrainedTokenizerBase
32
 
33
-
34
- class LLaMAdapter(BaseAdapter):
35
- "Model adapater for vicuna-v1.1"
36
-
37
- def match(self, model_path: str):
38
- return "llama" in model_path
39
-
40
- def load_model(self, model_path: str, from_pretrained_kwargs: dict):
41
- tokenizer = LlamaTokenizer.from_pretrained(model_path, use_fast=False)
42
- model = LlamaForCausalLM.from_pretrained(
43
- model_path,
44
- low_cpu_mem_usage=True,
45
- **from_pretrained_kwargs,
46
- )
47
- return model, tokenizer
48
-
49
-
50
- model_adapters.insert(-1, LLaMAdapter())
51
 
52
 
53
  def load_lora_model(
@@ -67,12 +47,10 @@ def load_lora_model(
67
  device=device,
68
  num_gpus=num_gpus,
69
  max_gpu_memory=max_gpu_memory,
70
- load_8bit=False,
71
  cpu_offloading=cpu_offloading,
72
  debug=debug,
73
  )
74
- if load_8bit:
75
- compress_module(model)
76
  if lora_weight is not None:
77
  # model = PeftModelForCausalLM.from_pretrained(model, model_path, **kwargs)
78
  config = LoraConfig.from_pretrained(lora_weight)
@@ -217,7 +195,7 @@ def evaluate(
217
  gr.update(interactive=True),
218
  )
219
 
220
- conv = get_conv_template()
221
 
222
  conv.append_message(conv.roles[0], instruction)
223
  conv.append_message(conv.roles[1], None)
 
9
  import gradio as gr
10
  import requests
11
  import torch
12
+ import transformers
13
  from fastchat.conversation import Conversation
14
+ from fastchat.conversation import get_default_conv_template
 
 
 
 
 
15
  from fastchat.serve.cli import SimpleChatIO
 
16
  from fastchat.serve.inference import generate_stream
17
+ from fastchat.serve.inference import load_model
18
  from huggingface_hub import Repository
19
  from huggingface_hub import snapshot_download
20
  from peft import LoraConfig
 
26
  from transformers import PreTrainedModel
27
  from transformers import PreTrainedTokenizerBase
28
 
29
+ transformers.AutoTokenizer.from_pretrained = LlamaTokenizer.from_pretrained
30
+ transformers.AutoModelForCausalLM.from_pretrained = LlamaForCausalLM.from_pretrained
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  def load_lora_model(
 
47
  device=device,
48
  num_gpus=num_gpus,
49
  max_gpu_memory=max_gpu_memory,
50
+ load_8bit=True,
51
  cpu_offloading=cpu_offloading,
52
  debug=debug,
53
  )
 
 
54
  if lora_weight is not None:
55
  # model = PeftModelForCausalLM.from_pretrained(model, model_path, **kwargs)
56
  config = LoraConfig.from_pretrained(lora_weight)
 
195
  gr.update(interactive=True),
196
  )
197
 
198
+ conv = get_default_conv_template(BASE_MODEL).copy()
199
 
200
  conv.append_message(conv.roles[0], instruction)
201
  conv.append_message(conv.roles[1], None)
pyproject.toml CHANGED
@@ -15,7 +15,7 @@ huggingface-hub = "^0.14.1"
15
  sentencepiece = "^0.1.99"
16
  bitsandbytes = "^0.38.1"
17
  accelerate = "^0.19.0"
18
- fschat = "0.2.8"
19
  transformers = "4.28.1"
20
 
21
 
 
15
  sentencepiece = "^0.1.99"
16
  bitsandbytes = "^0.38.1"
17
  accelerate = "^0.19.0"
18
+ fschat = "0.2.3"
19
  transformers = "4.28.1"
20
 
21