"""Display TLS certificate details for an HTTP endpoint.""" from __future__ import annotations import argparse import socket import ssl import sys import urllib.error from contextlib import suppress from typing import Any from urllib.parse import urlparse from cert_chain_resolver.api import CertificateChain, resolve from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from tabulate import tabulate SAN_GROUPING = 4 def get_cert_with_servername(addr: tuple[str, int], servername: str = "") -> bytes: """Get TLS certificate from an address with an explicit servername override. Args: addr (tuple[str, int]): adress in tuple form (address, port) servername (str): SNI servername Returns: bytes: PEM bytes """ context = ssl.create_default_context() context.check_hostname = False with socket.create_connection((addr[0], addr[1]), timeout=10) as sock, context.wrap_socket( sock, server_hostname=servername, ) as sslsock: if der_cert := sslsock.getpeercert(binary_form=True): return ssl.DER_cert_to_PEM_cert(der_cert).encode("utf=8") return b"" def format_fingerprint(fingerprint: bytes | str) -> str: """Print a fingerprint as a colon-separated hex string. Args: fingerprint (bytes | str): fingerprint to format Returns: str: formatted fingerprint """ if isinstance(fingerprint, str): fingerprint = bytearray.fromhex(fingerprint) return ":".join([format(i, "02x") for i in fingerprint]) def display_error(site: str, error: Any = None) -> None: """Print a generic error.""" print(f"ERROR: Could not find a certificate for {site}") if error: print(str(error)) def parseargs() -> argparse.Namespace: """Parse the CLI. Returns: argparse.Namespace: parsed arguments """ parser = argparse.ArgumentParser() parser.add_argument("site", help="site to lookup") parser.add_argument("-a", "--address", help="explicit address to connect to") return parser.parse_args() def main() -> int: """Run the code. Returns: int: return value """ args = parseargs() url = args.site if "://" not in url: url = f"https://{url}" parts = urlparse(args.site, scheme="https") if not parts.netloc: parts = parts._replace(netloc=args.site) if not parts.port: parts = parts._replace(netloc=f"{parts.netloc}:443") if not parts.hostname or not parts.port: display_error(args.site, "Cannot parse hostname") return 1 endpoint = f"{parts.hostname}:{parts.port}" try: if args.address: pem_data = get_cert_with_servername( (args.address, parts.port), servername=parts.hostname, ) else: pem_data = ssl.get_server_certificate( (parts.hostname, parts.port), timeout=10, ).encode("utf-8") cert_chain = CertificateChain() with suppress(urllib.error.URLError): cert_chain = resolve(pem_data) except ( ConnectionRefusedError, ConnectionResetError, socket.gaierror, ssl.CertificateError, ssl.SSLError, TimeoutError, ) as error: display_error(endpoint, error) return 2 if not pem_data: display_error(endpoint, "Cannot fetch PEM data") return 3 cert = x509.load_pem_x509_certificate(pem_data, default_backend()) sans = [ f"DNS:{dns}" for dns in cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value.get_values_for_type( x509.DNSName, ) ] sans.extend( [ f"IP:{ip}" for ip in cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value.get_values_for_type( x509.IPAddress, ) ], ) sangroups = [sans[group : group + SAN_GROUPING] for group in range(0, len(sans), SAN_GROUPING)] table = [ ["Common name", cert.subject.rfc4514_string()], [f"SANs ({len(sans)})", tabulate(sangroups, tablefmt="plain")], ["Valid from", cert.not_valid_before_utc], ["Valid to", cert.not_valid_after_utc], ["Issuer", cert.issuer.rfc4514_string()], [ "Fingerprint", f"{format_fingerprint(cert.fingerprint(hashes.SHA1()))} (SHA1)\n" f"{format_fingerprint(cert.fingerprint(hashes.SHA256()))} (SHA256)\n", ], ] if cert_chain: table.append( [ "CA chain", "\n".join( [ f"{cert.common_name} " f"(Issuer: {cert.issuer})\n" "Fingerprint: \n" f"\t{format_fingerprint(cert.get_fingerprint(hashes.SHA1))} (SHA1)\n" f"\t{format_fingerprint(cert.get_fingerprint(hashes.SHA256))} (SHA256)" for cert in [*list(cert_chain.intermediates), cert_chain.root] if cert ], ), ], ) print(tabulate(table, tablefmt="plain")) return 0 if __name__ == "__main__": sys.exit(main())