Qifan Zhang commited on
Commit
3f6f474
1 Parent(s): b8d9710

update optional models, add text input

Browse files
Files changed (2) hide show
  1. app.py +34 -14
  2. utils/models.py +3 -1
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Optional
2
 
3
  import gradio as gr
@@ -7,8 +8,6 @@ from utils.similarity import batch_cos_sim
7
 
8
 
9
  def read_data(filepath: str) -> Optional[pd.DataFrame]:
10
- if not filepath:
11
- return None
12
  if filepath.endswith('.xlsx'):
13
  df = pd.read_csv(filepath)
14
  elif filepath.endswith('.csv'):
@@ -19,35 +18,56 @@ def read_data(filepath: str) -> Optional[pd.DataFrame]:
19
 
20
 
21
  def process(model_name: str,
22
- prompt: str,
23
  file=None,
24
  ):
25
- df = read_data(file.name)
 
 
 
 
 
 
26
  df = batch_cos_sim(df, model_name)
27
  path = 'output.csv'
28
  df.to_csv(path, index=False, encoding='utf-8-sig')
29
- return df.to_markdown(), path
30
 
31
 
32
  model_name_input = gr.components.Textbox(
33
  value='paraphrase-multilingual-MiniLM-L12-v2',
34
  lines=1,
35
- type="text"
 
 
 
 
 
 
 
 
 
 
36
  )
37
 
38
- prompt_input = gr.components.Textbox(
39
- value='prompt,response',
40
  lines=10,
41
- type="text"
 
 
 
 
 
42
  )
43
 
44
- file_output = gr.components.File(label="Output File",
45
- file_count="single",
46
- file_types=["", ".", ".csv", ".xls", ".xlsx"])
47
 
48
  app = gr.Interface(
49
  fn=process,
50
- inputs=[model_name_input, prompt_input, "file" ],
51
- outputs=["text", file_output]
52
  )
53
  app.launch()
 
1
+ from io import StringIO
2
  from typing import Optional
3
 
4
  import gradio as gr
 
8
 
9
 
10
  def read_data(filepath: str) -> Optional[pd.DataFrame]:
 
 
11
  if filepath.endswith('.xlsx'):
12
  df = pd.read_csv(filepath)
13
  elif filepath.endswith('.csv'):
 
18
 
19
 
20
  def process(model_name: str,
21
+ text: str,
22
  file=None,
23
  ):
24
+ if file:
25
+ df = read_data(file.name)
26
+ elif text:
27
+ string_io = StringIO(text)
28
+ df = pd.read_csv(string_io)
29
+ else:
30
+ raise Exception('No input provided')
31
  df = batch_cos_sim(df, model_name)
32
  path = 'output.csv'
33
  df.to_csv(path, index=False, encoding='utf-8-sig')
34
+ return str(df), path
35
 
36
 
37
  model_name_input = gr.components.Textbox(
38
  value='paraphrase-multilingual-MiniLM-L12-v2',
39
  lines=1,
40
+ type='text'
41
+ )
42
+
43
+ model_name_option = gr.components.Dropdown(
44
+ label='Model Name',
45
+ value='paraphrase-multilingual-MiniLM-L12-v2',
46
+ choices=[
47
+ 'paraphrase-multilingual-MiniLM-L12-v2',
48
+ 'paraphrase-multilingual-mpnet-base-v2',
49
+ 'cyclone/simcse-chinese-roberta-wwm-ext'
50
+ ]
51
  )
52
 
53
+ text_input = gr.components.Textbox(
54
+ value='prompt,response\n',
55
  lines=10,
56
+ type='text'
57
+ )
58
+
59
+ text_output = gr.components.Textbox(
60
+ label='Output',
61
+ type='text'
62
  )
63
 
64
+ file_output = gr.components.File(label='Output File',
65
+ file_count='single',
66
+ file_types=['', '.', '.csv', '.xls', '.xlsx'])
67
 
68
  app = gr.Interface(
69
  fn=process,
70
+ inputs=[model_name_option, text_input, 'file'],
71
+ outputs=[text_output, file_output]
72
  )
73
  app.launch()
utils/models.py CHANGED
@@ -1,13 +1,15 @@
1
  from functools import lru_cache
2
 
 
3
  import torch
4
  from sentence_transformers import SentenceTransformer
5
- import numpy as np
6
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
7
 
8
 
9
  class SBert:
10
  def __init__(self, path):
 
11
  self.model = SentenceTransformer(path, device=DEVICE)
12
 
13
  @lru_cache(maxsize=10000)
 
1
  from functools import lru_cache
2
 
3
+ import numpy as np
4
  import torch
5
  from sentence_transformers import SentenceTransformer
6
+
7
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
8
 
9
 
10
  class SBert:
11
  def __init__(self, path):
12
+ print(f'Loading model from {path} ...')
13
  self.model = SentenceTransformer(path, device=DEVICE)
14
 
15
  @lru_cache(maxsize=10000)