Annalyn Ng commited on
Commit
1d09c47
1 Parent(s): 449ebd8

test barplot

Browse files
Files changed (3) hide show
  1. .vscode/settings.json +6 -0
  2. app.py +30 -12
  3. requirements.txt +2 -1
.vscode/settings.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "[python]": {
3
+ "editor.defaultFormatter": "ms-python.black-formatter"
4
+ },
5
+ "python.formatting.provider": "none"
6
+ }
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForMaskedLM
4
 
@@ -9,12 +10,13 @@ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
9
  model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
10
  mask_token = tokenizer.mask_token
11
 
 
12
  def add_mask(target_word, text):
13
- text_mask = text.replace(target_word, mask_token)
14
- return text_mask
15
 
16
- def eval_prob(target_word, text):
17
 
 
18
  text_mask = add_mask(target_word, text)
19
  # Get index of target_word
20
  idx = tokenizer.encode(target_word)[2]
@@ -22,11 +24,11 @@ def eval_prob(target_word, text):
22
  # Get logits
23
  inputs = tokenizer(text_mask, return_tensors="pt")
24
  token_logits = model(**inputs).logits
25
-
26
  # Find the location of the MASK and extract its logits
27
  mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
28
  mask_token_logits = token_logits[0, mask_token_index, :]
29
-
30
  # Convert logits to softmax probability
31
  logits = mask_token_logits[0].tolist()
32
  probs = torch.nn.functional.softmax(torch.tensor([logits]), dim=1)[0]
@@ -36,18 +38,34 @@ def eval_prob(target_word, text):
36
 
37
  return result
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  gr.Interface(
40
  fn=eval_prob,
41
  inputs=[
42
- gr.Textbox(
43
- label="词语",
44
- placeholder="夸大"),
45
- gr.Textbox(
46
- label="造句",
47
- placeholder=f"我们使用生成式人工智能已经很长时间了,所以他们最近的媒体报道可能被夸大了。"),
48
  ],
49
  examples=[
50
- ["夸大", "我们使用生成式人工智能已经很长时间了,所以他们最近的媒体报道可能被夸大了。"],
51
  ],
52
  outputs="number",
53
  title="Chinese Sentence Grading",
 
1
  import gradio as gr
2
+ import pandas as pd
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForMaskedLM
5
 
 
10
  model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
11
  mask_token = tokenizer.mask_token
12
 
13
+
14
  def add_mask(target_word, text):
15
+ text_mask = text.replace(target_word, mask_token)
16
+ return text_mask
17
 
 
18
 
19
+ def eval_prob(target_word, text):
20
  text_mask = add_mask(target_word, text)
21
  # Get index of target_word
22
  idx = tokenizer.encode(target_word)[2]
 
24
  # Get logits
25
  inputs = tokenizer(text_mask, return_tensors="pt")
26
  token_logits = model(**inputs).logits
27
+
28
  # Find the location of the MASK and extract its logits
29
  mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
30
  mask_token_logits = token_logits[0, mask_token_index, :]
31
+
32
  # Convert logits to softmax probability
33
  logits = mask_token_logits[0].tolist()
34
  probs = torch.nn.functional.softmax(torch.tensor([logits]), dim=1)[0]
 
38
 
39
  return result
40
 
41
+
42
+ # test barplot
43
+ simple = pd.DataFrame(
44
+ {
45
+ "item": ["A", "B", "C", "D", "E", "F", "G", "H", "I"],
46
+ "inventory": [28, 55, 43, 91, 81, 53, 19, 87, 52],
47
+ }
48
+ )
49
+
50
+ css = (
51
+ "footer {display: none !important;} .gradio-container {min-height: 0px !important;}"
52
+ )
53
+
54
+ with gr.Blocks(css=css) as demo:
55
+ gr.BarPlot(value=simple, x="item", y="inventory", title="Simple Bar Plot").style(
56
+ container=False,
57
+ )
58
+ demo.launch(share=True)
59
+
60
+
61
  gr.Interface(
62
  fn=eval_prob,
63
  inputs=[
64
+ gr.Textbox(label="词语", placeholder="夸大"),
65
+ gr.Textbox(label="造句", placeholder=f"我们使用生成式人工智能已经很长时间了,所以最近的媒体报道可能被夸大了。"),
 
 
 
 
66
  ],
67
  examples=[
68
+ ["夸大", "我们使用生成式人工智能已经很长时间了,所以最近的媒体报道可能被夸大了。"],
69
  ],
70
  outputs="number",
71
  title="Chinese Sentence Grading",
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  --extra-index-url https://download.pytorch.org/whl/cu113
2
  torch
3
- transformers
 
 
1
  --extra-index-url https://download.pytorch.org/whl/cu113
2
  torch
3
+ transformers
4
+ pandas