fl399 commited on
Commit
184f5ef
1 Parent(s): 58d8b0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -59
app.py CHANGED
@@ -52,68 +52,69 @@ A: Let's find the row of year 2007, that's Row 3. Let's extract the numbers on R
52
  ## alpaca-lora
53
 
54
  # debugging...
55
- # assert (
56
- # "LlamaTokenizer" in transformers._import_structure["models.llama"]
57
- # ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
58
- # from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
59
-
60
- # tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
61
-
62
- # BASE_MODEL = "decapoda-research/llama-7b-hf"
63
- # LORA_WEIGHTS = "tloen/alpaca-lora-7b"
64
-
65
- # if torch.cuda.is_available():
66
- # device = "cuda"
67
- # else:
68
- # device = "cpu"
69
-
70
- # try:
71
- # if torch.backends.mps.is_available():
72
- # device = "mps"
73
- # except:
74
- # pass
75
-
76
- # if device == "cuda":
77
- # model = LlamaForCausalLM.from_pretrained(
78
- # BASE_MODEL,
79
- # load_in_8bit=False,
80
- # torch_dtype=torch.float16,
81
- # device_map="auto",
82
- # )
83
- # model = PeftModel.from_pretrained(
84
- # model, LORA_WEIGHTS, torch_dtype=torch.float16, force_download=True
85
- # )
86
- # elif device == "mps":
87
- # model = LlamaForCausalLM.from_pretrained(
88
- # BASE_MODEL,
89
- # device_map={"": device},
90
- # torch_dtype=torch.float16,
91
- # )
92
- # model = PeftModel.from_pretrained(
93
- # model,
94
- # LORA_WEIGHTS,
95
- # device_map={"": device},
96
- # torch_dtype=torch.float16,
97
- # )
98
- # else:
99
- # model = LlamaForCausalLM.from_pretrained(
100
- # BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
101
- # )
102
- # model = PeftModel.from_pretrained(
103
- # model,
104
- # LORA_WEIGHTS,
105
- # device_map={"": device},
106
- # )
107
-
108
-
109
- # if device != "cpu":
110
- # model.half()
111
- # model.eval()
112
- # if torch.__version__ >= "2":
113
- # model = torch.compile(model)
114
 
115
 
116
  ## FLAN-UL2
 
117
  TOKEN = os.environ.get("API_TOKEN", None)
118
  API_URL = "https://api-inference.huggingface.co/models/google/flan-ul2"
119
  headers = {"Authorization": f"Bearer {TOKEN}"}
 
52
  ## alpaca-lora
53
 
54
  # debugging...
55
+ assert (
56
+ "LlamaTokenizer" in transformers._import_structure["models.llama"]
57
+ ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
58
+ from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
59
+
60
+ tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
61
+
62
+ BASE_MODEL = "decapoda-research/llama-7b-hf"
63
+ LORA_WEIGHTS = "tloen/alpaca-lora-7b"
64
+
65
+ if torch.cuda.is_available():
66
+ device = "cuda"
67
+ else:
68
+ device = "cpu"
69
+
70
+ try:
71
+ if torch.backends.mps.is_available():
72
+ device = "mps"
73
+ except:
74
+ pass
75
+
76
+ if device == "cuda":
77
+ model = LlamaForCausalLM.from_pretrained(
78
+ BASE_MODEL,
79
+ load_in_8bit=False,
80
+ torch_dtype=torch.float16,
81
+ device_map="auto",
82
+ )
83
+ model = PeftModel.from_pretrained(
84
+ model, LORA_WEIGHTS, torch_dtype=torch.float16, force_download=True
85
+ )
86
+ elif device == "mps":
87
+ model = LlamaForCausalLM.from_pretrained(
88
+ BASE_MODEL,
89
+ device_map={"": device},
90
+ torch_dtype=torch.float16,
91
+ )
92
+ model = PeftModel.from_pretrained(
93
+ model,
94
+ LORA_WEIGHTS,
95
+ device_map={"": device},
96
+ torch_dtype=torch.float16,
97
+ )
98
+ else:
99
+ model = LlamaForCausalLM.from_pretrained(
100
+ BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
101
+ )
102
+ model = PeftModel.from_pretrained(
103
+ model,
104
+ LORA_WEIGHTS,
105
+ device_map={"": device},
106
+ )
107
+
108
+
109
+ if device != "cpu":
110
+ model.half()
111
+ model.eval()
112
+ if torch.__version__ >= "2":
113
+ model = torch.compile(model)
114
 
115
 
116
  ## FLAN-UL2
117
+ # in dev...
118
  TOKEN = os.environ.get("API_TOKEN", None)
119
  API_URL = "https://api-inference.huggingface.co/models/google/flan-ul2"
120
  headers = {"Authorization": f"Bearer {TOKEN}"}