Spaces:
Runtime error
Runtime error
modified files to test remote POST
Browse files- app.py +3 -2
- app_test.py +2 -6
- utils.py +1 -48
app.py
CHANGED
@@ -12,7 +12,7 @@ os.environ["https_proxy"] = "http://127.0.0.1:7890"
|
|
12 |
|
13 |
app = Flask(__name__)
|
14 |
|
15 |
-
path_for_model = "./output/checkpoint-
|
16 |
|
17 |
args = {
|
18 |
"model_type": "gpt2",
|
@@ -20,6 +20,7 @@ args = {
|
|
20 |
"length": 80,
|
21 |
"stop_token": None,
|
22 |
"temperature": 1.0,
|
|
|
23 |
"repetition_penalty": 1.2,
|
24 |
"k": 3,
|
25 |
"p": 0.9,
|
@@ -94,4 +95,4 @@ def chat():
|
|
94 |
|
95 |
if __name__ == '__main__':
|
96 |
load_model_and_components()
|
97 |
-
app.run(host='0.0.0.0', port=
|
|
|
12 |
|
13 |
app = Flask(__name__)
|
14 |
|
15 |
+
path_for_model = "./output/gpt2_openprompt/checkpoint-4500"
|
16 |
|
17 |
args = {
|
18 |
"model_type": "gpt2",
|
|
|
20 |
"length": 80,
|
21 |
"stop_token": None,
|
22 |
"temperature": 1.0,
|
23 |
+
"length_penalty": 1.2,
|
24 |
"repetition_penalty": 1.2,
|
25 |
"k": 3,
|
26 |
"p": 0.9,
|
|
|
95 |
|
96 |
if __name__ == '__main__':
|
97 |
load_model_and_components()
|
98 |
+
app.run(host='0.0.0.0', port=10008, debug=False)
|
app_test.py
CHANGED
@@ -1,18 +1,14 @@
|
|
1 |
import requests
|
2 |
import json
|
3 |
|
4 |
-
url = 'http://localhost:
|
5 |
|
6 |
-
# 构造请求数据
|
7 |
data = {
|
8 |
-
'phrase': 'a spiece 和一只狼'
|
9 |
}
|
10 |
|
11 |
-
# 发送 POST 请求
|
12 |
response = requests.post(url, json=data)
|
13 |
|
14 |
-
# 解析响应
|
15 |
response_data = response.json()
|
16 |
|
17 |
-
# 打印响应结果
|
18 |
print(json.dumps(response_data, indent=4))
|
|
|
1 |
import requests
|
2 |
import json
|
3 |
|
4 |
+
url = 'http://localhost:10008/chat'
|
5 |
|
|
|
6 |
data = {
|
7 |
+
'phrase': 'a spiece 和一只狼'
|
8 |
}
|
9 |
|
|
|
10 |
response = requests.post(url, json=data)
|
11 |
|
|
|
12 |
response_data = response.json()
|
13 |
|
|
|
14 |
print(json.dumps(response_data, indent=4))
|
utils.py
CHANGED
@@ -1,25 +1,7 @@
|
|
1 |
import os
|
2 |
-
import json
|
3 |
|
4 |
-
from typing import Dict
|
5 |
-
from torch.utils.data import Dataset
|
6 |
-
from datasets import Dataset as AdvancedDataset
|
7 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
8 |
|
9 |
-
|
10 |
-
DEFAULT_TRAIN_DATA_NAME = "test_openprompt.json"
|
11 |
-
DEFAULT_TEST_DATA_NAME = "train_openprompt.json"
|
12 |
-
DEFAULT_DICT_DATA_NAME = "dataset_openprompt.json"
|
13 |
-
|
14 |
-
def get_open_prompt_data(path_for_data):
|
15 |
-
with open(os.path.join(path_for_data, DEFAULT_TRAIN_DATA_NAME)) as f:
|
16 |
-
train_data = json.load(f)
|
17 |
-
|
18 |
-
with open(os.path.join(path_for_data, DEFAULT_TEST_DATA_NAME)) as f:
|
19 |
-
test_data = json.load(f)
|
20 |
-
|
21 |
-
return train_data, test_data
|
22 |
-
|
23 |
def get_tok_and_model(path_for_model):
|
24 |
if not os.path.exists(path_for_model):
|
25 |
raise RuntimeError("no cached model.")
|
@@ -27,33 +9,4 @@ def get_tok_and_model(path_for_model):
|
|
27 |
tok.pad_token_id = 50256
|
28 |
# default for open-ended generation
|
29 |
model = AutoModelForCausalLM.from_pretrained(path_for_model)
|
30 |
-
return tok, model
|
31 |
-
|
32 |
-
|
33 |
-
class OpenPromptDataset(Dataset):
|
34 |
-
def __init__(self, data) -> None:
|
35 |
-
super().__init__()
|
36 |
-
self.data = data
|
37 |
-
|
38 |
-
def __len__(self):
|
39 |
-
return len(self.data)
|
40 |
-
|
41 |
-
def __getitem__(self, index):
|
42 |
-
return self.data[index]
|
43 |
-
|
44 |
-
def get_dataset(train_data, test_data):
|
45 |
-
train_dataset = OpenPromptDataset(train_data)
|
46 |
-
test_dataset = OpenPromptDataset(test_data)
|
47 |
-
return train_dataset, test_dataset
|
48 |
-
|
49 |
-
def get_dict_dataset(path_for_data):
|
50 |
-
with open(os.path.join(path_for_data, DEFAULT_DICT_DATA_NAME)) as f:
|
51 |
-
dict_data = json.load(f)
|
52 |
-
return dict_data
|
53 |
-
|
54 |
-
def get_advance_dataset(dict_data):
|
55 |
-
if not isinstance(dict_data, Dict):
|
56 |
-
raise RuntimeError("dict_data is not a dict.")
|
57 |
-
dataset = AdvancedDataset.from_dict(dict_data)
|
58 |
-
|
59 |
-
return dataset
|
|
|
1 |
import os
|
|
|
2 |
|
|
|
|
|
|
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
def get_tok_and_model(path_for_model):
|
6 |
if not os.path.exists(path_for_model):
|
7 |
raise RuntimeError("no cached model.")
|
|
|
9 |
tok.pad_token_id = 50256
|
10 |
# default for open-ended generation
|
11 |
model = AutoModelForCausalLM.from_pretrained(path_for_model)
|
12 |
+
return tok, model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|