Qifan Zhang commited on
Commit
dd2409d
1 Parent(s): 613e689

feat: add pooling: cls/mean

Browse files
Files changed (3) hide show
  1. app.py +9 -3
  2. utils/models.py +56 -1
  3. utils/pipeline.py +8 -8
app.py CHANGED
@@ -21,6 +21,7 @@ def read_data(filepath: str) -> Optional[pd.DataFrame]:
21
  def process(
22
  task_name: str,
23
  model_name: str,
 
24
  text: str,
25
  file=None,
26
  ) -> (None, pd.DataFrame, str):
@@ -37,9 +38,9 @@ def process(
37
 
38
  # process
39
  if task_name == 'Originality':
40
- df = pipeline.p0_originality(df, model_name)
41
  elif task_name == 'Flexibility':
42
- df = pipeline.p1_flexibility(df, model_name)
43
  else:
44
  raise Exception('Task not supported')
45
 
@@ -62,6 +63,11 @@ model_name_dropdown = gr.components.Dropdown(
62
  value=list_models[0],
63
  choices=list_models
64
  )
 
 
 
 
 
65
  text_input = gr.components.Textbox(
66
  value=open('data/example_xlm.csv', 'r').read(),
67
  lines=10,
@@ -75,7 +81,7 @@ file_output = gr.components.File(label='Output File', file_types=['.csv', '.xlsx
75
 
76
  app = gr.Interface(
77
  fn=process,
78
- inputs=[task_name_dropdown, model_name_dropdown, text_input, file_input],
79
  outputs=[text_output, dataframe_output, file_output],
80
  description=open('data/description.txt', 'r').read()
81
  )
 
21
  def process(
22
  task_name: str,
23
  model_name: str,
24
+ pooling: str,
25
  text: str,
26
  file=None,
27
  ) -> (None, pd.DataFrame, str):
 
38
 
39
  # process
40
  if task_name == 'Originality':
41
+ df = pipeline.p0_originality(df, model_name, pooling)
42
  elif task_name == 'Flexibility':
43
+ df = pipeline.p1_flexibility(df, model_name, pooling)
44
  else:
45
  raise Exception('Task not supported')
46
 
 
63
  value=list_models[0],
64
  choices=list_models
65
  )
66
+ pooling_dropdown = gr.components.Dropdown(
67
+ label='Pooling',
68
+ value='mean',
69
+ choices=['mean', 'cls']
70
+ )
71
  text_input = gr.components.Textbox(
72
  value=open('data/example_xlm.csv', 'r').read(),
73
  lines=10,
 
81
 
82
  app = gr.Interface(
83
  fn=process,
84
+ inputs=[task_name_dropdown, model_name_dropdown, pooling_dropdown, text_input, file_input],
85
  outputs=[text_output, dataframe_output, file_output],
86
  description=open('data/description.txt', 'r').read()
87
  )
utils/models.py CHANGED
@@ -2,6 +2,7 @@ from functools import lru_cache
2
 
3
  import torch
4
  from sentence_transformers import SentenceTransformer
 
5
 
6
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
7
 
@@ -10,7 +11,9 @@ list_models = [
10
  'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
11
  'sentence-transformers/all-mpnet-base-v2',
12
  'sentence-transformers/all-MiniLM-L12-v2',
13
- 'cyclone/simcse-chinese-roberta-wwm-ext'
 
 
14
  ]
15
 
16
 
@@ -18,8 +21,60 @@ class SBert:
18
  def __init__(self, path):
19
  print(f'Loading model from {path} ...')
20
  self.model = SentenceTransformer(path, device=DEVICE)
 
 
21
 
22
  @lru_cache(maxsize=10000)
23
  def __call__(self, x) -> torch.Tensor:
24
  y = self.model.encode(x, convert_to_tensor=True)
25
  return y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import torch
4
  from sentence_transformers import SentenceTransformer
5
+ from transformers import AutoTokenizer, AutoModel
6
 
7
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
8
 
 
11
  'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
12
  'sentence-transformers/all-mpnet-base-v2',
13
  'sentence-transformers/all-MiniLM-L12-v2',
14
+ 'cyclone/simcse-chinese-roberta-wwm-ext',
15
+ 'bert-base-chinese',
16
+ 'IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese',
17
  ]
18
 
19
 
 
21
  def __init__(self, path):
22
  print(f'Loading model from {path} ...')
23
  self.model = SentenceTransformer(path, device=DEVICE)
24
+ # from pprint import pprint
25
+ # pprint(self.model.__dict__)
26
 
27
  @lru_cache(maxsize=10000)
28
  def __call__(self, x) -> torch.Tensor:
29
  y = self.model.encode(x, convert_to_tensor=True)
30
  return y
31
+
32
+
33
+ class ModelWithPooling:
34
+ def __init__(self, path):
35
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
36
+ self.model = AutoModel.from_pretrained(path)
37
+
38
+ @lru_cache(maxsize=10000)
39
+ @torch.no_grad()
40
+ def __call__(self, text: str, pooling='mean'):
41
+ inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
42
+ outputs = self.model(**inputs, output_hidden_states=True)
43
+
44
+ if pooling == 'cls':
45
+ o = outputs.last_hidden_state[:, 0] # [b, h]
46
+
47
+ elif pooling == 'pooler':
48
+ o = outputs.pooler_output # [b, h]
49
+
50
+ elif pooling in ['mean', 'last-avg']:
51
+ last = outputs.last_hidden_state.transpose(1, 2) # [b, h, s]
52
+ o = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [b, h]
53
+
54
+ elif pooling == 'first-last-avg':
55
+ first = outputs.hidden_states[1].transpose(1, 2) # [b, h, s]
56
+ last = outputs.hidden_states[-1].transpose(1, 2) # [b, h, s]
57
+ first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [b, h]
58
+ last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [b, h]
59
+ avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [b, 2, h]
60
+ o = torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) # [b, h]
61
+
62
+ else:
63
+ raise Exception(f'Unknown pooling {pooling}')
64
+
65
+ o = o.squeeze(0)
66
+ return o
67
+
68
+
69
+ def test_sbert():
70
+ m = SBert('bert-base-chinese')
71
+ o = m('hello')
72
+ print(o.size())
73
+ assert o.size() == (768,)
74
+
75
+
76
+ def test_hf_model():
77
+ m = ModelWithPooling('IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese')
78
+ o = m('hello', pooling='cls')
79
+ print(o.size())
80
+ assert o.size() == (768,)
utils/pipeline.py CHANGED
@@ -3,10 +3,10 @@ from typing import List
3
  import pandas as pd
4
  from sentence_transformers.util import cos_sim
5
 
6
- from utils.models import SBert
7
 
8
 
9
- def p0_originality(df: pd.DataFrame, model_name: str) -> pd.DataFrame:
10
  """
11
  row-wise
12
  :param df:
@@ -15,11 +15,11 @@ def p0_originality(df: pd.DataFrame, model_name: str) -> pd.DataFrame:
15
  """
16
  assert 'prompt' in df.columns
17
  assert 'response' in df.columns
18
- model = SBert(model_name)
19
 
20
  def get_cos_sim(prompt: str, response: str) -> float:
21
- prompt_vec = model(prompt)
22
- response_vec = model(response)
23
  score = cos_sim(prompt_vec, response_vec).item()
24
  return score
25
 
@@ -27,7 +27,7 @@ def p0_originality(df: pd.DataFrame, model_name: str) -> pd.DataFrame:
27
  return df
28
 
29
 
30
- def p1_flexibility(df: pd.DataFrame, model_name: str) -> pd.DataFrame:
31
  """
32
  group-wise
33
  :param df:
@@ -37,10 +37,10 @@ def p1_flexibility(df: pd.DataFrame, model_name: str) -> pd.DataFrame:
37
  assert 'prompt' in df.columns
38
  assert 'response' in df.columns
39
  assert 'id' in df.columns
40
- model = SBert(model_name)
41
 
42
  def get_flexibility(responses: List[str]) -> float:
43
- responses_vec = [model(_) for _ in responses]
44
  score = 0
45
  for i in range(len(responses_vec) - 1):
46
  score += 1 - cos_sim(responses_vec[i], responses_vec[i + 1]).item()
 
3
  import pandas as pd
4
  from sentence_transformers.util import cos_sim
5
 
6
+ from utils.models import ModelWithPooling
7
 
8
 
9
+ def p0_originality(df: pd.DataFrame, model_name: str, pooling: str) -> pd.DataFrame:
10
  """
11
  row-wise
12
  :param df:
 
15
  """
16
  assert 'prompt' in df.columns
17
  assert 'response' in df.columns
18
+ model = ModelWithPooling(model_name)
19
 
20
  def get_cos_sim(prompt: str, response: str) -> float:
21
+ prompt_vec = model(text=prompt, pooling=pooling)
22
+ response_vec = model(text=response, pooling=pooling)
23
  score = cos_sim(prompt_vec, response_vec).item()
24
  return score
25
 
 
27
  return df
28
 
29
 
30
+ def p1_flexibility(df: pd.DataFrame, model_name: str, pooling: str) -> pd.DataFrame:
31
  """
32
  group-wise
33
  :param df:
 
37
  assert 'prompt' in df.columns
38
  assert 'response' in df.columns
39
  assert 'id' in df.columns
40
+ model = ModelWithPooling(model_name)
41
 
42
  def get_flexibility(responses: List[str]) -> float:
43
+ responses_vec = [model(text=_, pooling=pooling) for _ in responses]
44
  score = 0
45
  for i in range(len(responses_vec) - 1):
46
  score += 1 - cos_sim(responses_vec[i], responses_vec[i + 1]).item()