Qifan Zhang
commited on
Commit
•
dd2409d
1
Parent(s):
613e689
feat: add pooling: cls/mean
Browse files- app.py +9 -3
- utils/models.py +56 -1
- utils/pipeline.py +8 -8
app.py
CHANGED
@@ -21,6 +21,7 @@ def read_data(filepath: str) -> Optional[pd.DataFrame]:
|
|
21 |
def process(
|
22 |
task_name: str,
|
23 |
model_name: str,
|
|
|
24 |
text: str,
|
25 |
file=None,
|
26 |
) -> (None, pd.DataFrame, str):
|
@@ -37,9 +38,9 @@ def process(
|
|
37 |
|
38 |
# process
|
39 |
if task_name == 'Originality':
|
40 |
-
df = pipeline.p0_originality(df, model_name)
|
41 |
elif task_name == 'Flexibility':
|
42 |
-
df = pipeline.p1_flexibility(df, model_name)
|
43 |
else:
|
44 |
raise Exception('Task not supported')
|
45 |
|
@@ -62,6 +63,11 @@ model_name_dropdown = gr.components.Dropdown(
|
|
62 |
value=list_models[0],
|
63 |
choices=list_models
|
64 |
)
|
|
|
|
|
|
|
|
|
|
|
65 |
text_input = gr.components.Textbox(
|
66 |
value=open('data/example_xlm.csv', 'r').read(),
|
67 |
lines=10,
|
@@ -75,7 +81,7 @@ file_output = gr.components.File(label='Output File', file_types=['.csv', '.xlsx
|
|
75 |
|
76 |
app = gr.Interface(
|
77 |
fn=process,
|
78 |
-
inputs=[task_name_dropdown, model_name_dropdown, text_input, file_input],
|
79 |
outputs=[text_output, dataframe_output, file_output],
|
80 |
description=open('data/description.txt', 'r').read()
|
81 |
)
|
|
|
21 |
def process(
|
22 |
task_name: str,
|
23 |
model_name: str,
|
24 |
+
pooling: str,
|
25 |
text: str,
|
26 |
file=None,
|
27 |
) -> (None, pd.DataFrame, str):
|
|
|
38 |
|
39 |
# process
|
40 |
if task_name == 'Originality':
|
41 |
+
df = pipeline.p0_originality(df, model_name, pooling)
|
42 |
elif task_name == 'Flexibility':
|
43 |
+
df = pipeline.p1_flexibility(df, model_name, pooling)
|
44 |
else:
|
45 |
raise Exception('Task not supported')
|
46 |
|
|
|
63 |
value=list_models[0],
|
64 |
choices=list_models
|
65 |
)
|
66 |
+
pooling_dropdown = gr.components.Dropdown(
|
67 |
+
label='Pooling',
|
68 |
+
value='mean',
|
69 |
+
choices=['mean', 'cls']
|
70 |
+
)
|
71 |
text_input = gr.components.Textbox(
|
72 |
value=open('data/example_xlm.csv', 'r').read(),
|
73 |
lines=10,
|
|
|
81 |
|
82 |
app = gr.Interface(
|
83 |
fn=process,
|
84 |
+
inputs=[task_name_dropdown, model_name_dropdown, pooling_dropdown, text_input, file_input],
|
85 |
outputs=[text_output, dataframe_output, file_output],
|
86 |
description=open('data/description.txt', 'r').read()
|
87 |
)
|
utils/models.py
CHANGED
@@ -2,6 +2,7 @@ from functools import lru_cache
|
|
2 |
|
3 |
import torch
|
4 |
from sentence_transformers import SentenceTransformer
|
|
|
5 |
|
6 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
7 |
|
@@ -10,7 +11,9 @@ list_models = [
|
|
10 |
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
|
11 |
'sentence-transformers/all-mpnet-base-v2',
|
12 |
'sentence-transformers/all-MiniLM-L12-v2',
|
13 |
-
'cyclone/simcse-chinese-roberta-wwm-ext'
|
|
|
|
|
14 |
]
|
15 |
|
16 |
|
@@ -18,8 +21,60 @@ class SBert:
|
|
18 |
def __init__(self, path):
|
19 |
print(f'Loading model from {path} ...')
|
20 |
self.model = SentenceTransformer(path, device=DEVICE)
|
|
|
|
|
21 |
|
22 |
@lru_cache(maxsize=10000)
|
23 |
def __call__(self, x) -> torch.Tensor:
|
24 |
y = self.model.encode(x, convert_to_tensor=True)
|
25 |
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import torch
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
+
from transformers import AutoTokenizer, AutoModel
|
6 |
|
7 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
8 |
|
|
|
11 |
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
|
12 |
'sentence-transformers/all-mpnet-base-v2',
|
13 |
'sentence-transformers/all-MiniLM-L12-v2',
|
14 |
+
'cyclone/simcse-chinese-roberta-wwm-ext',
|
15 |
+
'bert-base-chinese',
|
16 |
+
'IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese',
|
17 |
]
|
18 |
|
19 |
|
|
|
21 |
def __init__(self, path):
|
22 |
print(f'Loading model from {path} ...')
|
23 |
self.model = SentenceTransformer(path, device=DEVICE)
|
24 |
+
# from pprint import pprint
|
25 |
+
# pprint(self.model.__dict__)
|
26 |
|
27 |
@lru_cache(maxsize=10000)
|
28 |
def __call__(self, x) -> torch.Tensor:
|
29 |
y = self.model.encode(x, convert_to_tensor=True)
|
30 |
return y
|
31 |
+
|
32 |
+
|
33 |
+
class ModelWithPooling:
|
34 |
+
def __init__(self, path):
|
35 |
+
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
36 |
+
self.model = AutoModel.from_pretrained(path)
|
37 |
+
|
38 |
+
@lru_cache(maxsize=10000)
|
39 |
+
@torch.no_grad()
|
40 |
+
def __call__(self, text: str, pooling='mean'):
|
41 |
+
inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
|
42 |
+
outputs = self.model(**inputs, output_hidden_states=True)
|
43 |
+
|
44 |
+
if pooling == 'cls':
|
45 |
+
o = outputs.last_hidden_state[:, 0] # [b, h]
|
46 |
+
|
47 |
+
elif pooling == 'pooler':
|
48 |
+
o = outputs.pooler_output # [b, h]
|
49 |
+
|
50 |
+
elif pooling in ['mean', 'last-avg']:
|
51 |
+
last = outputs.last_hidden_state.transpose(1, 2) # [b, h, s]
|
52 |
+
o = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [b, h]
|
53 |
+
|
54 |
+
elif pooling == 'first-last-avg':
|
55 |
+
first = outputs.hidden_states[1].transpose(1, 2) # [b, h, s]
|
56 |
+
last = outputs.hidden_states[-1].transpose(1, 2) # [b, h, s]
|
57 |
+
first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [b, h]
|
58 |
+
last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [b, h]
|
59 |
+
avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [b, 2, h]
|
60 |
+
o = torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) # [b, h]
|
61 |
+
|
62 |
+
else:
|
63 |
+
raise Exception(f'Unknown pooling {pooling}')
|
64 |
+
|
65 |
+
o = o.squeeze(0)
|
66 |
+
return o
|
67 |
+
|
68 |
+
|
69 |
+
def test_sbert():
|
70 |
+
m = SBert('bert-base-chinese')
|
71 |
+
o = m('hello')
|
72 |
+
print(o.size())
|
73 |
+
assert o.size() == (768,)
|
74 |
+
|
75 |
+
|
76 |
+
def test_hf_model():
|
77 |
+
m = ModelWithPooling('IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese')
|
78 |
+
o = m('hello', pooling='cls')
|
79 |
+
print(o.size())
|
80 |
+
assert o.size() == (768,)
|
utils/pipeline.py
CHANGED
@@ -3,10 +3,10 @@ from typing import List
|
|
3 |
import pandas as pd
|
4 |
from sentence_transformers.util import cos_sim
|
5 |
|
6 |
-
from utils.models import
|
7 |
|
8 |
|
9 |
-
def p0_originality(df: pd.DataFrame, model_name: str) -> pd.DataFrame:
|
10 |
"""
|
11 |
row-wise
|
12 |
:param df:
|
@@ -15,11 +15,11 @@ def p0_originality(df: pd.DataFrame, model_name: str) -> pd.DataFrame:
|
|
15 |
"""
|
16 |
assert 'prompt' in df.columns
|
17 |
assert 'response' in df.columns
|
18 |
-
model =
|
19 |
|
20 |
def get_cos_sim(prompt: str, response: str) -> float:
|
21 |
-
prompt_vec = model(prompt)
|
22 |
-
response_vec = model(response)
|
23 |
score = cos_sim(prompt_vec, response_vec).item()
|
24 |
return score
|
25 |
|
@@ -27,7 +27,7 @@ def p0_originality(df: pd.DataFrame, model_name: str) -> pd.DataFrame:
|
|
27 |
return df
|
28 |
|
29 |
|
30 |
-
def p1_flexibility(df: pd.DataFrame, model_name: str) -> pd.DataFrame:
|
31 |
"""
|
32 |
group-wise
|
33 |
:param df:
|
@@ -37,10 +37,10 @@ def p1_flexibility(df: pd.DataFrame, model_name: str) -> pd.DataFrame:
|
|
37 |
assert 'prompt' in df.columns
|
38 |
assert 'response' in df.columns
|
39 |
assert 'id' in df.columns
|
40 |
-
model =
|
41 |
|
42 |
def get_flexibility(responses: List[str]) -> float:
|
43 |
-
responses_vec = [model(_) for _ in responses]
|
44 |
score = 0
|
45 |
for i in range(len(responses_vec) - 1):
|
46 |
score += 1 - cos_sim(responses_vec[i], responses_vec[i + 1]).item()
|
|
|
3 |
import pandas as pd
|
4 |
from sentence_transformers.util import cos_sim
|
5 |
|
6 |
+
from utils.models import ModelWithPooling
|
7 |
|
8 |
|
9 |
+
def p0_originality(df: pd.DataFrame, model_name: str, pooling: str) -> pd.DataFrame:
|
10 |
"""
|
11 |
row-wise
|
12 |
:param df:
|
|
|
15 |
"""
|
16 |
assert 'prompt' in df.columns
|
17 |
assert 'response' in df.columns
|
18 |
+
model = ModelWithPooling(model_name)
|
19 |
|
20 |
def get_cos_sim(prompt: str, response: str) -> float:
|
21 |
+
prompt_vec = model(text=prompt, pooling=pooling)
|
22 |
+
response_vec = model(text=response, pooling=pooling)
|
23 |
score = cos_sim(prompt_vec, response_vec).item()
|
24 |
return score
|
25 |
|
|
|
27 |
return df
|
28 |
|
29 |
|
30 |
+
def p1_flexibility(df: pd.DataFrame, model_name: str, pooling: str) -> pd.DataFrame:
|
31 |
"""
|
32 |
group-wise
|
33 |
:param df:
|
|
|
37 |
assert 'prompt' in df.columns
|
38 |
assert 'response' in df.columns
|
39 |
assert 'id' in df.columns
|
40 |
+
model = ModelWithPooling(model_name)
|
41 |
|
42 |
def get_flexibility(responses: List[str]) -> float:
|
43 |
+
responses_vec = [model(text=_, pooling=pooling) for _ in responses]
|
44 |
score = 0
|
45 |
for i in range(len(responses_vec) - 1):
|
46 |
score += 1 - cos_sim(responses_vec[i], responses_vec[i + 1]).item()
|