Commit 7e05d747 authored by Joshua Tauberer's avatar Joshua Tauberer

run status checks asynchronously so that they finish faster, since many checks...

run status checks asynchronously so that they finish faster, since many checks are waiting on network replies and ought not to block the whole thing
parent 8fd98d7d
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
__ALL__ = ['check_certificate'] __ALL__ = ['check_certificate']
import os, os.path, re, subprocess, datetime import os, os.path, re, subprocess, datetime, multiprocessing.pool
import dns.reversename, dns.resolver import dns.reversename, dns.resolver
import dateutil.parser, dateutil.tz import dateutil.parser, dateutil.tz
...@@ -35,15 +35,17 @@ def run_checks(env, output): ...@@ -35,15 +35,17 @@ def run_checks(env, output):
run_system_checks(env, output) run_system_checks(env, output)
# perform other checks # perform other checks asynchronously
run_network_checks(env, output)
run_domain_checks(env, output) pool = multiprocessing.pool.Pool(processes=1)
r1 = pool.apply_async(run_network_checks, [env])
r2 = run_domain_checks(env)
r1.get().playback(output)
r2.playback(output)
def run_services_checks(env, output): def run_services_checks(env, output):
# Check that system services are running. # Check that system services are running.
import socket
services = [ services = [
{ "name": "Local DNS (bind9)", "port": 53, "public": False, }, { "name": "Local DNS (bind9)", "port": 53, "public": False, },
#{ "name": "NSD Control", "port": 8952, "public": False, }, #{ "name": "NSD Control", "port": 8952, "public": False, },
...@@ -66,33 +68,47 @@ def run_services_checks(env, output): ...@@ -66,33 +68,47 @@ def run_services_checks(env, output):
{ "name": "HTTPS Web (nginx)", "port": 443, "public": True, }, { "name": "HTTPS Web (nginx)", "port": 443, "public": True, },
] ]
ok = True all_running = True
fatal = False
for service in services: pool = multiprocessing.pool.Pool(processes=10)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) ret = pool.starmap(check_service, ((i, service, env) for i, service in enumerate(services)), chunksize=1)
s.settimeout(.1) for i, running, fatal2, output2 in sorted(ret):
try: all_running = all_running and running
s.connect(( fatal = fatal or fatal2
"127.0.0.1" if not service["public"] else env['PUBLIC_IP'], output2.playback(output)
service["port"]))
except OSError as e: if all_running:
output.print_error("%s is not running (%s)." % (service['name'], str(e)))
# Why is nginx not running?
if service["port"] in (80, 443):
output.print_line(shell('check_output', ['nginx', '-t'], capture_stderr=True, trap=True)[1].strip())
# Flag if local DNS is not running.
if service["port"] == 53 and service["public"] == False:
ok = False
finally:
s.close()
if ok:
output.print_ok("All system services are running.") output.print_ok("All system services are running.")
return ok return not fatal
def check_service(i, service, env):
import socket
output = BufferedOutput()
running = False
fatal = False
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(1)
try:
s.connect((
"127.0.0.1" if not service["public"] else env['PUBLIC_IP'],
service["port"]))
running = True
except OSError as e:
output.print_error("%s is not running (%s)." % (service['name'], str(e)))
# Why is nginx not running?
if service["port"] in (80, 443):
output.print_line(shell('check_output', ['nginx', '-t'], capture_stderr=True, trap=True)[1].strip())
# Flag if local DNS is not running.
if service["port"] == 53 and service["public"] == False:
fatal = True
finally:
s.close()
return (i, running, fatal, output)
def run_system_checks(env, output): def run_system_checks(env, output):
check_ssh_password(env, output) check_ssh_password(env, output)
...@@ -146,9 +162,10 @@ def check_free_disk_space(env, output): ...@@ -146,9 +162,10 @@ def check_free_disk_space(env, output):
else: else:
output.print_error(disk_msg) output.print_error(disk_msg)
def run_network_checks(env, output): def run_network_checks(env):
# Also see setup/network-checks.sh. # Also see setup/network-checks.sh.
output = BufferedOutput()
output.add_heading("Network") output.add_heading("Network")
# Stop if we cannot make an outbound connection on port 25. Many residential # Stop if we cannot make an outbound connection on port 25. Many residential
...@@ -176,7 +193,9 @@ def run_network_checks(env, output): ...@@ -176,7 +193,9 @@ def run_network_checks(env, output):
which may prevent recipients from receiving your email. See http://www.spamhaus.org/query/ip/%s.""" which may prevent recipients from receiving your email. See http://www.spamhaus.org/query/ip/%s."""
% (env['PUBLIC_IP'], zen, env['PUBLIC_IP'])) % (env['PUBLIC_IP'], zen, env['PUBLIC_IP']))
def run_domain_checks(env, output): return output
def run_domain_checks(env):
# Get the list of domains we handle mail for. # Get the list of domains we handle mail for.
mail_domains = get_mail_domains(env) mail_domains = get_mail_domains(env)
...@@ -187,24 +206,44 @@ def run_domain_checks(env, output): ...@@ -187,24 +206,44 @@ def run_domain_checks(env, output):
# Get the list of domains we serve HTTPS for. # Get the list of domains we serve HTTPS for.
web_domains = set(get_web_domains(env)) web_domains = set(get_web_domains(env))
# Check the domains. domains_to_check = mail_domains | dns_domains | web_domains
for domain in sort_domains(mail_domains | dns_domains | web_domains, env):
output.add_heading(domain) # Serial version:
#for domain in sort_domains(domains_to_check, env):
# run_domain_checks_on_domain(domain, env, dns_domains, dns_zonefiles, mail_domains, web_domains)
# Parallelize the checks across a worker pool.
args = ((domain, env, dns_domains, dns_zonefiles, mail_domains, web_domains)
for domain in domains_to_check)
pool = multiprocessing.pool.Pool(processes=10)
ret = pool.starmap(run_domain_checks_on_domain, args, chunksize=1)
ret = dict(ret) # (domain, output) => { domain: output }
output = BufferedOutput()
for domain in sort_domains(ret, env):
ret[domain].playback(output)
return output
def run_domain_checks_on_domain(domain, env, dns_domains, dns_zonefiles, mail_domains, web_domains):
output = BufferedOutput()
if domain == env["PRIMARY_HOSTNAME"]: output.add_heading(domain)
check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles)
if domain == env["PRIMARY_HOSTNAME"]:
check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles)
if domain in dns_domains: if domain in dns_domains:
check_dns_zone(domain, env, output, dns_zonefiles) check_dns_zone(domain, env, output, dns_zonefiles)
if domain in mail_domains: if domain in mail_domains:
check_mail_domain(domain, env, output) check_mail_domain(domain, env, output)
if domain in web_domains:
check_web_domain(domain, env, output)
if domain in web_domains: if domain in dns_domains:
check_web_domain(domain, env, output) check_dns_zone_suggestions(domain, env, output, dns_zonefiles)
if domain in dns_domains: return (domain, output)
check_dns_zone_suggestions(domain, env, output, dns_zonefiles)
def check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles): def check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles):
# If a DS record is set on the zone containing this domain, check DNSSEC now. # If a DS record is set on the zone containing this domain, check DNSSEC now.
...@@ -655,11 +694,12 @@ def list_apt_updates(apt_update=True): ...@@ -655,11 +694,12 @@ def list_apt_updates(apt_update=True):
return pkgs return pkgs
try:
terminal_columns = int(shell('check_output', ['stty', 'size']).split()[1])
except:
terminal_columns = 76
class ConsoleOutput: class ConsoleOutput:
try:
terminal_columns = int(shell('check_output', ['stty', 'size']).split()[1])
except:
terminal_columns = 76
def add_heading(self, heading): def add_heading(self, heading):
print() print()
print(heading) print(heading)
...@@ -680,7 +720,7 @@ class ConsoleOutput: ...@@ -680,7 +720,7 @@ class ConsoleOutput:
words = re.split("(\s+)", message) words = re.split("(\s+)", message)
linelen = 0 linelen = 0
for w in words: for w in words:
if linelen + len(w) > terminal_columns-1-len(first_line): if linelen + len(w) > self.terminal_columns-1-len(first_line):
print() print()
print(" ", end="") print(" ", end="")
linelen = 0 linelen = 0
...@@ -693,6 +733,21 @@ class ConsoleOutput: ...@@ -693,6 +733,21 @@ class ConsoleOutput:
for line in message.split("\n"): for line in message.split("\n"):
self.print_block(line) self.print_block(line)
class BufferedOutput:
# Record all of the instance method calls so we can play them back later.
def __init__(self):
self.buf = []
def __getattr__(self, attr):
if attr not in ("add_heading", "print_ok", "print_error", "print_warning", "print_block", "print_line"):
raise AttributeError
# Return a function that just records the call & arguments to our buffer.
def w(*args, **kwargs):
self.buf.append((attr, args, kwargs))
return w
def playback(self, output):
for attr, args, kwargs in self.buf:
getattr(output, attr)(*args, **kwargs)
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
from utils import load_environment from utils import load_environment
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment