Ron Au commited on
Commit
61d63d5
1 Parent(s): 0731409

Initial Commit

Browse files
Files changed (6) hide show
  1. app.py +30 -8
  2. dataset.py +19 -0
  3. index.html +0 -0
  4. index.js +89 -11
  5. inference.py +5 -4
  6. style.css +39 -5
app.py CHANGED
@@ -4,11 +4,13 @@ import requests
4
  from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
5
  from urllib.parse import parse_qs, urlparse
6
 
7
- from inference import t5_infer
 
8
 
9
  # https://huggingface.co/settings/tokens
10
  # https://huggingface.co/spaces/{username}/{space}/settings
11
- API_TOKEN = os.getenv('BIG_GAN_TOKEN')
 
12
 
13
  class RequestHandler(SimpleHTTPRequestHandler):
14
  def do_GET(self):
@@ -17,14 +19,16 @@ class RequestHandler(SimpleHTTPRequestHandler):
17
 
18
  return SimpleHTTPRequestHandler.do_GET(self)
19
 
20
- if self.path.startswith("/biggan_infer"):
21
- input = parse_qs(urlparse(self.path).query).get("input", None)[0]
 
 
22
 
23
  output = requests.request(
24
  "POST",
25
  "https://api-inference.huggingface.co/models/osanseviero/BigGAN-deep-128",
26
  headers={"Authorization": f"Bearer {API_TOKEN}"},
27
- data=json.dumps(input)
28
  )
29
 
30
  self.send_response(200)
@@ -35,10 +39,28 @@ class RequestHandler(SimpleHTTPRequestHandler):
35
 
36
  return SimpleHTTPRequestHandler
37
 
38
- elif self.path.startswith("/t5_infer"):
39
- input = parse_qs(urlparse(self.path).query).get("input", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- output = t5_infer(input)
42
 
43
  self.send_response(200)
44
  self.send_header("Content-Type", "application/json")
 
4
  from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
5
  from urllib.parse import parse_qs, urlparse
6
 
7
+ from inference import infer_t5
8
+ from dataset import query_emotion
9
 
10
  # https://huggingface.co/settings/tokens
11
  # https://huggingface.co/spaces/{username}/{space}/settings
12
+ API_TOKEN = os.getenv("BIG_GAN_TOKEN")
13
+
14
 
15
  class RequestHandler(SimpleHTTPRequestHandler):
16
  def do_GET(self):
 
19
 
20
  return SimpleHTTPRequestHandler.do_GET(self)
21
 
22
+ if self.path.startswith("/infer_biggan"):
23
+ url = urlparse(self.path)
24
+ query = parse_qs(url.query)
25
+ input = query.get("input", None)[0]
26
 
27
  output = requests.request(
28
  "POST",
29
  "https://api-inference.huggingface.co/models/osanseviero/BigGAN-deep-128",
30
  headers={"Authorization": f"Bearer {API_TOKEN}"},
31
+ data=json.dumps(input),
32
  )
33
 
34
  self.send_response(200)
 
39
 
40
  return SimpleHTTPRequestHandler
41
 
42
+ elif self.path.startswith("/infer_t5"):
43
+ url = urlparse(self.path)
44
+ query = parse_qs(url.query)
45
+ input = query.get("input", None)[0]
46
+
47
+ output = infer_t5(input)
48
+
49
+ self.send_response(200)
50
+ self.send_header("Content-Type", "application/json")
51
+ self.end_headers()
52
+
53
+ self.wfile.write(json.dumps({"output": output}).encode("utf-8"))
54
+
55
+ return SimpleHTTPRequestHandler
56
+
57
+ elif self.path.startswith("/query_emotion"):
58
+ url = urlparse(self.path)
59
+ query = parse_qs(url.query)
60
+ start = int(query.get("start", None)[0])
61
+ end = int(query.get("end", None)[0])
62
 
63
+ output = query_emotion(start, end)
64
 
65
  self.send_response(200)
66
  self.send_header("Content-Type", "application/json")
dataset.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+
3
+ dataset = load_dataset("emotion", split="train")
4
+
5
+ emotions = dataset.info.features["label"].names
6
+
7
+ def query_emotion(start, end):
8
+ rows = dataset[start:end]
9
+ texts, labels = [rows[k] for k in rows.keys()]
10
+
11
+ observations = []
12
+
13
+ for i, text in enumerate(texts):
14
+ observations.append({
15
+ "text": text,
16
+ "emotion": emotions[labels[i]],
17
+ })
18
+
19
+ return observations
index.html CHANGED
The diff for this file is too large to render. See raw diff
 
index.js CHANGED
@@ -1,26 +1,74 @@
1
- if (document.location.search.includes("dark-theme=true")) {
2
- document.body.classList.add("dark-theme");
3
  }
4
 
5
  const textToImage = async (text) => {
6
- const inferenceResponse = await fetch(`biggan_infer?input=${text}`);
7
  const inferenceBlob = await inferenceResponse.blob();
8
 
9
  return URL.createObjectURL(inferenceBlob);
10
  };
11
 
12
  const translateText = async (text) => {
13
- const inferResponse = await fetch(`t5_infer?input=${text}`);
14
  const inferJson = await inferResponse.json();
15
 
16
  return inferJson.output;
17
  };
18
 
19
- const imageGenSelect = document.getElementById("image-gen-input");
20
- const imageGenImage = document.querySelector(".image-gen-output");
21
- const textGenForm = document.querySelector(".text-gen-form");
22
 
23
- imageGenSelect.addEventListener("change", async (event) => {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  const value = event.target.value;
25
 
26
  try {
@@ -30,11 +78,11 @@ imageGenSelect.addEventListener("change", async (event) => {
30
  }
31
  });
32
 
33
- textGenForm.addEventListener("submit", async (event) => {
34
  event.preventDefault();
35
 
36
- const textGenInput = document.getElementById("text-gen-input");
37
- const textGenParagraph = document.querySelector(".text-gen-output");
38
 
39
  try {
40
  textGenParagraph.textContent = await translateText(textGenInput.value);
@@ -43,6 +91,36 @@ textGenForm.addEventListener("submit", async (event) => {
43
  }
44
  });
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  textToImage(imageGenSelect.value)
47
  .then((image) => (imageGenImage.src = image))
48
  .catch(console.error);
 
 
 
 
 
1
+ if (document.location.search.includes('dark-theme=true')) {
2
+ document.body.classList.add('dark-theme');
3
  }
4
 
5
  const textToImage = async (text) => {
6
+ const inferenceResponse = await fetch(`/infer_biggan?input=${text}`);
7
  const inferenceBlob = await inferenceResponse.blob();
8
 
9
  return URL.createObjectURL(inferenceBlob);
10
  };
11
 
12
  const translateText = async (text) => {
13
+ const inferResponse = await fetch(`/infer_t5?input=${text}`);
14
  const inferJson = await inferResponse.json();
15
 
16
  return inferJson.output;
17
  };
18
 
19
+ const queryDataset = async (start, end) => {
20
+ const queryResponse = await fetch(`/query_emotion?start=${start}&end=${end}`);
21
+ const queryJson = await queryResponse.json();
22
 
23
+ return queryJson.output;
24
+ };
25
+
26
+ const updateTable = async (cursor, range = 5) => {
27
+ const table = document.querySelector('.dataset-output');
28
+
29
+ const fragment = new DocumentFragment();
30
+
31
+ const observations = await queryDataset(cursor, cursor + range);
32
+
33
+ for (const observation of observations) {
34
+ let row = document.createElement('tr');
35
+ let text = document.createElement('td');
36
+ let emotion = document.createElement('td');
37
+
38
+ text.textContent = observation.text;
39
+ emotion.textContent = observation.emotion;
40
+
41
+ row.appendChild(text);
42
+ row.appendChild(emotion);
43
+ fragment.appendChild(row);
44
+ }
45
+
46
+ table.innerHTML = '';
47
+
48
+ table.appendChild(fragment);
49
+
50
+ table.insertAdjacentHTML(
51
+ 'afterbegin',
52
+ `<thead>
53
+ <tr>
54
+ <td>text</td>
55
+ <td>emotion</td>
56
+ </tr>
57
+ </thead>`
58
+ );
59
+ };
60
+
61
+ const imageGenSelect = document.getElementById('image-gen-input');
62
+ const imageGenImage = document.querySelector('.image-gen-output');
63
+ const textGenForm = document.querySelector('.text-gen-form');
64
+ const tableButtonPrev = document.querySelector('.table-previous');
65
+ const tableButtonNext = document.querySelector('.table-next');
66
+
67
+ let cursor = 0;
68
+ const RANGE = 5;
69
+ const LIMIT = 16_000;
70
+
71
+ imageGenSelect.addEventListener('change', async (event) => {
72
  const value = event.target.value;
73
 
74
  try {
 
78
  }
79
  });
80
 
81
+ textGenForm.addEventListener('submit', async (event) => {
82
  event.preventDefault();
83
 
84
+ const textGenInput = document.getElementById('text-gen-input');
85
+ const textGenParagraph = document.querySelector('.text-gen-output');
86
 
87
  try {
88
  textGenParagraph.textContent = await translateText(textGenInput.value);
 
91
  }
92
  });
93
 
94
+ tableButtonPrev.addEventListener('click', () => {
95
+ cursor = cursor > RANGE ? cursor - RANGE : 0;
96
+
97
+ if (cursor < RANGE) {
98
+ tableButtonPrev.classList.add('hidden');
99
+ }
100
+ if (cursor < LIMIT - RANGE) {
101
+ tableButtonNext.classList.remove('hidden');
102
+ }
103
+
104
+ updateTable(cursor);
105
+ });
106
+
107
+ tableButtonNext.addEventListener('click', () => {
108
+ cursor = cursor < LIMIT - RANGE ? cursor + RANGE : cursor;
109
+
110
+ if (cursor >= RANGE) {
111
+ tableButtonPrev.classList.remove('hidden');
112
+ }
113
+ if (cursor >= LIMIT - RANGE) {
114
+ tableButtonNext.classList.add('hidden');
115
+ }
116
+
117
+ updateTable(cursor);
118
+ });
119
+
120
  textToImage(imageGenSelect.value)
121
  .then((image) => (imageGenImage.src = image))
122
  .catch(console.error);
123
+
124
+ updateTable(cursor)
125
+ .then((image) => (imageGenImage.src = image))
126
+ .catch(console.error);
inference.py CHANGED
@@ -3,8 +3,9 @@ from transformers import T5Tokenizer, T5ForConditionalGeneration
3
  tokenizer = T5Tokenizer.from_pretrained("t5-small")
4
  model = T5ForConditionalGeneration.from_pretrained("t5-small")
5
 
6
- def t5_infer(input):
7
- input_ids = tokenizer(input, return_tensors="pt").input_ids
8
- outputs = model.generate(input_ids)
9
 
10
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
3
  tokenizer = T5Tokenizer.from_pretrained("t5-small")
4
  model = T5ForConditionalGeneration.from_pretrained("t5-small")
5
 
 
 
 
6
 
7
+ def infer_t5(input):
8
+ input_ids = tokenizer(input, return_tensors="pt").input_ids
9
+ outputs = model.generate(input_ids)
10
+
11
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
style.css CHANGED
@@ -12,11 +12,14 @@ body.dark-theme {
12
  }
13
 
14
  main {
 
 
 
 
 
15
  display: flex;
16
  flex-direction: column;
17
  align-items: center;
18
- max-width: 80rem;
19
- text-align: center;
20
  }
21
 
22
  a {
@@ -40,9 +43,40 @@ input {
40
  width: 70%;
41
  }
42
 
 
 
 
 
43
  .text-gen-output {
44
- min-height: 1rem;
45
- margin: 0;
46
  align-self: start;
47
- border: 2px solid var(--text);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  }
 
12
  }
13
 
14
  main {
15
+ max-width: 80rem;
16
+ text-align: center;
17
+ }
18
+
19
+ section {
20
  display: flex;
21
  flex-direction: column;
22
  align-items: center;
 
 
23
  }
24
 
25
  a {
 
43
  width: 70%;
44
  }
45
 
46
+ button {
47
+ cursor: pointer;
48
+ }
49
+
50
  .text-gen-output {
51
+ min-height: 1.2rem;
52
+ margin: 1rem;
53
  align-self: start;
54
+ border: 0.5px solid grey;
55
+ }
56
+
57
+ #dataset button {
58
+ width: 6rem;
59
+ margin: 0.5rem;
60
+ }
61
+
62
+ #dataset button.hidden {
63
+ visibility: hidden;
64
+ }
65
+
66
+ table {
67
+ max-width: 40rem;
68
+ text-align: left;
69
+ border-collapse: collapse;
70
+ }
71
+
72
+ thead {
73
+ font-weight: bold;
74
+ }
75
+
76
+ td {
77
+ padding: 0.5rem;
78
+ }
79
+
80
+ td:not(thead td) {
81
+ border: 0.5px solid grey;
82
  }