utils.py 6.2 KB
Newer Older
1
import os.path
2 3 4 5

# DO NOT import non-standard modules. This module is imported by
# migrate.py which runs on fresh machines before anything is installed
# besides Python.
6

7
# THE ENVIRONMENT FILE AT /etc/mailinabox.conf
8

9 10
def load_environment():
    # Load settings from /etc/mailinabox.conf.
11
    return load_env_vars_from_file("/etc/mailinabox.conf")
Joshua Tauberer's avatar
Joshua Tauberer committed
12 13 14

def load_env_vars_from_file(fn):
    # Load settings from a KEY=VALUE file.
15 16
    import collections
    env = collections.OrderedDict()
Joshua Tauberer's avatar
Joshua Tauberer committed
17 18
    for line in open(fn): env.setdefault(*line.strip().split("=", 1))
    return env
19

20 21 22 23
def save_environment(env):
    with open("/etc/mailinabox.conf", "w") as f:
        for k, v in env.items():
            f.write("%s=%s\n" % (k, v))
24

25
# THE SETTINGS FILE AT STORAGE_ROOT/settings.yaml.
26

27
def write_settings(config, env):
28
    import rtyaml
29 30 31 32 33
    fn = os.path.join(env['STORAGE_ROOT'], 'settings.yaml')
    with open(fn, "w") as f:
        f.write(rtyaml.dump(config))

def load_settings(env):
34
    import rtyaml
35
    fn = os.path.join(env['STORAGE_ROOT'], 'settings.yaml')
36
    try:
37
        config = rtyaml.load(open(fn, "r"))
38 39
        if not isinstance(config, dict): raise ValueError() # caught below
        return config
40 41
    except:
        return { }
42

43
# UTILITIES
44

45 46 47 48 49
def safe_domain_name(name):
    # Sanitize a domain name so it is safe to use as a file name on disk.
    import urllib.parse
    return urllib.parse.quote(name, safe='')

50
def sort_domains(domain_names, env):
51 52 53 54 55 56 57 58 59 60 61 62 63 64
    # Put domain names in a nice sorted order.

    # The nice order will group domain names by DNS zone, i.e. the top-most
    # domain name that we serve that ecompasses a set of subdomains. Map
    # each of the domain names to the zone that contains them. Walk the domains
    # from shortest to longest since zones are always shorter than their
    # subdomains.
    zones = { }
    for domain in sorted(domain_names, key=lambda d : len(d)):
        for z in zones.values():
            if domain.endswith("." + z):
                # We found a parent domain already in the list.
                zones[domain] = z
                break
65
        else:
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
            # 'break' did not occur: there is no parent domain, so it is its
            # own zone.
            zones[domain] = domain

    # Sort the zones.
    zone_domains = sorted(zones.values(),
      key = lambda d : (
        # PRIMARY_HOSTNAME or the zone that contains it is always first.
        not (d == env['PRIMARY_HOSTNAME'] or env['PRIMARY_HOSTNAME'].endswith("." + d)),

        # Then just dumb lexicographically.
        d,
      ))

    # Now sort the domain names that fall within each zone.
    domain_names = sorted(domain_names,
      key = lambda d : (
        # First by zone.
        zone_domains.index(zones[d]),

        # PRIMARY_HOSTNAME is always first within the zone that contains it.
        d != env['PRIMARY_HOSTNAME'],

        # Followed by any of its subdomains.
        not d.endswith("." + env['PRIMARY_HOSTNAME']),

        # Then in right-to-left lexicographic order of the .-separated parts of the name.
        list(reversed(d.split("."))),
      ))
    
    return domain_names
97

98 99 100 101 102 103 104 105 106 107 108
def sort_email_addresses(email_addresses, env):
    email_addresses = set(email_addresses)
    domains = set(email.split("@", 1)[1] for email in email_addresses if "@" in email)
    ret = []
    for domain in sort_domains(domains, env):
        domain_emails = set(email for email in email_addresses if email.endswith("@" + domain))
        ret.extend(sorted(domain_emails))
        email_addresses -= domain_emails
    ret.extend(sorted(email_addresses)) # whatever is left
    return ret

109
def shell(method, cmd_args, env={}, capture_stderr=False, return_bytes=False, trap=False, input=None):
110 111 112
    # A safe way to execute processes.
    # Some processes like apt-get require being given a sane PATH.
    import subprocess
113

114
    env.update({ "PATH": "/sbin:/bin:/usr/sbin:/usr/bin" })
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
    kwargs = {
        'env': env,
        'stderr': None if not capture_stderr else subprocess.STDOUT,
    }
    if method == "check_output" and input is not None:
        kwargs['input'] = input

    if not trap:
        ret = getattr(subprocess, method)(cmd_args, **kwargs)
    else:
        try:
            ret = getattr(subprocess, method)(cmd_args, **kwargs)
            code = 0
        except subprocess.CalledProcessError as e:
            ret = e.output
            code = e.returncode
131
    if not return_bytes and isinstance(ret, bytes): ret = ret.decode("utf8")
132 133 134 135
    if not trap:
        return ret
    else:
        return code, ret
136 137 138 139 140 141

def create_syslog_handler():
    import logging.handlers
    handler = logging.handlers.SysLogHandler(address='/dev/log')
    handler.setLevel(logging.WARNING)
    return handler
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160

def du(path):
    # Computes the size of all files in the path, like the `du` command.
    # Based on http://stackoverflow.com/a/17936789. Takes into account
    # soft and hard links.
    total_size = 0
    seen = set()
    for dirpath, dirnames, filenames in os.walk(path):
        for f in filenames:
            fp = os.path.join(dirpath, f)
            try:
                stat = os.lstat(fp)
            except OSError:
                continue
            if stat.st_ino in seen:
                continue
            seen.add(stat.st_ino)
            total_size += stat.st_size
    return total_size
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176

def wait_for_service(port, public, env, timeout):
	# Block until a service on a given port (bound privately or publicly)
	# is taking connections, with a maximum timeout.
	import socket, time
	start = time.perf_counter()
	while True:
		s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
		s.settimeout(timeout/3)
		try:
			s.connect(("127.0.0.1" if not public else env['PUBLIC_IP'], port))
			return True
		except OSError:
			if time.perf_counter() > start+timeout:
				return False
		time.sleep(min(timeout/4, 1))
177

178 179 180 181 182 183 184 185
def fix_boto():
	# Google Compute Engine instances install some Python-2-only boto plugins that
	# conflict with boto running under Python 3. Disable boto's default configuration
	# file prior to importing boto so that GCE's plugin is not loaded:
	import os
	os.environ["BOTO_CONFIG"] = "/etc/boto3.cfg"


186
if __name__ == "__main__":
187
	from web_update import get_web_domains
188
	env = load_environment()
189
	domains = get_web_domains(env)
190 191
	for domain in domains:
		print(domain)