maubot-shame/shameotron.py

194 lines
5.7 KiB
Python
Raw Normal View History

2020-07-14 17:34:39 +01:00
"""
[Maubot](https://mau.dev/maubot/maubot) plugin to shame room members into
upgrading their Matrix homeservers to the latest version.
"""
import json
2020-07-20 18:37:52 +01:00
from datetime import datetime
import socket
import ssl
2020-07-14 17:34:39 +01:00
from typing import Dict, List, Type
import requests
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
2021-06-01 17:08:43 +01:00
from mautrix.types import (
TextMessageEventContent,
MessageType,
Format,
EventID,
RoomID,
UserID,
)
2020-07-14 17:34:39 +01:00
from mautrix.util import markdown
from maubot import Plugin, MessageEvent
from maubot.handlers import command
class Config(BaseProxyConfig):
"""
Config class
"""
2021-06-01 17:08:43 +01:00
2020-07-14 17:34:39 +01:00
def do_update(self, helper: ConfigUpdateHelper) -> None:
"""
Class method to update the config
"""
helper.copy('federation_tester')
helper.copy('dead_servers')
class ShameOTron(Plugin):
"""
Main class for the Shame-o-Tron
"""
2021-06-01 17:08:43 +01:00
2020-07-14 17:34:39 +01:00
async def start(self) -> None:
"""
Class method for plugin startup
"""
self.on_external_config_update()
@classmethod
def get_config_class(cls) -> Type[Config]:
"""
Class method for getting the config
"""
return Config
async def _edit(self, room_id: RoomID, event_id: EventID, text: str) -> None:
"""
Class method to update an existing message event
"""
2020-07-14 17:47:32 +01:00
content = TextMessageEventContent(
msgtype=MessageType.NOTICE,
body=text,
format=Format.HTML,
2021-06-01 17:08:43 +01:00
formatted_body=markdown.render(text),
2020-07-14 17:47:32 +01:00
)
2020-07-14 17:34:39 +01:00
content.set_edit(event_id)
await self.client.send_message(room_id, content)
async def _load_members(self, room_id: RoomID) -> Dict[str, List[UserID]]:
"""
Class method to return the servers and room members
"""
users = await self.client.get_joined_members(room_id)
servers: Dict[str, List[UserID]] = {}
for user in users:
_, server = self.client.parse_user_id(user)
servers.setdefault(server, []).append(user)
return servers
2020-09-16 15:28:58 +01:00
async def get_ssl_expiry(self, addr, host):
2020-07-14 17:34:39 +01:00
"""
2020-07-20 18:37:52 +01:00
Class method to return the expiry date of a specific instance
"""
ssl_date_fmt = r'%b %d %H:%M:%S %Y %Z'
(hostname, port) = addr.split(':')
context = ssl.create_default_context()
2020-09-16 15:28:58 +01:00
context.check_hostname = True
2020-07-20 18:37:52 +01:00
context.verify_mode = ssl.CERT_OPTIONAL
conn = context.wrap_socket(
socket.socket(socket.AF_INET),
2020-09-16 15:28:58 +01:00
server_hostname=host,
2020-07-20 18:37:52 +01:00
)
conn.settimeout(10.0)
conn.connect((hostname, int(port)))
ssl_info = conn.getpeercert()
# parse the string from the certificate into a Python datetime object
return datetime.strptime(ssl_info['notAfter'], ssl_date_fmt)
async def query_homeserver(self, host):
"""
Class method to query the Federation Tester to retrieve the running
version for a server
2020-07-14 17:34:39 +01:00
host: (str) Server to get version for
Returns: (str) Version string of the server
"""
2020-07-20 18:37:52 +01:00
version = None
2020-07-14 17:34:39 +01:00
try:
req = requests.get(
2021-06-01 17:08:43 +01:00
self.config["federation_tester"].format(server=host), timeout=10000
2020-07-14 17:34:39 +01:00
)
except requests.exceptions.Timeout:
2020-07-20 18:37:52 +01:00
version = '[TIMEOUT]'
2020-07-14 17:34:39 +01:00
data = json.loads(req.text)
if not data['FederationOK']:
2020-07-20 18:37:52 +01:00
version = '[OFFLINE]'
2020-07-14 17:34:39 +01:00
try:
2020-07-20 18:37:52 +01:00
addr = list(data['ConnectionReports'].keys())[0]
2020-09-16 15:28:58 +01:00
ssl_expiry = await self.get_ssl_expiry(addr, host)
2021-06-01 17:08:43 +01:00
except (ssl.SSLCertVerificationError, ssl.SSLError, IndexError) as error:
2020-09-25 09:27:36 +01:00
self.log.warning('SSL error for: %s: %s', host, error)
2020-07-20 18:37:52 +01:00
ssl_expiry = None
2020-07-26 11:04:17 +01:00
2020-07-20 18:37:52 +01:00
try:
if not version:
version = data['Version']['version']
2020-09-16 15:28:58 +01:00
except (TypeError, KeyError) as error:
self.log.error(error)
2020-07-20 18:37:52 +01:00
version = '[ERROR]'
2021-06-01 17:08:43 +01:00
return {'version': version, 'ssl_expiry': ssl_expiry}
2020-07-14 17:34:39 +01:00
@command.new('shame', help='Show versions of all homeservers in the room')
@command.argument("candidate", pass_raw=True, required=False)
async def shame_handler(self, evt: MessageEvent, candidate: str = None) -> None:
2020-07-14 17:34:39 +01:00
"""
Class method to handle the `!shame` command
"""
event_id = await evt.reply('Loading member list...')
if candidate:
member_servers = [candidate]
else:
member_servers = await self._load_members(evt.room_id)
# Filter out the "dead servers"
dead_servers = self.config['dead_servers']
if dead_servers:
# Return a unique list
2021-06-01 17:08:43 +01:00
member_servers = list(set(member_servers.keys() - set(dead_servers)))
2020-07-14 17:34:39 +01:00
2020-07-20 18:37:52 +01:00
await self._edit(
evt.room_id,
event_id,
2021-06-01 17:08:43 +01:00
'Member list loaded, fetching versions... please wait...',
2020-07-20 18:37:52 +01:00
)
2020-07-14 17:34:39 +01:00
versions = []
2021-06-01 17:08:43 +01:00
for host in sorted(member_servers):
2020-07-20 18:37:52 +01:00
data = await self.query_homeserver(host)
warning = ''
2020-09-16 15:28:58 +01:00
now = datetime.now()
2020-07-26 11:04:17 +01:00
if data['ssl_expiry']:
2020-09-16 15:28:58 +01:00
expiry_days = (data['ssl_expiry'] - now).days
if expiry_days < 30:
warning = f'(cert expiry in {expiry_days} days!)'
else:
warning = '(SSL error)'
2020-07-20 18:37:52 +01:00
2021-06-01 17:08:43 +01:00
versions.append((host, f"{data['version']} {warning}"))
2020-07-14 17:34:39 +01:00
await self._edit(
evt.room_id,
event_id,
(
'#### Homeserver versions\n'
+ '\n'.join(
f'* {host}: [{version}]({self.config["federation_tester"].format(server=host)})'
for host, version in versions
)
2021-06-01 17:08:43 +01:00
),
2020-07-14 17:34:39 +01:00
)