fl399 commited on
Commit
2c3ce4f
1 Parent(s): 95ff74d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -60
app.py CHANGED
@@ -51,65 +51,66 @@ A: Let's find the row of year 2007, that's Row 3. Let's extract the numbers on R
51
 
52
  ## alpaca-lora
53
 
54
- assert (
55
- "LlamaTokenizer" in transformers._import_structure["models.llama"]
56
- ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
57
- from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
58
-
59
- tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
60
-
61
- BASE_MODEL = "decapoda-research/llama-7b-hf"
62
- LORA_WEIGHTS = "tloen/alpaca-lora-7b"
63
-
64
- if torch.cuda.is_available():
65
- device = "cuda"
66
- else:
67
- device = "cpu"
68
-
69
- try:
70
- if torch.backends.mps.is_available():
71
- device = "mps"
72
- except:
73
- pass
74
-
75
- if device == "cuda":
76
- model = LlamaForCausalLM.from_pretrained(
77
- BASE_MODEL,
78
- load_in_8bit=False,
79
- torch_dtype=torch.float16,
80
- device_map="auto",
81
- )
82
- model = PeftModel.from_pretrained(
83
- model, LORA_WEIGHTS, torch_dtype=torch.float16, force_download=True
84
- )
85
- elif device == "mps":
86
- model = LlamaForCausalLM.from_pretrained(
87
- BASE_MODEL,
88
- device_map={"": device},
89
- torch_dtype=torch.float16,
90
- )
91
- model = PeftModel.from_pretrained(
92
- model,
93
- LORA_WEIGHTS,
94
- device_map={"": device},
95
- torch_dtype=torch.float16,
96
- )
97
- else:
98
- model = LlamaForCausalLM.from_pretrained(
99
- BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
100
- )
101
- model = PeftModel.from_pretrained(
102
- model,
103
- LORA_WEIGHTS,
104
- device_map={"": device},
105
- )
106
-
107
-
108
- if device != "cpu":
109
- model.half()
110
- model.eval()
111
- if torch.__version__ >= "2":
112
- model = torch.compile(model)
 
113
 
114
 
115
  ## FLAN-UL2
@@ -156,7 +157,7 @@ def evaluate(
156
  elif llm == "flan-ul2":
157
  output = query({
158
  "inputs": prompt
159
- })[0]["generated_text"]
160
 
161
  else:
162
  RuntimeError(f"No such LLM: {llm}")
 
51
 
52
  ## alpaca-lora
53
 
54
+ # debugging...
55
+ # assert (
56
+ # "LlamaTokenizer" in transformers._import_structure["models.llama"]
57
+ # ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
58
+ # from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
59
+
60
+ # tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
61
+
62
+ # BASE_MODEL = "decapoda-research/llama-7b-hf"
63
+ # LORA_WEIGHTS = "tloen/alpaca-lora-7b"
64
+
65
+ # if torch.cuda.is_available():
66
+ # device = "cuda"
67
+ # else:
68
+ # device = "cpu"
69
+
70
+ # try:
71
+ # if torch.backends.mps.is_available():
72
+ # device = "mps"
73
+ # except:
74
+ # pass
75
+
76
+ # if device == "cuda":
77
+ # model = LlamaForCausalLM.from_pretrained(
78
+ # BASE_MODEL,
79
+ # load_in_8bit=False,
80
+ # torch_dtype=torch.float16,
81
+ # device_map="auto",
82
+ # )
83
+ # model = PeftModel.from_pretrained(
84
+ # model, LORA_WEIGHTS, torch_dtype=torch.float16, force_download=True
85
+ # )
86
+ # elif device == "mps":
87
+ # model = LlamaForCausalLM.from_pretrained(
88
+ # BASE_MODEL,
89
+ # device_map={"": device},
90
+ # torch_dtype=torch.float16,
91
+ # )
92
+ # model = PeftModel.from_pretrained(
93
+ # model,
94
+ # LORA_WEIGHTS,
95
+ # device_map={"": device},
96
+ # torch_dtype=torch.float16,
97
+ # )
98
+ # else:
99
+ # model = LlamaForCausalLM.from_pretrained(
100
+ # BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
101
+ # )
102
+ # model = PeftModel.from_pretrained(
103
+ # model,
104
+ # LORA_WEIGHTS,
105
+ # device_map={"": device},
106
+ # )
107
+
108
+
109
+ # if device != "cpu":
110
+ # model.half()
111
+ # model.eval()
112
+ # if torch.__version__ >= "2":
113
+ # model = torch.compile(model)
114
 
115
 
116
  ## FLAN-UL2
 
157
  elif llm == "flan-ul2":
158
  output = query({
159
  "inputs": prompt
160
+ })
161
 
162
  else:
163
  RuntimeError(f"No such LLM: {llm}")