Code tidy up

This commit is contained in:
Scott Wallace 2025-03-05 17:29:06 +00:00
parent ec2f5ce902
commit d207312a2e
Signed by: scott
SSH key fingerprint: SHA256:+LJug6Dj01Jdg86CILGng9r0lJseUrpI0xfRqdW9Uws

View file

@ -1,12 +1,13 @@
#!python3 """Display TLS certificate details for an HTTP endpoint."""
"""
Return the Akamai property and version for a given site from __future__ import annotations
"""
import argparse import argparse
import socket import socket
import ssl import ssl
import sys import sys
import urllib.error import urllib.error
from contextlib import suppress
from typing import Any from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
@ -20,8 +21,7 @@ SAN_GROUPING = 4
def get_cert_with_servername(addr: tuple[str, int], servername: str = "") -> bytes: def get_cert_with_servername(addr: tuple[str, int], servername: str = "") -> bytes:
""" """Get TLS certificate from an address with an explicit servername override.
Get TLS certificate from an address with an explicit servername override
Args: Args:
addr (tuple[str, int]): adress in tuple form (address, port) addr (tuple[str, int]): adress in tuple form (address, port)
@ -29,51 +29,49 @@ def get_cert_with_servername(addr: tuple[str, int], servername: str = "") -> byt
Returns: Returns:
bytes: PEM bytes bytes: PEM bytes
""" """
context = ssl.create_default_context() context = ssl.create_default_context()
context.check_hostname = False context.check_hostname = False
with socket.create_connection((addr[0], addr[1]), timeout=10) as sock: with socket.create_connection((addr[0], addr[1]), timeout=10) as sock, context.wrap_socket(
with context.wrap_socket(sock, server_hostname=servername) as sslsock: sock,
if der_cert := sslsock.getpeercert(True): server_hostname=servername,
return ssl.DER_cert_to_PEM_cert(der_cert).encode("utf=8") ) as sslsock:
if der_cert := sslsock.getpeercert(binary_form=True):
return ssl.DER_cert_to_PEM_cert(der_cert).encode("utf=8")
return bytes() return b""
def format_fingerprint(fingerprint: bytes | str) -> str: def format_fingerprint(fingerprint: bytes | str) -> str:
""" """Print a fingerprint as a colon-separated hex string.
Print a fingerprint as a colon-separated hex string
Args: Args:
fingerprint (bytes | str): fingerprint to format fingerprint (bytes | str): fingerprint to format
Returns: Returns:
str: formatted fingerprint str: formatted fingerprint
""" """
if isinstance(fingerprint, str): if isinstance(fingerprint, str):
fingerprint = bytearray.fromhex(fingerprint) fingerprint = bytearray.fromhex(fingerprint)
return ":".join([format(i, "02x") for i in fingerprint]) return ":".join([format(i, "02x") for i in fingerprint])
def display_error( def display_error(site: str, error: Any = None) -> None:
site: str, """Print a generic error."""
error: Any = None,
) -> None:
"""
Print a generic error
"""
print(f"ERROR: Could not find a certificate for {site}") print(f"ERROR: Could not find a certificate for {site}")
if error: if error:
print(str(error)) print(str(error))
def parseargs() -> argparse.Namespace: def parseargs() -> argparse.Namespace:
""" """Parse the CLI.
Parse the CLI
Returns: Returns:
argparse.Namespace: parsed arguments argparse.Namespace: parsed arguments
""" """
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -84,11 +82,11 @@ def parseargs() -> argparse.Namespace:
def main() -> int: def main() -> int:
""" """Run the code.
Main entrypoint
Returns: Returns:
int: return value int: return value
""" """
args = parseargs() args = parseargs()
@ -123,10 +121,8 @@ def main() -> int:
).encode("utf-8") ).encode("utf-8")
cert_chain = CertificateChain() cert_chain = CertificateChain()
try: with suppress(urllib.error.URLError):
cert_chain = resolve(pem_data) cert_chain = resolve(pem_data)
except urllib.error.URLError:
pass
except ( except (
ConnectionRefusedError, ConnectionRefusedError,
ConnectionResetError, ConnectionResetError,
@ -146,23 +142,20 @@ def main() -> int:
sans = [ sans = [
f"DNS:{dns}" f"DNS:{dns}"
for dns in cert.extensions.get_extension_for_class( for dns in cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value.get_values_for_type(
x509.SubjectAlternativeName x509.DNSName,
).value.get_values_for_type(x509.DNSName) )
] ]
sans.extend( sans.extend(
[ [
f"IP:{ip}" f"IP:{ip}"
for ip in cert.extensions.get_extension_for_class( for ip in cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value.get_values_for_type(
x509.SubjectAlternativeName x509.IPAddress,
).value.get_values_for_type(x509.IPAddress) )
] ],
) )
sangroups = [ sangroups = [sans[group : group + SAN_GROUPING] for group in range(0, len(sans), SAN_GROUPING)]
sans[group : group + SAN_GROUPING]
for group in range(0, len(sans), SAN_GROUPING)
]
table = [ table = [
["Common name", cert.subject.rfc4514_string()], ["Common name", cert.subject.rfc4514_string()],
@ -192,7 +185,7 @@ def main() -> int:
if cert if cert
], ],
), ),
] ],
) )
print(tabulate(table, tablefmt="plain")) print(tabulate(table, tablefmt="plain"))