ν˜•κ·œ 솑 commited on
Commit
6bd388c
β€’
1 Parent(s): e0a78a8

add Google Perspective API

Browse files

(`d9de3ce` in https://bitbucket.org/maum-system/cvpr22-demo-gradio)

.gitignore CHANGED
@@ -10,4 +10,5 @@ output_file/*
10
  *.png
11
  !background_image/*
12
  *.mkv
13
- gradio_queue.db*
 
10
  *.png
11
  !background_image/*
12
  *.mkv
13
+ gradio_queue.db*
14
+ !vacant.mp4
app.py CHANGED
@@ -7,55 +7,19 @@ TRANSLATION_APIKEY_URL = os.environ['TRANSLATION_APIKEY_URL']
7
  GOOGLE_APPLICATION_CREDENTIALS = os.environ['GOOGLE_APPLICATION_CREDENTIALS']
8
  subprocess.call(f"wget --no-check-certificate -O {GOOGLE_APPLICATION_CREDENTIALS} {TRANSLATION_APIKEY_URL}", shell=True)
9
 
 
 
10
  import gradio as gr
 
 
11
  from client_rest import RestAPIApplication
12
  from pathlib import Path
13
  import argparse
14
  import threading
15
- from translator import GoogleAuthTranslation
16
  import yaml
17
 
18
  TITLE = Path("docs/title.txt").read_text()
19
- DESCRIPTION = Path("docs/description.txt").read_text()
20
-
21
- class Translator:
22
- def __init__(self, yaml_path='lang.yaml'):
23
- self.google_translation = GoogleAuthTranslation(project_id="cvpr-2022-demonstration")
24
- with open(yaml_path) as f:
25
- self.supporting_languages = yaml.load(f, Loader=yaml.FullLoader)
26
-
27
- def _get_text_with_lang(self, text, lang):
28
- lang_detected = self.google_translation.detect(text)
29
- print(lang_detected, lang)
30
- if lang is None:
31
- lang = lang_detected
32
-
33
- if lang != lang_detected:
34
- target_text = self.google_translation.translate(text, lang=lang)
35
- else:
36
- target_text = text
37
-
38
- return target_text, lang
39
-
40
- def _convert_lang_from_index(self, lang):
41
- lang_finder = [name for name in self.supporting_languages
42
- if self.supporting_languages[name]['language'] == lang]
43
- if len(lang_finder) == 1:
44
- lang = lang_finder[0]
45
- else:
46
- raise AssertionError(f"Given language index can't be understood! | lang: {lang}")
47
-
48
- return lang
49
-
50
- def get_translation(self, text, lang, use_translation=True):
51
- lang_ = self._convert_lang_from_index(lang)
52
-
53
- if use_translation:
54
- target_text, _ = self._get_text_with_lang(text, lang_)
55
- else:
56
- target_text = text
57
-
58
- return target_text, lang
59
 
60
 
61
  class GradioApplication:
@@ -72,6 +36,7 @@ class GradioApplication:
72
  "background_image/river.mp4",
73
  "background_image/sky.mp4"]
74
 
 
75
  self.translator = Translator()
76
  self.rest_application = RestAPIApplication(rest_ip, rest_port)
77
  self.output_dir = Path("output_file")
@@ -118,24 +83,49 @@ class GradioApplication:
118
  is_video_background = False
119
 
120
  return background_data, is_video_background
 
 
 
 
121
 
122
  def infer(self, text, lang, duration_rate, action, background_index):
123
  self._counter_file_seed()
124
  print(f"File Seed: {self._file_seed}")
125
- target_text, lang_dest = self.translator.get_translation(text, lang)
126
- lang_rpc_code = self.get_lang_code(lang_dest)
127
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  background_data, is_video_background = self.get_background_data(background_index)
129
 
130
  video_data = self.rest_application.get_video(target_text, lang_rpc_code, duration_rate, action.lower(),
131
  background_data, is_video_background)
132
- print(len(video_data))
133
 
134
  video_filename = self.output_dir / f"{self._file_seed:02d}.mkv"
135
  with open(video_filename, "wb") as video_file:
136
  video_file.write(video_data)
137
 
138
- return f"Language: {lang_dest}\nText: \n{target_text}", str(video_filename)
139
 
140
  def run(self, server_port=7860, share=False):
141
  try:
@@ -176,11 +166,10 @@ def prepare_input():
176
 
177
 
178
  def prepare_output():
179
- translation_result_otuput = gr.Textbox(type="str",
180
- label="Translation Result")
181
-
182
  video_output = gr.Video(format='mp4')
183
- return [translation_result_otuput, video_output]
184
 
185
 
186
  def parse_args():
7
  GOOGLE_APPLICATION_CREDENTIALS = os.environ['GOOGLE_APPLICATION_CREDENTIALS']
8
  subprocess.call(f"wget --no-check-certificate -O {GOOGLE_APPLICATION_CREDENTIALS} {TRANSLATION_APIKEY_URL}", shell=True)
9
 
10
+ TOXICITY_THRESHOLD = float(os.getenv('TOXICITY_THRESHOLD', 0.7))
11
+
12
  import gradio as gr
13
+ from toxicity_estimator import PerspectiveAPI
14
+ from translator import Translator
15
  from client_rest import RestAPIApplication
16
  from pathlib import Path
17
  import argparse
18
  import threading
 
19
  import yaml
20
 
21
  TITLE = Path("docs/title.txt").read_text()
22
+ DESCRIPTION = Path("docs/description.md").read_text()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  class GradioApplication:
36
  "background_image/river.mp4",
37
  "background_image/sky.mp4"]
38
 
39
+ self.perspective_api = PerspectiveAPI()
40
  self.translator = Translator()
41
  self.rest_application = RestAPIApplication(rest_ip, rest_port)
42
  self.output_dir = Path("output_file")
83
  is_video_background = False
84
 
85
  return background_data, is_video_background
86
+
87
+ @staticmethod
88
+ def return_format(toxicity_prob, target_text, lang_dest, video_filename):
89
+ return {'Toxicity': toxicity_prob}, f"Language: {lang_dest}\nText: \n{target_text}", str(video_filename)
90
 
91
  def infer(self, text, lang, duration_rate, action, background_index):
92
  self._counter_file_seed()
93
  print(f"File Seed: {self._file_seed}")
94
+ toxicity_prob = 0.0
95
+ target_text = "(Sorry, it seems that the input text is too toxic.)"
96
+ lang_dest = ""
97
+ video_filename = "vacant.mp4"
98
+
99
+ # Toxicity estimation
100
+ try:
101
+ toxicity_prob = self.perspective_api.get_score(text)
102
+ except Exception as e: # when Perspective API doesn't work
103
+ pass
104
+
105
+ if toxicity_prob > TOXICITY_THRESHOLD:
106
+ return self.return_format(toxicity_prob, target_text, lang_dest, video_filename)
107
+
108
+ # Google Translate API
109
+ try:
110
+ target_text, lang_dest = self.translator.get_translation(text, lang)
111
+ lang_rpc_code = self.get_lang_code(lang_dest)
112
+ except Exception as e:
113
+ target_text = f"Error from language translation: ({e})"
114
+ lang_dest = ""
115
+ return self.return_format(toxicity_prob, target_text, lang_dest, video_filename)
116
+
117
+ # Video Inference
118
  background_data, is_video_background = self.get_background_data(background_index)
119
 
120
  video_data = self.rest_application.get_video(target_text, lang_rpc_code, duration_rate, action.lower(),
121
  background_data, is_video_background)
122
+ print(f"Video data size: {len(video_data)}")
123
 
124
  video_filename = self.output_dir / f"{self._file_seed:02d}.mkv"
125
  with open(video_filename, "wb") as video_file:
126
  video_file.write(video_data)
127
 
128
+ return {'Toxicity': toxicity_prob}, f"Language: {lang_dest}\nText: \n{target_text}", str(video_filename)
129
 
130
  def run(self, server_port=7860, share=False):
131
  try:
166
 
167
 
168
  def prepare_output():
169
+ toxicity_output = gr.Label(num_top_classes=1, label="Toxicity (from Perspective API)")
170
+ translation_result_otuput = gr.Textbox(type="str", label="Translation Result")
 
171
  video_output = gr.Video(format='mp4')
172
+ return [toxicity_output, translation_result_otuput, video_output]
173
 
174
 
175
  def parse_args():
docs/{description.txt β†’ description.md} RENAMED
@@ -3,4 +3,8 @@ You can provide the input text in one of the four languages: Chinese (Mandarin),
3
  You may also select the target language, the language of the output speech.
4
  If the input text language and the target language are different, the input text will be translated to the target language using Google Translate API.
5
 
 
 
 
 
6
  (2022.06.05.) Due to the latency from HuggingFace Spaces and video rendering, it takes 15 ~ 30 seconds to get a video result.
3
  You may also select the target language, the language of the output speech.
4
  If the input text language and the target language are different, the input text will be translated to the target language using Google Translate API.
5
 
6
+ ### Updates
7
+
8
+ (2022.06.17.) We were originally planning to support any input text. However, when checking the logs recently, we found that there were a lot of inappropriate input texts. So, we decided to filter the inputs based on toxicity using [Perspective API @Google](https://developers.perspectiveapi.com/s/). Now, if you enter a possibily toxic text, the video generation will fail. We hope you understand.
9
+
10
  (2022.06.05.) Due to the latency from HuggingFace Spaces and video rendering, it takes 15 ~ 30 seconds to get a video result.
requirements.txt CHANGED
@@ -3,4 +3,5 @@ jinja2
3
  googletrans==4.0.0-rc1
4
  PyYAML
5
  opencv-python
6
- google-cloud-translate
 
3
  googletrans==4.0.0-rc1
4
  PyYAML
5
  opencv-python
6
+ google-cloud-translate
7
+ google-api-python-client
toxicity_estimator/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from .module import PerspectiveAPI
toxicity_estimator/module.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from googleapiclient import discovery
2
+ import argparse
3
+ import json
4
+ import os
5
+
6
+ API_KEY = os.environ['PERSPECTIVE_API_KEY']
7
+
8
+ class PerspectiveAPI:
9
+ def __init__(self):
10
+ self.client = discovery.build(
11
+ "commentanalyzer",
12
+ "v1alpha1",
13
+ developerKey=API_KEY,
14
+ discoveryServiceUrl="https://commentanalyzer.googleapis.com/$discovery/rest?version=v1alpha1",
15
+ static_discovery=False,
16
+ )
17
+ @staticmethod
18
+ def _get_request(text):
19
+ return {
20
+ 'comment': {'text': text},
21
+ 'requestedAttributes': {'TOXICITY': {}}
22
+ }
23
+
24
+ def _infer(self, text):
25
+ request = self._get_request(text)
26
+ response = self.client.comments().analyze(body=request).execute()
27
+ return response
28
+
29
+ def infer(self, text):
30
+ return self._infer(text)
31
+
32
+ def get_score(self, text, label='TOXICITY'):
33
+ response = self._infer(text)
34
+ return response['attributeScores'][label]['spanScores'][0]['score']['value']
35
+
36
+
37
+ def parse_args():
38
+ parser = argparse.ArgumentParser(
39
+ description='Perspective API Test.')
40
+ parser.add_argument('-i', '--input-text', type=str, required=True)
41
+ args = parser.parse_args()
42
+ return args
43
+
44
+
45
+ if __name__ == '__main__':
46
+ args = parse_args()
47
+
48
+ perspective_api = PerspectiveAPI()
49
+ score = perspective_api.get_score(args.input_text)
50
+
51
+ print(score)
translator/__init__.py CHANGED
@@ -1 +1 @@
1
- from .v3 import GoogleAuthTranslation
1
+ from .module import Translator
translator/module.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .v3 import GoogleAuthTranslation
2
+ from pathlib import Path
3
+ import yaml
4
+
5
+
6
+ class Translator:
7
+ def __init__(self, yaml_path='./lang.yaml'):
8
+ self.google_translation = GoogleAuthTranslation(project_id="cvpr-2022-demonstration")
9
+ with open(yaml_path) as f:
10
+ self.supporting_languages = yaml.load(f, Loader=yaml.FullLoader)
11
+
12
+ def _get_text_with_lang(self, text, lang):
13
+ lang_detected = self.google_translation.detect(text)
14
+ print(lang_detected, lang)
15
+ if lang is None:
16
+ lang = lang_detected
17
+
18
+ if lang != lang_detected:
19
+ target_text = self.google_translation.translate(text, lang=lang)
20
+ else:
21
+ target_text = text
22
+
23
+ return target_text, lang
24
+
25
+ def _convert_lang_from_index(self, lang):
26
+ try:
27
+ lang_finder = [name for name in self.supporting_languages
28
+ if self.supporting_languages[name]['language'] == lang]
29
+ except Exception as e:
30
+ raise RuntimeError(e)
31
+
32
+ if len(lang_finder) == 1:
33
+ lang = lang_finder[0]
34
+ else:
35
+ raise AssertionError("Given language index can't be understood!"
36
+ f"Only one of ['Korean', 'English', 'Japanese', 'Chinese'] can be supported. | lang: {lang}")
37
+
38
+ return lang
39
+
40
+ def get_translation(self, text, lang, use_translation=True):
41
+ lang_ = self._convert_lang_from_index(lang)
42
+
43
+ if use_translation:
44
+ target_text, _ = self._get_text_with_lang(text, lang_)
45
+ else:
46
+ target_text = text
47
+
48
+ return target_text, lang
translator/v3.py CHANGED
@@ -36,7 +36,7 @@ class GoogleAuthTranslation:
36
  if self.supporting_languages[key]['google_dest'] == dest:
37
  return key
38
 
39
- raise RuntimeError(f"Detected langauge {dest} is not supported for TTS.")
40
 
41
  def translate(self, query, lang):
42
 
36
  if self.supporting_languages[key]['google_dest'] == dest:
37
  return key
38
 
39
+ raise RuntimeError(f"Detected langauge is not supported in our multilingual TTS. |\n Code: {dest} | See https://cloud.google.com/translate/docs/languages")
40
 
41
  def translate(self, query, lang):
42
 
vacant.mp4 ADDED
File without changes