Skip to content

Commit 206cef5

Browse files
committed
Close socket when any exception was raised
1 parent 2b5c11e commit 206cef5

File tree

1 file changed

+42
-63
lines changed

1 file changed

+42
-63
lines changed

memcache.py

+42-63
Original file line numberDiff line numberDiff line change
@@ -488,13 +488,10 @@ def delete_multi(self, keys, time=None, key_prefix='', noreply=False):
488488
for key in server_keys[server]: # These are mangled keys
489489
cmd = self._encode_cmd('delete', key, headers, noreply, b'\r\n')
490490
write(cmd)
491-
try:
491+
with _socket_guard(server, (socket.error,)) as sg:
492492
server.send_cmds(b''.join(bigcmd))
493-
except socket.error as msg:
493+
if sg.interrupted:
494494
rc = 0
495-
if isinstance(msg, tuple):
496-
msg = msg[1]
497-
server.mark_dead(msg)
498495
dead_servers.append(server)
499496

500497
# if noreply, just return
@@ -506,13 +503,10 @@ def delete_multi(self, keys, time=None, key_prefix='', noreply=False):
506503
del server_keys[server]
507504

508505
for server, keys in six.iteritems(server_keys):
509-
try:
506+
with _socket_guard(server, (socket.error,)) as sg:
510507
for key in keys:
511508
server.expect(b"DELETED")
512-
except socket.error as msg:
513-
if isinstance(msg, tuple):
514-
msg = msg[1]
515-
server.mark_dead(msg)
509+
if sg.interrupted:
516510
rc = 0
517511
return rc
518512

@@ -558,7 +552,7 @@ def _deletetouch(self, expected, cmd, key, time=0, noreply=False):
558552
headers = None
559553
fullcmd = self._encode_cmd(cmd, key, headers, noreply)
560554

561-
try:
555+
with _socket_guard(server, (socket.error,)):
562556
server.send_cmd(fullcmd)
563557
if noreply:
564558
return 1
@@ -567,10 +561,6 @@ def _deletetouch(self, expected, cmd, key, time=0, noreply=False):
567561
return 1
568562
self.debuglog('%s expected %s, got: %r'
569563
% (cmd, ' or '.join(expected), line))
570-
except socket.error as msg:
571-
if isinstance(msg, tuple):
572-
msg = msg[1]
573-
server.mark_dead(msg)
574564
return 0
575565

576566
def incr(self, key, delta=1, noreply=False):
@@ -633,19 +623,14 @@ def _incrdecr(self, cmd, key, delta, noreply=False):
633623
return None
634624
self._statlog(cmd)
635625
fullcmd = self._encode_cmd(cmd, key, str(delta), noreply)
636-
try:
626+
with _socket_guard(server, (socket.error,)):
637627
server.send_cmd(fullcmd)
638628
if noreply:
639629
return
640630
line = server.readline()
641631
if line is None or line.strip() == b'NOT_FOUND':
642632
return None
643633
return int(line)
644-
except socket.error as msg:
645-
if isinstance(msg, tuple):
646-
msg = msg[1]
647-
server.mark_dead(msg)
648-
return None
649634

650635
def add(self, key, val, time=0, min_compress_len=0, noreply=False):
651636
'''Add new key with value.
@@ -902,7 +887,7 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0,
902887
for server in six.iterkeys(server_keys):
903888
bigcmd = []
904889
write = bigcmd.append
905-
try:
890+
with _socket_guard(server, (socket.error,)) as sg:
906891
for key in server_keys[server]: # These are mangled keys
907892
store_info = self._val_to_store_info(
908893
mapping[prefixed_to_orig_key[key]],
@@ -917,10 +902,7 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0,
917902
else:
918903
notstored.append(prefixed_to_orig_key[key])
919904
server.send_cmds(b''.join(bigcmd))
920-
except socket.error as msg:
921-
if isinstance(msg, tuple):
922-
msg = msg[1]
923-
server.mark_dead(msg)
905+
if sg.interrupted:
924906
dead_servers.append(server)
925907

926908
# if noreply, just return early
@@ -936,17 +918,13 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0,
936918
return list(mapping.keys())
937919

938920
for server, keys in six.iteritems(server_keys):
939-
try:
921+
with _socket_guard(server, (_Error, socket.error)):
940922
for key in keys:
941923
if server.readline() == b'STORED':
942924
continue
943925
else:
944926
# un-mangle.
945927
notstored.append(prefixed_to_orig_key[key])
946-
except (_Error, socket.error) as msg:
947-
if isinstance(msg, tuple):
948-
msg = msg[1]
949-
server.mark_dead(msg)
950928
return notstored
951929

952930
def _val_to_store_info(self, val, min_compress_len):
@@ -1032,15 +1010,11 @@ def _unsafe_set():
10321010
fullcmd = self._encode_cmd(cmd, key, headers, noreply,
10331011
b'\r\n', encoded_val)
10341012

1035-
try:
1013+
with _socket_guard(server, (socket.error,)):
10361014
server.send_cmd(fullcmd)
10371015
if noreply:
10381016
return True
10391017
return server.expect(b"STORED", raise_exception=True) == b"STORED"
1040-
except socket.error as msg:
1041-
if isinstance(msg, tuple):
1042-
msg = msg[1]
1043-
server.mark_dead(msg)
10441018
return 0
10451019

10461020
try:
@@ -1065,7 +1039,7 @@ def _get(self, cmd, key):
10651039
def _unsafe_get():
10661040
self._statlog(cmd)
10671041

1068-
try:
1042+
with _socket_guard(server, (_Error, socket.error)):
10691043
cmd_bytes = cmd.encode('utf-8') if six.PY3 else cmd
10701044
fullcmd = b''.join((cmd_bytes, b' ', key))
10711045
server.send_cmd(fullcmd)
@@ -1085,16 +1059,9 @@ def _unsafe_get():
10851059
if not rkey:
10861060
return None
10871061
try:
1088-
value = self._recv_value(server, flags, rlen)
1062+
return self._recv_value(server, flags, rlen)
10891063
finally:
10901064
server.expect(b"END", raise_exception=True)
1091-
except (_Error, socket.error) as msg:
1092-
if isinstance(msg, tuple):
1093-
msg = msg[1]
1094-
server.mark_dead(msg)
1095-
return None
1096-
1097-
return value
10981065

10991066
try:
11001067
return _unsafe_get()
@@ -1185,13 +1152,10 @@ def get_multi(self, keys, key_prefix=''):
11851152
# send out all requests on each server before reading anything
11861153
dead_servers = []
11871154
for server in six.iterkeys(server_keys):
1188-
try:
1155+
with _socket_guard(server, (socket.error,)) as sg:
11891156
fullcmd = b"get " + b" ".join(server_keys[server])
11901157
server.send_cmd(fullcmd)
1191-
except socket.error as msg:
1192-
if isinstance(msg, tuple):
1193-
msg = msg[1]
1194-
server.mark_dead(msg)
1158+
if sg.interrupted:
11951159
dead_servers.append(server)
11961160

11971161
# if any servers died on the way, don't expect them to respond.
@@ -1200,7 +1164,7 @@ def get_multi(self, keys, key_prefix=''):
12001164

12011165
retvals = {}
12021166
for server in six.iterkeys(server_keys):
1203-
try:
1167+
with _socket_guard(server, (_Error, socket.error)):
12041168
line = server.readline()
12051169
while line and line != b'END':
12061170
rkey, flags, rlen = self._expectvalue(server, line)
@@ -1210,10 +1174,6 @@ def get_multi(self, keys, key_prefix=''):
12101174
# un-prefix returned key.
12111175
retvals[prefixed_to_orig_key[rkey]] = val
12121176
line = server.readline()
1213-
except (_Error, socket.error) as msg:
1214-
if isinstance(msg, tuple):
1215-
msg = msg[1]
1216-
server.mark_dead(msg)
12171177
return retvals
12181178

12191179
def _expect_cas_value(self, server, line=None, raise_exception=False):
@@ -1394,15 +1354,10 @@ def _get_socket(self):
13941354
s = socket.socket(self.family, socket.SOCK_STREAM)
13951355
if hasattr(s, 'settimeout'):
13961356
s.settimeout(self.socket_timeout)
1397-
try:
1357+
with _socket_guard(self, (socket.error,),
1358+
msg_tmpl='connect: {}') as sg:
13981359
s.connect(self.address)
1399-
except socket.timeout as msg:
1400-
self.mark_dead("connect: %s" % msg)
1401-
return None
1402-
except socket.error as msg:
1403-
if isinstance(msg, tuple):
1404-
msg = msg[1]
1405-
self.mark_dead("connect: %s" % msg)
1360+
if sg.interrupted:
14061361
return None
14071362
self.socket = s
14081363
self.buffer = b''
@@ -1497,6 +1452,30 @@ def __str__(self):
14971452
return "unix:%s%s" % (self.address, d)
14981453

14991454

1455+
class _socket_guard(object):
1456+
def __init__(self, server, exceptions, msg_tmpl='{}'):
1457+
self._server = server
1458+
self._exceptions = exceptions
1459+
self._msg_tmpl = msg_tmpl
1460+
self.interrupted = False
1461+
1462+
def __enter__(self):
1463+
return self
1464+
1465+
def __exit__(self, exc_type, exc, exc_tb):
1466+
if exc is not None:
1467+
self.interrupted = True
1468+
1469+
if isinstance(exc, self._exceptions):
1470+
msg = self._msg_tmpl.format(exc)
1471+
self._server.mark_dead(msg)
1472+
return True
1473+
elif exc is not None:
1474+
self._server.close_socket()
1475+
1476+
return False
1477+
1478+
15001479
def _doctest():
15011480
import doctest
15021481
import memcache

0 commit comments

Comments
 (0)