|
1 | 1 | from __future__ import print_function
|
2 | 2 |
|
3 |
| -from itertools import chain |
4 |
| -import multiprocessing |
5 |
| -import os |
6 |
| -import signal |
7 | 3 | import socket
|
8 |
| -import sys |
9 |
| -import traceback |
10 | 4 | import unittest
|
11 | 5 |
|
12 | 6 | import six
|
@@ -212,70 +206,19 @@ def test_socket_error(self):
|
212 | 206 |
|
213 | 207 | def test_exception_handling(self):
|
214 | 208 | """Tests closing socket when custom exception raised"""
|
215 |
| - queue = multiprocessing.Queue() |
216 |
| - process = multiprocessing.Process(target=worker, args=(self.mc, queue)) |
217 |
| - process.start() |
218 |
| - if queue.get() != 'loop started': |
219 |
| - raise ValueError( |
220 |
| - 'Expected "loop started" message from the child process' |
221 |
| - ) |
| 209 | + class CustomException(Exception): |
| 210 | + pass |
222 | 211 |
|
223 |
| - # maximum test duration is 0.5 second |
224 |
| - num_iters = 50 |
225 |
| - timeout = 0.01 |
226 |
| - for i in range(num_iters): |
227 |
| - os.kill(process.pid, signal.SIGUSR1) |
| 212 | + self.mc.set('error', 1) |
| 213 | + with patch.object(self.mc, '_recv_value', |
| 214 | + Mock(side_effect=CustomException('custom error'))): |
228 | 215 | try:
|
229 |
| - exc = WorkerError(*queue.get(timeout=timeout)) |
230 |
| - raise exc |
231 |
| - except six.moves.queue.Empty: |
| 216 | + self.mc.get('error') |
| 217 | + except CustomException: |
232 | 218 | pass
|
233 |
| - if not process.is_alive(): |
234 |
| - break |
235 |
| - |
236 |
| - if process.is_alive(): |
237 |
| - os.kill(process.pid, signal.SIGTERM) |
238 |
| - process.join() |
239 |
| - |
240 |
| - |
241 |
| -class SignalException(Exception): |
242 |
| - pass |
243 |
| - |
244 |
| - |
245 |
| -def sighandler(signum, frame): |
246 |
| - raise SignalException() |
247 |
| - |
248 |
| - |
249 |
| -class WorkerError(Exception): |
250 |
| - def __init__(self, exc, assert_tb, signal_tb=None): |
251 |
| - super(WorkerError, self).__init__( |
252 |
| - ''.join(chain(assert_tb, signal_tb or [])) |
253 |
| - ) |
254 |
| - self.cause = exc |
255 |
| - |
256 |
| - |
257 |
| -def worker(mc, queue): |
258 |
| - signal.signal(signal.SIGUSR1, sighandler) |
259 |
| - |
260 |
| - signal_tb = None |
261 |
| - for i in range(100000): |
262 |
| - if i == 0: |
263 |
| - queue.put('loop started') |
264 |
| - try: |
265 |
| - k = str(i) |
266 |
| - mc.set(k, i) |
267 |
| - # This loop is just to increase chance to get previous value |
268 |
| - # for clarity |
269 |
| - for j in range(10): |
270 |
| - mc.get(str(i-1)) |
271 |
| - res = mc.get(k) |
272 |
| - assert res == i, 'Expected {} but was {}'.format(i, res) |
273 |
| - except AssertionError as e: |
274 |
| - assert_tb = traceback.format_exception(*sys.exc_info()) |
275 |
| - queue.put((e, assert_tb, signal_tb)) |
276 |
| - break |
277 |
| - except SignalException as e: |
278 |
| - signal_tb = traceback.format_exception(*sys.exc_info()) |
| 219 | + self.assertIs(self.mc.servers[0].socket, None) |
| 220 | + self.assertEqual(self.mc.set('error', 2), True) |
| 221 | + self.assertEqual(self.mc.get('error'), 2) |
279 | 222 |
|
280 | 223 |
|
281 | 224 | if __name__ == '__main__':
|
|
0 commit comments