Skip to content

test: error handle, state mgmt, backoff, timeouts #1546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
339 changes: 203 additions & 136 deletions testinfra/test_ami_nix.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
import requests
import socket
import testinfra
import time
from botocore.exceptions import ClientError
from ec2instanceconnectcli.EC2InstanceConnectLogger import EC2InstanceConnectLogger
from ec2instanceconnectcli.EC2InstanceConnectKey import EC2InstanceConnectKey
from time import sleep
from typing import Optional, Dict, Any, List, Callable
from functools import wraps

# if GITHUB_RUN_ID is not set, use a default value that includes the user and hostname
RUN_ID = os.environ.get(
Expand Down Expand Up @@ -162,37 +166,72 @@
}}
"""

# Configure logging
logger = logging.getLogger("ami-tests")
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s %(name)-12s %(levelname)-8s %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)


# scope='session' uses the same container for all the tests;
# scope='function' uses a new container per test function.
@pytest.fixture(scope="session")
def host():
ec2 = boto3.resource("ec2", region_name="ap-southeast-1")
images = list(
ec2.images.filter(
Filters=[{"Name": "name", "Values": [AMI_NAME]}],
)
)
assert len(images) == 1
image = images[0]

def gzip_then_base64_encode(s: str) -> str:
return base64.b64encode(gzip.compress(s.encode())).decode()

instance = list(
ec2.create_instances(
# Constants
MAX_RETRIES = 5
INITIAL_RETRY_DELAY = 2
MAX_RETRY_DELAY = 32
AWS_REGION = "ap-southeast-1"
INSTANCE_TYPE = "t4g.micro"
SECURITY_GROUPS = ["sg-0a883ca614ebfbae0", "sg-014d326be5a1627dc"]
SSH_PORT = 22
SSH_TIMEOUT = 60
HEALTH_CHECK_TIMEOUT = 300 # 5 minutes
HEALTH_CHECK_INTERVAL = 5

def retry_with_backoff(
max_retries: int = MAX_RETRIES,
initial_delay: int = INITIAL_RETRY_DELAY,
max_delay: int = MAX_RETRY_DELAY,
exceptions: tuple = (Exception,),
):
"""Decorator that implements exponential backoff for retrying operations."""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
delay = initial_delay
for i in range(max_retries):
try:
return func(*args, **kwargs)
except exceptions as e:
if i == max_retries - 1:
raise
logger.warning(
f"Attempt {i + 1} failed: {str(e)}. Retrying in {delay} seconds..."
)
sleep(delay)
delay = min(delay * 2, max_delay)
return None
return wrapper
return decorator

def validate_aws_resources(ec2_client) -> None:
"""Validate AWS resources before instance creation."""
try:
# Check security groups
for sg in SECURITY_GROUPS:
ec2_client.describe_security_groups(GroupIds=[sg])
logger.info("AWS resources validation successful")
except ClientError as e:
logger.error(f"AWS resource validation failed: {str(e)}")
raise

def create_ec2_instance(ec2_resource, image_id: str, user_data: str) -> Any:
"""Create EC2 instance with proper error handling."""
try:
instances = ec2_resource.create_instances(
BlockDeviceMappings=[
{
"DeviceName": "/dev/sda1",
"Ebs": {
"VolumeSize": 8, # gb
"VolumeSize": 8,
"Encrypted": True,
"DeleteOnTermination": True,
"VolumeType": "gp3",
Expand All @@ -203,35 +242,18 @@ def gzip_then_base64_encode(s: str) -> str:
"HttpTokens": "required",
"HttpEndpoint": "enabled",
},
IamInstanceProfile={"Name": "pg-ap-southeast-1"},
InstanceType="t4g.micro",
InstanceType=INSTANCE_TYPE,
MinCount=1,
MaxCount=1,
ImageId=image.id,
ImageId=image_id,
NetworkInterfaces=[
{
"DeviceIndex": 0,
"AssociatePublicIpAddress": True,
"Groups": ["sg-0a883ca614ebfbae0", "sg-014d326be5a1627dc"],
"Groups": SECURITY_GROUPS,
}
],
UserData=f"""#cloud-config
hostname: db-aaaaaaaaaaaaaaaaaaaa
write_files:
- {{path: /etc/postgresql.schema.sql, content: {gzip_then_base64_encode(postgresql_schema_sql_content)}, permissions: '0600', encoding: gz+b64}}
- {{path: /etc/realtime.env, content: {gzip_then_base64_encode(realtime_env_content)}, permissions: '0664', encoding: gz+b64}}
- {{path: /etc/adminapi/adminapi.yaml, content: {gzip_then_base64_encode(adminapi_yaml_content)}, permissions: '0600', owner: 'adminapi:root', encoding: gz+b64}}
- {{path: /etc/postgresql-custom/pgsodium_root.key, content: {gzip_then_base64_encode(pgsodium_root_key_content)}, permissions: '0600', owner: 'postgres:postgres', encoding: gz+b64}}
- {{path: /etc/postgrest/base.conf, content: {gzip_then_base64_encode(postgrest_base_conf_content)}, permissions: '0664', encoding: gz+b64}}
- {{path: /etc/gotrue.env, content: {gzip_then_base64_encode(gotrue_env_content)}, permissions: '0664', encoding: gz+b64}}
- {{path: /etc/wal-g/config.json, content: {gzip_then_base64_encode(walg_config_json_content)}, permissions: '0664', owner: 'wal-g:wal-g', encoding: gz+b64}}
- {{path: /tmp/init.json, content: {gzip_then_base64_encode(init_json_content)}, permissions: '0600', encoding: gz+b64}}
runcmd:
- 'sudo echo \"pgbouncer\" \"postgres\" >> /etc/pgbouncer/userlist.txt'
- 'cd /tmp && aws s3 cp --region ap-southeast-1 s3://init-scripts-staging/project/init.sh .'
- 'bash init.sh "staging"'
- 'rm -rf /tmp/*'
""",
UserData=user_data,
TagSpecifications=[
{
"ResourceType": "instance",
Expand All @@ -243,111 +265,156 @@ def gzip_then_base64_encode(s: str) -> str:
}
],
)
)[0]
instance.wait_until_running()

ec2logger = EC2InstanceConnectLogger(debug=False)
temp_key = EC2InstanceConnectKey(ec2logger.get_logger())
ec2ic = boto3.client("ec2-instance-connect", region_name="ap-southeast-1")
response = ec2ic.send_ssh_public_key(
InstanceId=instance.id,
InstanceOSUser="ubuntu",
SSHPublicKey=temp_key.get_pub_key(),
)
assert response["Success"]

# instance doesn't have public ip yet
return instances[0]
except ClientError as e:
logger.error(f"Failed to create EC2 instance: {str(e)}")
raise

@retry_with_backoff()
def wait_for_instance_running(instance) -> None:
"""Wait for instance to be in running state with retries."""
try:
instance.wait_until_running()
logger.info("Instance is running")
except Exception as e:
logger.error(f"Failed to wait for instance running state: {str(e)}")
raise

@retry_with_backoff()
def wait_for_public_ip(instance) -> str:
"""Wait for instance to have a public IP with retries."""
while not instance.public_ip_address:
logger.warning("waiting for ip to be available")
logger.warning("Waiting for public IP to be available")
sleep(5)
instance.reload()
return instance.public_ip_address

@retry_with_backoff()
def wait_for_ssh(ip_address: str) -> None:
"""Wait for SSH to be available with retries."""
while True:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if sock.connect_ex((instance.public_ip_address, 22)) == 0:
break
else:
logger.warning("waiting for ssh to be available")
sleep(10)

def get_ssh_connection(instance_ip, ssh_identity_file, max_retries=10):
for attempt in range(max_retries):
try:
return testinfra.get_host(
f"paramiko://ubuntu@{instance_ip}?timeout=60",
ssh_identity_file=ssh_identity_file,
)
except Exception as e:
if attempt == max_retries - 1:
raise
logger.warning(
f"Ssh connection failed, retrying: {attempt + 1}/{max_retries} failed, retrying ..."
)
sleep(5)

host = get_ssh_connection(
# paramiko is an ssh backend
instance.public_ip_address,
temp_key.get_priv_key_file(),
try:
if sock.connect_ex((ip_address, SSH_PORT)) == 0:
logger.info("SSH is available")
return
finally:
sock.close()
logger.warning("Waiting for SSH to be available")
sleep(10)

@retry_with_backoff()
def get_ssh_connection(instance_ip: str) -> Any:
"""Get SSH connection with retries."""
return testinfra.get_host(
f"paramiko://ubuntu@{instance_ip}?timeout={SSH_TIMEOUT}"
)

def is_healthy(host, instance_ip, ssh_identity_file) -> bool:
health_checks = [
(
"postgres",
lambda h: h.run("sudo -u postgres /usr/bin/pg_isready -U postgres"),
),
(
"adminapi",
lambda h: h.run(
f"curl -sf -k --connect-timeout 30 --max-time 60 https://localhost:8085/health -H 'apikey: {supabase_admin_key}'"
),
),
(
"postgrest",
lambda h: h.run(
"curl -sf --connect-timeout 30 --max-time 60 http://localhost:3001/ready"
),
),
(
"gotrue",
lambda h: h.run(
"curl -sf --connect-timeout 30 --max-time 60 http://localhost:8081/health"
),
),
("kong", lambda h: h.run("sudo kong health")),
("fail2ban", lambda h: h.run("sudo fail2ban-client status")),
]

for service, check in health_checks:
try:
cmd = check(host)
if cmd.failed is True:
logger.warning(f"{service} not ready")
return False
except Exception:
logger.warning(
f"Connection failed during {service} check, attempting reconnect..."
)
host = get_ssh_connection(instance_ip, ssh_identity_file)
return False

def check_service_health(host: Any, service: str, check: Callable) -> bool:
"""Check health of a specific service."""
try:
cmd = check(host)
if cmd.failed:
logger.warning(f"{service} not ready")
return False
return True
except Exception as e:
logger.warning(f"Connection failed during {service} check: {str(e)}")
return False

def is_healthy(host: Any, instance_ip: str) -> bool:
"""Check if all services are healthy."""
health_checks = [
("postgres", lambda h: h.run("sudo -u postgres /usr/bin/pg_isready -U postgres")),
("adminapi", lambda h: h.run(
f"curl -sf -k --connect-timeout 30 --max-time 60 https://localhost:8085/health -H 'apikey: {supabase_admin_key}'"
)),
("postgrest", lambda h: h.run(
"curl -sf --connect-timeout 30 --max-time 60 http://localhost:3001/ready"
)),
("gotrue", lambda h: h.run(
"curl -sf --connect-timeout 30 --max-time 60 http://localhost:8081/health"
)),
("kong", lambda h: h.run("sudo kong health")),
("fail2ban", lambda h: h.run("sudo fail2ban-client status")),
]

for service, check in health_checks:
if not check_service_health(host, service, check):
return False
return True

def wait_for_healthy(host: Any, instance_ip: str) -> None:
"""Wait for all services to be healthy with timeout."""
start_time = time.time()
while time.time() - start_time < HEALTH_CHECK_TIMEOUT:
if is_healthy(host, instance_ip):
logger.info("All services are healthy")
return
sleep(HEALTH_CHECK_INTERVAL)
raise TimeoutError("Services did not become healthy within timeout period")

while True:
if is_healthy(
host=host,
instance_ip=instance.public_ip_address,
ssh_identity_file=temp_key.get_priv_key_file(),
):
break
sleep(1)

# return a testinfra connection to the instance
yield host

# at the end of the test suite, destroy the instance
instance.terminate()
@pytest.fixture(scope="session")
def host():
"""Create and manage an EC2 instance for testing."""
instance = None
try:
# Initialize AWS clients using environment variables
ec2_resource = boto3.resource("ec2", region_name=AWS_REGION)
ec2_client = boto3.client("ec2", region_name=AWS_REGION)

# Validate AWS resources (now only checks security groups)
validate_aws_resources(ec2_client)

# Get AMI
images = list(ec2_resource.images.filter(
Filters=[{"Name": "name", "Values": [AMI_NAME]}],
))
if len(images) != 1:
raise ValueError(f"Expected exactly one AMI, found {len(images)}")
image = images[0]

def gzip_then_base64_encode(s: str) -> str:
return base64.b64encode(gzip.compress(s.encode())).decode()

# Modified user data to remove AWS-specific commands
user_data = f"""#cloud-config
hostname: db-aaaaaaaaaaaaaaaaaaaa
write_files:
- {{path: /etc/postgresql.schema.sql, content: {gzip_then_base64_encode(postgresql_schema_sql_content)}, permissions: '0600', encoding: gz+b64}}
- {{path: /etc/realtime.env, content: {gzip_then_base64_encode(realtime_env_content)}, permissions: '0664', encoding: gz+b64}}
- {{path: /etc/adminapi/adminapi.yaml, content: {gzip_then_base64_encode(adminapi_yaml_content)}, permissions: '0600', owner: 'adminapi:root', encoding: gz+b64}}
- {{path: /etc/postgresql-custom/pgsodium_root.key, content: {gzip_then_base64_encode(pgsodium_root_key_content)}, permissions: '0600', owner: 'postgres:postgres', encoding: gz+b64}}
- {{path: /etc/postgrest/base.conf, content: {gzip_then_base64_encode(postgrest_base_conf_content)}, permissions: '0664', encoding: gz+b64}}
- {{path: /etc/gotrue.env, content: {gzip_then_base64_encode(gotrue_env_content)}, permissions: '0664', encoding: gz+b64}}
- {{path: /etc/wal-g/config.json, content: {gzip_then_base64_encode(walg_config_json_content)}, permissions: '0664', owner: 'wal-g:wal-g', encoding: gz+b64}}
- {{path: /tmp/init.json, content: {gzip_then_base64_encode(init_json_content)}, permissions: '0600', encoding: gz+b64}}
runcmd:
- 'sudo echo \"pgbouncer\" \"postgres\" >> /etc/pgbouncer/userlist.txt'
- 'bash /tmp/init.sh "staging"'
- 'rm -rf /tmp/*'
"""
instance = create_ec2_instance(ec2_resource, image.id, user_data)
logger.info(f"Created instance {instance.id}")

# Wait for instance to be running
wait_for_instance_running(instance)
instance_ip = wait_for_public_ip(instance)
wait_for_ssh(instance_ip)

# Get SSH connection
host = get_ssh_connection(instance_ip)
wait_for_healthy(host, instance_ip)

yield host

finally:
if instance:
try:
instance.terminate()
logger.info(f"Terminated instance {instance.id}")
except Exception as e:
logger.error(f"Failed to terminate instance: {str(e)}")


def test_postgrest_is_running(host):
Expand Down
Loading