Spaces:
Runtime error
Runtime error
Update qa_model.py
Browse files- qa_model.py +45 -0
qa_model.py
CHANGED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from transformers import AutoModel, AutoConfig
|
5 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
6 |
+
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
7 |
+
|
8 |
+
class QAModel():
|
9 |
+
def __init__(self, checkpoint="google/flan-t5-xl"):
|
10 |
+
self.checkpoint = checkpoint
|
11 |
+
self.tmpdir = f"{self.checkpoint.split('/')[-1]}-sharded"
|
12 |
+
|
13 |
+
def store_sharded_model(self):
|
14 |
+
tmpdir = self.tmpdir
|
15 |
+
|
16 |
+
checkpoint = self.checkpoint
|
17 |
+
|
18 |
+
if not os.path.exists(tmpdir):
|
19 |
+
os.mkdir(tmpdir)
|
20 |
+
print(f"Directory created - {tmpdir}")
|
21 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
|
22 |
+
print(f"Model loaded - {checkpoint}")
|
23 |
+
model.save_pretrained(tmpdir, max_shard_size="200MB")
|
24 |
+
|
25 |
+
def load_sharded_model(self):
|
26 |
+
tmpdir = self.tmpdir
|
27 |
+
if not os.path.exists(tmpdir):
|
28 |
+
self.store_sharded_model()
|
29 |
+
|
30 |
+
checkpoint = self.checkpoint
|
31 |
+
|
32 |
+
|
33 |
+
config = AutoConfig.from_pretrained(checkpoint)
|
34 |
+
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
36 |
+
with init_empty_weights():
|
37 |
+
model = AutoModelForSeq2SeqLM.from_config(config)
|
38 |
+
# model = AutoModelForSeq2SeqLM.from_pretrained(tmpdir)
|
39 |
+
|
40 |
+
model = load_checkpoint_and_dispatch(model, checkpoint=tmpdir, device_map="auto")
|
41 |
+
return model, tokenizer
|
42 |
+
|
43 |
+
def query_model(self, model, tokenizer, query):
|
44 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
45 |
+
return tokenizer.batch_decode(model.generate(**tokenizer(query, return_tensors='pt').to(device)), skip_special_tokens=True)[0]
|