Skip to content

Commit 6a8cd8c

Browse files
woutdenolfdvora-h
authored andcommitted
Add unit tests for the connect method of all Redis connection classes (redis#2631)
* tests: move certificate discovery to a separate module * tests: add 'connect' tests for all Redis connection classes --------- Co-authored-by: dvora-h <[email protected]>
1 parent 63e06dd commit 6a8cd8c

File tree

5 files changed

+350
-24
lines changed

5 files changed

+350
-24
lines changed

tests/ssl_utils.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import os
2+
3+
4+
def get_ssl_filename(name):
5+
root = os.path.join(os.path.dirname(__file__), "..")
6+
cert_dir = os.path.abspath(os.path.join(root, "docker", "stunnel", "keys"))
7+
if not os.path.isdir(cert_dir): # github actions package validation case
8+
cert_dir = os.path.abspath(
9+
os.path.join(root, "..", "docker", "stunnel", "keys")
10+
)
11+
if not os.path.isdir(cert_dir):
12+
raise IOError(f"No SSL certificates found. They should be in {cert_dir}")
13+
14+
return os.path.join(cert_dir, name)

tests/test_asyncio/test_cluster.py

+3-12
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import binascii
33
import datetime
4-
import os
54
import warnings
65
from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union
76
from urllib.parse import urlparse
@@ -36,6 +35,7 @@
3635
skip_unless_arch_bits,
3736
)
3837

38+
from ..ssl_utils import get_ssl_filename
3939
from .compat import mock
4040

4141
pytestmark = pytest.mark.onlycluster
@@ -2744,17 +2744,8 @@ class TestSSL:
27442744
appropriate port.
27452745
"""
27462746

2747-
ROOT = os.path.join(os.path.dirname(__file__), "../..")
2748-
CERT_DIR = os.path.abspath(os.path.join(ROOT, "docker", "stunnel", "keys"))
2749-
if not os.path.isdir(CERT_DIR): # github actions package validation case
2750-
CERT_DIR = os.path.abspath(
2751-
os.path.join(ROOT, "..", "docker", "stunnel", "keys")
2752-
)
2753-
if not os.path.isdir(CERT_DIR):
2754-
raise IOError(f"No SSL certificates found. They should be in {CERT_DIR}")
2755-
2756-
SERVER_CERT = os.path.join(CERT_DIR, "server-cert.pem")
2757-
SERVER_KEY = os.path.join(CERT_DIR, "server-key.pem")
2747+
SERVER_CERT = get_ssl_filename("server-cert.pem")
2748+
SERVER_KEY = get_ssl_filename("server-key.pem")
27582749

27592750
@pytest_asyncio.fixture()
27602751
def create_client(self, request: FixtureRequest) -> Callable[..., RedisCluster]:

tests/test_asyncio/test_connect.py

+145
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import asyncio
2+
import logging
3+
import re
4+
import socket
5+
import ssl
6+
7+
import pytest
8+
9+
from redis.asyncio.connection import (
10+
Connection,
11+
SSLConnection,
12+
UnixDomainSocketConnection,
13+
)
14+
15+
from ..ssl_utils import get_ssl_filename
16+
17+
_logger = logging.getLogger(__name__)
18+
19+
20+
_CLIENT_NAME = "test-suite-client"
21+
_CMD_SEP = b"\r\n"
22+
_SUCCESS_RESP = b"+OK" + _CMD_SEP
23+
_ERROR_RESP = b"-ERR" + _CMD_SEP
24+
_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP}
25+
26+
27+
@pytest.fixture
28+
def tcp_address():
29+
with socket.socket() as sock:
30+
sock.bind(("127.0.0.1", 0))
31+
return sock.getsockname()
32+
33+
34+
@pytest.fixture
35+
def uds_address(tmpdir):
36+
return tmpdir / "uds.sock"
37+
38+
39+
async def test_tcp_connect(tcp_address):
40+
host, port = tcp_address
41+
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10)
42+
await _assert_connect(conn, tcp_address)
43+
44+
45+
async def test_uds_connect(uds_address):
46+
path = str(uds_address)
47+
conn = UnixDomainSocketConnection(
48+
path=path, client_name=_CLIENT_NAME, socket_timeout=10
49+
)
50+
await _assert_connect(conn, path)
51+
52+
53+
@pytest.mark.ssl
54+
async def test_tcp_ssl_connect(tcp_address):
55+
host, port = tcp_address
56+
certfile = get_ssl_filename("server-cert.pem")
57+
keyfile = get_ssl_filename("server-key.pem")
58+
conn = SSLConnection(
59+
host=host,
60+
port=port,
61+
client_name=_CLIENT_NAME,
62+
ssl_ca_certs=certfile,
63+
socket_timeout=10,
64+
)
65+
await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)
66+
67+
68+
async def _assert_connect(conn, server_address, certfile=None, keyfile=None):
69+
stop_event = asyncio.Event()
70+
finished = asyncio.Event()
71+
72+
async def _handler(reader, writer):
73+
try:
74+
return await _redis_request_handler(reader, writer, stop_event)
75+
finally:
76+
finished.set()
77+
78+
if isinstance(server_address, str):
79+
server = await asyncio.start_unix_server(_handler, path=server_address)
80+
elif certfile:
81+
host, port = server_address
82+
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
83+
context.minimum_version = ssl.TLSVersion.TLSv1_2
84+
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
85+
server = await asyncio.start_server(_handler, host=host, port=port, ssl=context)
86+
else:
87+
host, port = server_address
88+
server = await asyncio.start_server(_handler, host=host, port=port)
89+
90+
async with server as aserver:
91+
await aserver.start_serving()
92+
try:
93+
await conn.connect()
94+
await conn.disconnect()
95+
finally:
96+
stop_event.set()
97+
aserver.close()
98+
await aserver.wait_closed()
99+
await finished.wait()
100+
101+
102+
async def _redis_request_handler(reader, writer, stop_event):
103+
buffer = b""
104+
command = None
105+
command_ptr = None
106+
fragment_length = None
107+
while not stop_event.is_set() or buffer:
108+
_logger.info(str(stop_event.is_set()))
109+
try:
110+
buffer += await asyncio.wait_for(reader.read(1024), timeout=0.5)
111+
except TimeoutError:
112+
continue
113+
if not buffer:
114+
continue
115+
parts = re.split(_CMD_SEP, buffer)
116+
buffer = parts[-1]
117+
for fragment in parts[:-1]:
118+
fragment = fragment.decode()
119+
_logger.info("Command fragment: %s", fragment)
120+
121+
if fragment.startswith("*") and command is None:
122+
command = [None for _ in range(int(fragment[1:]))]
123+
command_ptr = 0
124+
fragment_length = None
125+
continue
126+
127+
if fragment.startswith("$") and command[command_ptr] is None:
128+
fragment_length = int(fragment[1:])
129+
continue
130+
131+
assert len(fragment) == fragment_length
132+
command[command_ptr] = fragment
133+
command_ptr += 1
134+
135+
if command_ptr < len(command):
136+
continue
137+
138+
command = " ".join(command)
139+
_logger.info("Command %s", command)
140+
resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP)
141+
_logger.info("Response from %s", resp)
142+
writer.write(resp)
143+
await writer.drain()
144+
command = None
145+
_logger.info("Exit handler")

tests/test_connect.py

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import logging
2+
import re
3+
import socket
4+
import socketserver
5+
import ssl
6+
import threading
7+
8+
import pytest
9+
10+
from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection
11+
12+
from .ssl_utils import get_ssl_filename
13+
14+
_logger = logging.getLogger(__name__)
15+
16+
17+
_CLIENT_NAME = "test-suite-client"
18+
_CMD_SEP = b"\r\n"
19+
_SUCCESS_RESP = b"+OK" + _CMD_SEP
20+
_ERROR_RESP = b"-ERR" + _CMD_SEP
21+
_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP}
22+
23+
24+
@pytest.fixture
25+
def tcp_address():
26+
with socket.socket() as sock:
27+
sock.bind(("127.0.0.1", 0))
28+
return sock.getsockname()
29+
30+
31+
@pytest.fixture
32+
def uds_address(tmpdir):
33+
return tmpdir / "uds.sock"
34+
35+
36+
def test_tcp_connect(tcp_address):
37+
host, port = tcp_address
38+
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10)
39+
_assert_connect(conn, tcp_address)
40+
41+
42+
def test_uds_connect(uds_address):
43+
path = str(uds_address)
44+
conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME, socket_timeout=10)
45+
_assert_connect(conn, path)
46+
47+
48+
@pytest.mark.ssl
49+
def test_tcp_ssl_connect(tcp_address):
50+
host, port = tcp_address
51+
certfile = get_ssl_filename("server-cert.pem")
52+
keyfile = get_ssl_filename("server-key.pem")
53+
conn = SSLConnection(
54+
host=host,
55+
port=port,
56+
client_name=_CLIENT_NAME,
57+
ssl_ca_certs=certfile,
58+
socket_timeout=10,
59+
)
60+
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)
61+
62+
63+
def _assert_connect(conn, server_address, certfile=None, keyfile=None):
64+
if isinstance(server_address, str):
65+
server = _RedisUDSServer(server_address, _RedisRequestHandler)
66+
else:
67+
server = _RedisTCPServer(
68+
server_address, _RedisRequestHandler, certfile=certfile, keyfile=keyfile
69+
)
70+
with server as aserver:
71+
t = threading.Thread(target=aserver.serve_forever)
72+
t.start()
73+
try:
74+
aserver.wait_online()
75+
conn.connect()
76+
conn.disconnect()
77+
finally:
78+
aserver.stop()
79+
t.join(timeout=5)
80+
81+
82+
class _RedisTCPServer(socketserver.TCPServer):
83+
def __init__(self, *args, certfile=None, keyfile=None, **kw) -> None:
84+
self._ready_event = threading.Event()
85+
self._stop_requested = False
86+
self._certfile = certfile
87+
self._keyfile = keyfile
88+
super().__init__(*args, **kw)
89+
90+
def service_actions(self):
91+
self._ready_event.set()
92+
93+
def wait_online(self):
94+
self._ready_event.wait()
95+
96+
def stop(self):
97+
self._stop_requested = True
98+
self.shutdown()
99+
100+
def is_serving(self):
101+
return not self._stop_requested
102+
103+
def get_request(self):
104+
if self._certfile is None:
105+
return super().get_request()
106+
newsocket, fromaddr = self.socket.accept()
107+
connstream = ssl.wrap_socket(
108+
newsocket,
109+
server_side=True,
110+
certfile=self._certfile,
111+
keyfile=self._keyfile,
112+
ssl_version=ssl.PROTOCOL_TLSv1_2,
113+
)
114+
return connstream, fromaddr
115+
116+
117+
class _RedisUDSServer(socketserver.UnixStreamServer):
118+
def __init__(self, *args, **kw) -> None:
119+
self._ready_event = threading.Event()
120+
self._stop_requested = False
121+
super().__init__(*args, **kw)
122+
123+
def service_actions(self):
124+
self._ready_event.set()
125+
126+
def wait_online(self):
127+
self._ready_event.wait()
128+
129+
def stop(self):
130+
self._stop_requested = True
131+
self.shutdown()
132+
133+
def is_serving(self):
134+
return not self._stop_requested
135+
136+
137+
class _RedisRequestHandler(socketserver.StreamRequestHandler):
138+
def setup(self):
139+
_logger.info("%s connected", self.client_address)
140+
141+
def finish(self):
142+
_logger.info("%s disconnected", self.client_address)
143+
144+
def handle(self):
145+
buffer = b""
146+
command = None
147+
command_ptr = None
148+
fragment_length = None
149+
while self.server.is_serving() or buffer:
150+
try:
151+
buffer += self.request.recv(1024)
152+
except socket.timeout:
153+
continue
154+
if not buffer:
155+
continue
156+
parts = re.split(_CMD_SEP, buffer)
157+
buffer = parts[-1]
158+
for fragment in parts[:-1]:
159+
fragment = fragment.decode()
160+
_logger.info("Command fragment: %s", fragment)
161+
162+
if fragment.startswith("*") and command is None:
163+
command = [None for _ in range(int(fragment[1:]))]
164+
command_ptr = 0
165+
fragment_length = None
166+
continue
167+
168+
if fragment.startswith("$") and command[command_ptr] is None:
169+
fragment_length = int(fragment[1:])
170+
continue
171+
172+
assert len(fragment) == fragment_length
173+
command[command_ptr] = fragment
174+
command_ptr += 1
175+
176+
if command_ptr < len(command):
177+
continue
178+
179+
command = " ".join(command)
180+
_logger.info("Command %s", command)
181+
resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP)
182+
_logger.info("Response %s", resp)
183+
self.request.sendall(resp)
184+
command = None
185+
_logger.info("Exit handler")

0 commit comments

Comments
 (0)