@@ -488,13 +488,10 @@ def delete_multi(self, keys, time=None, key_prefix='', noreply=False):
488
488
for key in server_keys [server ]: # These are mangled keys
489
489
cmd = self ._encode_cmd ('delete' , key , headers , noreply , b'\r \n ' )
490
490
write (cmd )
491
- try :
491
+ with _socket_guard ( server , ( socket . error ,)) as sg :
492
492
server .send_cmds (b'' .join (bigcmd ))
493
- except socket . error as msg :
493
+ if sg . interrupted :
494
494
rc = 0
495
- if isinstance (msg , tuple ):
496
- msg = msg [1 ]
497
- server .mark_dead (msg )
498
495
dead_servers .append (server )
499
496
500
497
# if noreply, just return
@@ -506,13 +503,10 @@ def delete_multi(self, keys, time=None, key_prefix='', noreply=False):
506
503
del server_keys [server ]
507
504
508
505
for server , keys in six .iteritems (server_keys ):
509
- try :
506
+ with _socket_guard ( server , ( socket . error ,)) as sg :
510
507
for key in keys :
511
508
server .expect (b"DELETED" )
512
- except socket .error as msg :
513
- if isinstance (msg , tuple ):
514
- msg = msg [1 ]
515
- server .mark_dead (msg )
509
+ if sg .interrupted :
516
510
rc = 0
517
511
return rc
518
512
@@ -558,7 +552,7 @@ def _deletetouch(self, expected, cmd, key, time=0, noreply=False):
558
552
headers = None
559
553
fullcmd = self ._encode_cmd (cmd , key , headers , noreply )
560
554
561
- try :
555
+ with _socket_guard ( server , ( socket . error ,)) :
562
556
server .send_cmd (fullcmd )
563
557
if noreply :
564
558
return 1
@@ -567,10 +561,6 @@ def _deletetouch(self, expected, cmd, key, time=0, noreply=False):
567
561
return 1
568
562
self .debuglog ('%s expected %s, got: %r'
569
563
% (cmd , ' or ' .join (expected ), line ))
570
- except socket .error as msg :
571
- if isinstance (msg , tuple ):
572
- msg = msg [1 ]
573
- server .mark_dead (msg )
574
564
return 0
575
565
576
566
def incr (self , key , delta = 1 , noreply = False ):
@@ -633,19 +623,14 @@ def _incrdecr(self, cmd, key, delta, noreply=False):
633
623
return None
634
624
self ._statlog (cmd )
635
625
fullcmd = self ._encode_cmd (cmd , key , str (delta ), noreply )
636
- try :
626
+ with _socket_guard ( server , ( socket . error ,)) :
637
627
server .send_cmd (fullcmd )
638
628
if noreply :
639
629
return
640
630
line = server .readline ()
641
631
if line is None or line .strip () == b'NOT_FOUND' :
642
632
return None
643
633
return int (line )
644
- except socket .error as msg :
645
- if isinstance (msg , tuple ):
646
- msg = msg [1 ]
647
- server .mark_dead (msg )
648
- return None
649
634
650
635
def add (self , key , val , time = 0 , min_compress_len = 0 , noreply = False ):
651
636
'''Add new key with value.
@@ -902,7 +887,7 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0,
902
887
for server in six .iterkeys (server_keys ):
903
888
bigcmd = []
904
889
write = bigcmd .append
905
- try :
890
+ with _socket_guard ( server , ( socket . error ,)) as sg :
906
891
for key in server_keys [server ]: # These are mangled keys
907
892
store_info = self ._val_to_store_info (
908
893
mapping [prefixed_to_orig_key [key ]],
@@ -917,10 +902,7 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0,
917
902
else :
918
903
notstored .append (prefixed_to_orig_key [key ])
919
904
server .send_cmds (b'' .join (bigcmd ))
920
- except socket .error as msg :
921
- if isinstance (msg , tuple ):
922
- msg = msg [1 ]
923
- server .mark_dead (msg )
905
+ if sg .interrupted :
924
906
dead_servers .append (server )
925
907
926
908
# if noreply, just return early
@@ -936,17 +918,13 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0,
936
918
return list (mapping .keys ())
937
919
938
920
for server , keys in six .iteritems (server_keys ):
939
- try :
921
+ with _socket_guard ( server , ( _Error , socket . error )) :
940
922
for key in keys :
941
923
if server .readline () == b'STORED' :
942
924
continue
943
925
else :
944
926
# un-mangle.
945
927
notstored .append (prefixed_to_orig_key [key ])
946
- except (_Error , socket .error ) as msg :
947
- if isinstance (msg , tuple ):
948
- msg = msg [1 ]
949
- server .mark_dead (msg )
950
928
return notstored
951
929
952
930
def _val_to_store_info (self , val , min_compress_len ):
@@ -1032,15 +1010,11 @@ def _unsafe_set():
1032
1010
fullcmd = self ._encode_cmd (cmd , key , headers , noreply ,
1033
1011
b'\r \n ' , encoded_val )
1034
1012
1035
- try :
1013
+ with _socket_guard ( server , ( socket . error ,)) :
1036
1014
server .send_cmd (fullcmd )
1037
1015
if noreply :
1038
1016
return True
1039
1017
return server .expect (b"STORED" , raise_exception = True ) == b"STORED"
1040
- except socket .error as msg :
1041
- if isinstance (msg , tuple ):
1042
- msg = msg [1 ]
1043
- server .mark_dead (msg )
1044
1018
return 0
1045
1019
1046
1020
try :
@@ -1065,7 +1039,7 @@ def _get(self, cmd, key):
1065
1039
def _unsafe_get ():
1066
1040
self ._statlog (cmd )
1067
1041
1068
- try :
1042
+ with _socket_guard ( server , ( _Error , socket . error )) :
1069
1043
cmd_bytes = cmd .encode ('utf-8' ) if six .PY3 else cmd
1070
1044
fullcmd = b'' .join ((cmd_bytes , b' ' , key ))
1071
1045
server .send_cmd (fullcmd )
@@ -1085,16 +1059,9 @@ def _unsafe_get():
1085
1059
if not rkey :
1086
1060
return None
1087
1061
try :
1088
- value = self ._recv_value (server , flags , rlen )
1062
+ return self ._recv_value (server , flags , rlen )
1089
1063
finally :
1090
1064
server .expect (b"END" , raise_exception = True )
1091
- except (_Error , socket .error ) as msg :
1092
- if isinstance (msg , tuple ):
1093
- msg = msg [1 ]
1094
- server .mark_dead (msg )
1095
- return None
1096
-
1097
- return value
1098
1065
1099
1066
try :
1100
1067
return _unsafe_get ()
@@ -1185,13 +1152,10 @@ def get_multi(self, keys, key_prefix=''):
1185
1152
# send out all requests on each server before reading anything
1186
1153
dead_servers = []
1187
1154
for server in six .iterkeys (server_keys ):
1188
- try :
1155
+ with _socket_guard ( server , ( socket . error ,)) as sg :
1189
1156
fullcmd = b"get " + b" " .join (server_keys [server ])
1190
1157
server .send_cmd (fullcmd )
1191
- except socket .error as msg :
1192
- if isinstance (msg , tuple ):
1193
- msg = msg [1 ]
1194
- server .mark_dead (msg )
1158
+ if sg .interrupted :
1195
1159
dead_servers .append (server )
1196
1160
1197
1161
# if any servers died on the way, don't expect them to respond.
@@ -1200,7 +1164,7 @@ def get_multi(self, keys, key_prefix=''):
1200
1164
1201
1165
retvals = {}
1202
1166
for server in six .iterkeys (server_keys ):
1203
- try :
1167
+ with _socket_guard ( server , ( _Error , socket . error )) :
1204
1168
line = server .readline ()
1205
1169
while line and line != b'END' :
1206
1170
rkey , flags , rlen = self ._expectvalue (server , line )
@@ -1210,10 +1174,6 @@ def get_multi(self, keys, key_prefix=''):
1210
1174
# un-prefix returned key.
1211
1175
retvals [prefixed_to_orig_key [rkey ]] = val
1212
1176
line = server .readline ()
1213
- except (_Error , socket .error ) as msg :
1214
- if isinstance (msg , tuple ):
1215
- msg = msg [1 ]
1216
- server .mark_dead (msg )
1217
1177
return retvals
1218
1178
1219
1179
def _expect_cas_value (self , server , line = None , raise_exception = False ):
@@ -1394,15 +1354,10 @@ def _get_socket(self):
1394
1354
s = socket .socket (self .family , socket .SOCK_STREAM )
1395
1355
if hasattr (s , 'settimeout' ):
1396
1356
s .settimeout (self .socket_timeout )
1397
- try :
1357
+ with _socket_guard (self , (socket .error ,),
1358
+ msg_tmpl = 'connect: {}' ) as sg :
1398
1359
s .connect (self .address )
1399
- except socket .timeout as msg :
1400
- self .mark_dead ("connect: %s" % msg )
1401
- return None
1402
- except socket .error as msg :
1403
- if isinstance (msg , tuple ):
1404
- msg = msg [1 ]
1405
- self .mark_dead ("connect: %s" % msg )
1360
+ if sg .interrupted :
1406
1361
return None
1407
1362
self .socket = s
1408
1363
self .buffer = b''
@@ -1497,6 +1452,30 @@ def __str__(self):
1497
1452
return "unix:%s%s" % (self .address , d )
1498
1453
1499
1454
1455
+ class _socket_guard (object ):
1456
+ def __init__ (self , server , exceptions , msg_tmpl = '{}' ):
1457
+ self ._server = server
1458
+ self ._exceptions = exceptions
1459
+ self ._msg_tmpl = msg_tmpl
1460
+ self .interrupted = False
1461
+
1462
+ def __enter__ (self ):
1463
+ return self
1464
+
1465
+ def __exit__ (self , exc_type , exc , exc_tb ):
1466
+ if exc is not None :
1467
+ self .interrupted = True
1468
+
1469
+ if isinstance (exc , self ._exceptions ):
1470
+ msg = self ._msg_tmpl .format (exc )
1471
+ self ._server .mark_dead (msg )
1472
+ return True
1473
+ elif exc is not None :
1474
+ self ._server .close_socket ()
1475
+
1476
+ return False
1477
+
1478
+
1500
1479
def _doctest ():
1501
1480
import doctest
1502
1481
import memcache
0 commit comments