Jae-Won Chung commited on
Commit
7aacedb
1 Parent(s): 8cf1cb0

Separate out tests

Browse files
setup.py CHANGED
@@ -12,6 +12,7 @@ extras_require = {
12
  "text_generation @ git+https://github.com/ml-energy/text_generation_energy@master",
13
  ],
14
  "benchmark": ["zeus-ml", "fschat==0.2.23", "tyro", "rich"],
 
15
  }
16
 
17
  extras_require["all"] = list(set(sum(extras_require.values(), [])))
 
12
  "text_generation @ git+https://github.com/ml-energy/text_generation_energy@master",
13
  ],
14
  "benchmark": ["zeus-ml", "fschat==0.2.23", "tyro", "rich"],
15
+ "dev": ["pytest"],
16
  }
17
 
18
  extras_require["all"] = list(set(sum(extras_require.values(), [])))
spitfight/colosseum/client.py CHANGED
@@ -1,10 +1,8 @@
1
  from __future__ import annotations
2
 
3
  import json
4
- import unittest
5
  import contextlib
6
  from uuid import uuid4, UUID
7
- from copy import deepcopy
8
  from typing import Generator, Literal
9
 
10
  import requests
@@ -92,15 +90,3 @@ def _check_response(response: requests.Response) -> None:
92
  raise gr.Error(response.json()["detail"])
93
  elif response.status_code >= 500:
94
  raise gr.Error("Failed to talk to our backend server. Please try again later.")
95
-
96
-
97
- class TestControllerClient(unittest.TestCase):
98
- def test_new_uuid_on_deepcopy(self):
99
- client = ControllerClient("http://localhost:8000")
100
- clients = [client.fork() for _ in range(50)]
101
- request_ids = [client.request_id for client in clients]
102
- assert len(set(request_ids)) == len(request_ids)
103
-
104
-
105
- if __name__ == "__main__":
106
- unittest.main()
 
1
  from __future__ import annotations
2
 
3
  import json
 
4
  import contextlib
5
  from uuid import uuid4, UUID
 
6
  from typing import Generator, Literal
7
 
8
  import requests
 
90
  raise gr.Error(response.json()["detail"])
91
  elif response.status_code >= 500:
92
  raise gr.Error("Failed to talk to our backend server. Please try again later.")
 
 
 
 
 
 
 
 
 
 
 
 
spitfight/utils.py CHANGED
@@ -3,7 +3,6 @@ from __future__ import annotations
3
  import time
4
  import heapq
5
  import asyncio
6
- import unittest
7
  from typing import TypeVar, Generic, AsyncGenerator, Any, Coroutine
8
 
9
  from fastapi.logger import logger
@@ -178,128 +177,3 @@ class TokenGenerationBuffer:
178
  break
179
 
180
  return return_buffer or None
181
-
182
-
183
-
184
- class TestTokenGenerationBuffer(unittest.TestCase):
185
- def test_basic1(self):
186
- buffer = TokenGenerationBuffer(stop_str="stop")
187
-
188
- buffer.append("hello")
189
- self.assertEqual(buffer.pop(), "hello")
190
- self.assertEqual(buffer.pop(), None)
191
- self.assertFalse(buffer.matched_stop_str)
192
-
193
- buffer.append("world")
194
- self.assertEqual(buffer.pop(), "world")
195
- self.assertFalse(buffer.matched_stop_str)
196
-
197
- buffer.append("stop")
198
- self.assertEqual(buffer.pop(), None)
199
- self.assertTrue(buffer.matched_stop_str)
200
- self.assertEqual(buffer.pop(), None)
201
- self.assertTrue(buffer.matched_stop_str)
202
- self.assertEqual(buffer.pop(), None)
203
- self.assertTrue(buffer.matched_stop_str)
204
- self.assertEqual(buffer.pop(), None)
205
- self.assertTrue(buffer.matched_stop_str)
206
-
207
- def test_basic2(self):
208
- buffer = TokenGenerationBuffer(stop_str="stop")
209
-
210
- buffer.append("hi")
211
- self.assertEqual(buffer.pop(), "hi")
212
- self.assertFalse(buffer.matched_stop_str)
213
-
214
- buffer.append("stole")
215
- self.assertEqual(buffer.pop(), "stole")
216
- self.assertFalse(buffer.matched_stop_str)
217
-
218
- buffer.append("sto")
219
- self.assertEqual(buffer.pop(), None)
220
- self.assertFalse(buffer.matched_stop_str)
221
-
222
- buffer.append("ic")
223
- self.assertEqual(buffer.pop(), "stoic")
224
- self.assertFalse(buffer.matched_stop_str)
225
-
226
- buffer.append("st")
227
- self.assertEqual(buffer.pop(), None)
228
- self.assertFalse(buffer.matched_stop_str)
229
-
230
- buffer.append("opper")
231
- self.assertEqual(buffer.pop(), "stopper")
232
- self.assertFalse(buffer.matched_stop_str)
233
-
234
- buffer.append("sto")
235
- self.assertEqual(buffer.pop(), None)
236
- self.assertFalse(buffer.matched_stop_str)
237
-
238
- buffer.append("p")
239
- self.assertEqual(buffer.pop(), None)
240
- self.assertTrue(buffer.matched_stop_str)
241
-
242
- def test_falcon1(self):
243
- buffer = TokenGenerationBuffer(stop_str="\nUser")
244
-
245
- buffer.append("Hi")
246
- self.assertEqual(buffer.pop(), "Hi")
247
- self.assertFalse(buffer.matched_stop_str)
248
-
249
- buffer.append("!")
250
- self.assertEqual(buffer.pop(), "!")
251
- self.assertFalse(buffer.matched_stop_str)
252
-
253
- buffer.append("\n")
254
- self.assertEqual(buffer.pop(), None)
255
- self.assertFalse(buffer.matched_stop_str)
256
-
257
- buffer.append("User")
258
- self.assertEqual(buffer.pop(), None)
259
- self.assertTrue(buffer.matched_stop_str)
260
-
261
- def test_falcon2(self):
262
- buffer = TokenGenerationBuffer(stop_str="\nUser")
263
-
264
- buffer.append("\n")
265
- self.assertEqual(buffer.pop(), None)
266
- self.assertFalse(buffer.matched_stop_str)
267
-
268
- buffer.append("\n")
269
- self.assertEqual(buffer.pop(), "\n")
270
- self.assertFalse(buffer.matched_stop_str)
271
-
272
- buffer.append("\n")
273
- self.assertEqual(buffer.pop(), "\n")
274
- self.assertFalse(buffer.matched_stop_str)
275
-
276
- buffer.append("\n")
277
- self.assertEqual(buffer.pop(), "\n")
278
- self.assertFalse(buffer.matched_stop_str)
279
-
280
- buffer.append("User")
281
- self.assertEqual(buffer.pop(), None)
282
- self.assertEqual(buffer.pop(), None)
283
- self.assertTrue(buffer.matched_stop_str)
284
-
285
- def test_no_stop_str(self):
286
- buffer = TokenGenerationBuffer(stop_str=None)
287
-
288
- buffer.append("hello")
289
- self.assertEqual(buffer.pop(), "hello")
290
- self.assertEqual(buffer.pop(), None)
291
- self.assertFalse(buffer.matched_stop_str)
292
-
293
- buffer.append("world")
294
- self.assertEqual(buffer.pop(), "world")
295
- self.assertEqual(buffer.pop(), None)
296
- self.assertFalse(buffer.matched_stop_str)
297
-
298
- buffer.append("\n")
299
- self.assertEqual(buffer.pop(), "\n")
300
- self.assertEqual(buffer.pop(), None)
301
- self.assertFalse(buffer.matched_stop_str)
302
-
303
-
304
- if __name__ == "__main__":
305
- unittest.main()
 
3
  import time
4
  import heapq
5
  import asyncio
 
6
  from typing import TypeVar, Generic, AsyncGenerator, Any, Coroutine
7
 
8
  from fastapi.logger import logger
 
177
  break
178
 
179
  return return_buffer or None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/colosseum/test_client.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from spitfight.colosseum.client import ControllerClient
4
+
5
+
6
+ def test_new_uuid_on_deepcopy():
7
+ client = ControllerClient("http://localhost:8000")
8
+ clients = [client.fork() for _ in range(50)]
9
+ request_ids = [client.request_id for client in clients]
10
+ assert len(set(request_ids)) == len(request_ids)
tests/test_utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from spitfight.utils import TokenGenerationBuffer
4
+
5
+
6
+ def test_basic1():
7
+ buffer = TokenGenerationBuffer(stop_str="stop")
8
+
9
+ buffer.append("hello")
10
+ assert buffer.pop() == "hello"
11
+ assert buffer.pop() == None
12
+ assert not buffer.matched_stop_str
13
+
14
+ buffer.append("world")
15
+ assert buffer.pop() == "world"
16
+ assert not buffer.matched_stop_str
17
+
18
+ buffer.append("stop")
19
+ assert buffer.pop() == None
20
+ assert buffer.matched_stop_str
21
+ assert buffer.pop() == None
22
+ assert buffer.matched_stop_str
23
+ assert buffer.pop() == None
24
+ assert buffer.matched_stop_str
25
+ assert buffer.pop() == None
26
+ assert buffer.matched_stop_str
27
+
28
+ def test_basic2():
29
+ buffer = TokenGenerationBuffer(stop_str="stop")
30
+
31
+ buffer.append("hi")
32
+ assert buffer.pop() == "hi"
33
+ assert not buffer.matched_stop_str
34
+
35
+ buffer.append("stole")
36
+ assert buffer.pop() == "stole"
37
+ assert not buffer.matched_stop_str
38
+
39
+ buffer.append("sto")
40
+ assert buffer.pop() == None
41
+ assert not buffer.matched_stop_str
42
+
43
+ buffer.append("ic")
44
+ assert buffer.pop() == "stoic"
45
+ assert not buffer.matched_stop_str
46
+
47
+ buffer.append("st")
48
+ assert buffer.pop() == None
49
+ assert not buffer.matched_stop_str
50
+
51
+ buffer.append("opper")
52
+ assert buffer.pop() == "stopper"
53
+ assert not buffer.matched_stop_str
54
+
55
+ buffer.append("sto")
56
+ assert buffer.pop() == None
57
+ assert not buffer.matched_stop_str
58
+
59
+ buffer.append("p")
60
+ assert buffer.pop() == None
61
+ assert buffer.matched_stop_str
62
+
63
+ def test_falcon1():
64
+ buffer = TokenGenerationBuffer(stop_str="\nUser")
65
+
66
+ buffer.append("Hi")
67
+ assert buffer.pop() == "Hi"
68
+ assert not buffer.matched_stop_str
69
+
70
+ buffer.append("!")
71
+ assert buffer.pop() == "!"
72
+ assert not buffer.matched_stop_str
73
+
74
+ buffer.append("\n")
75
+ assert buffer.pop() == None
76
+ assert not buffer.matched_stop_str
77
+
78
+ buffer.append("User")
79
+ assert buffer.pop() == None
80
+ assert buffer.matched_stop_str
81
+
82
+ def test_falcon2():
83
+ buffer = TokenGenerationBuffer(stop_str="\nUser")
84
+
85
+ buffer.append("\n")
86
+ assert buffer.pop() == None
87
+ assert not buffer.matched_stop_str
88
+
89
+ buffer.append("\n")
90
+ assert buffer.pop() == "\n"
91
+ assert not buffer.matched_stop_str
92
+
93
+ buffer.append("\n")
94
+ assert buffer.pop() == "\n"
95
+ assert not buffer.matched_stop_str
96
+
97
+ buffer.append("\n")
98
+ assert buffer.pop() == "\n"
99
+ assert not buffer.matched_stop_str
100
+
101
+ buffer.append("User")
102
+ assert buffer.pop() == None
103
+ assert buffer.pop() == None
104
+ assert buffer.matched_stop_str
105
+
106
+ def test_no_stop_str():
107
+ buffer = TokenGenerationBuffer(stop_str=None)
108
+
109
+ buffer.append("hello")
110
+ assert buffer.pop() == "hello"
111
+ assert buffer.pop() == None
112
+ assert not buffer.matched_stop_str
113
+
114
+ buffer.append("world")
115
+ assert buffer.pop() == "world"
116
+ assert buffer.pop() == None
117
+ assert not buffer.matched_stop_str
118
+
119
+ buffer.append("\n")
120
+ assert buffer.pop() == "\n"
121
+ assert buffer.pop() == None
122
+ assert not buffer.matched_stop_str