'''
Source: https://github.com/mdomi/hosts
Original file has been modified to meet our needs. The primary differences:
  - Removed some of the in line comments
  - Removal of compare_ip function (not compatible with IPv6 addresses)
  - IP's stored as arrays (some OS's have multiple localhost definitions)

Copyright (c) 2012 Michael Dominice

Permission is hereby granted, free of charge, to any person
obtaining a copy of this software and associated documentation
files (the "Software"), to deal in the Software without
restriction, including without limitation the rights to use,
copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following
conditions:

The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
'''
import datetime
import os
import re
import socket


def get_created_comment():
    return '\n'.join(['# Autogenerated by hosts.py',
                      '# https://github.com/mdomi/hosts',
                      '# Updated: %s' % datetime.datetime.now()])


class Hosts(object):

    def __init__(self, path):
        self.hosts = {}
        self.read(path)

    def get_one(self, host_name, raise_on_not_found=False):
        if host_name in self.hosts:
            return self.hosts[host_name]
        try:
            socket.gethostbyname(host_name)
        except socket.gaierror:
            if raise_on_not_found:
                raise Exception('Unknown host: %s' % (host_name,))
        return '[Unknown]'

    def print_one(self, host_name):
        print(host_name, self.get_one(host_name))

    def print_all(self, host_names=None):
        print(self.hosts)
        if host_names is None:
            for host_name in self.hosts.keys():
                self.print_one(host_name)
        else:
            for host_name in host_names:
                self.print_one(host_name)

    def file_contents(self):
        reversed_hosts = {}
        for host_name in self.hosts.keys():
            ip_addresses = self.hosts[host_name]
            for ip_address in ip_addresses:
                if ip_address not in reversed_hosts:
                    reversed_hosts[ip_address] = [host_name]
                else:
                    reversed_hosts[ip_address].append(host_name)
        parts = []
        for ip_address in sorted(reversed_hosts.keys()):
            hosts = []
            for host_name in sorted(reversed_hosts[ip_address]):
                hosts.append(host_name)
            parts.append('%s\t%s' % (ip_address, " ".join(hosts)))
        return '\n'.join([get_created_comment(), '\n'.join(parts), ''])

    def read(self, path):
        """Read the hosts file at the given location and parse the contents"""
        with open(path, 'r') as hosts_file:
            for line in hosts_file.read().split('\n'):
                if len(re.sub('\s*', '', line)) and not line.startswith('#'):
                    line = re.sub('\s+$', '', line)  # Remove spaces at EOL
                    parts = re.split('\s+', line)
                    ip_address = parts[0]
                    for host_name in parts[1:]:
                        if host_name in self.hosts:
                            self.hosts[host_name].append(ip_address)
                        else:
                            self.hosts[host_name] = [ip_address]

    def remove_one(self, host_name):
        """Remove a mapping for the given host_name"""
        del self.hosts[host_name]

    def remove_all(self, host_names):
        """Remove a mapping for the given host_name"""
        for host_name in host_names:
            self.remove_one(host_name)

    def write(self, path):
        """Write the contents of this hosts definition to the provided path"""
        try:
            contents = self.file_contents()
        except Exception as e:
            raise e
        try:
            with open(path, 'wb') as hosts_file:
                hosts_file.write(contents)
        except IOError:
            try:
                with open("/tmp/etc_hosts.tmp", 'wb') as hosts_file:
                    hosts_file.write(contents)
                    os.system('sudo mv /tmp/etc_hosts.tmp /etc/hosts')
            except Exception as e:
                raise e

    def set_one(self, host_name, ip_address):
        """Set the given hostname to map to the given IP address"""
        if isinstance(host_name, str):
            if host_name in self.hosts:
                self.hosts[host_name].append(ip_address)
            else:
                self.hosts[host_name] = [ip_address]
        else:
            raise TypeError

    def set_all(self, host_names, ip_address):
        """Set the given list of hostnames to map to the given IP address"""
        if isinstance(host_names, list):
            for host_name in host_names:
                self.set_one(host_name, ip_address)
        else:
            raise TypeError

    def alias_all(self, host_names, target):
        """Set hostnames to map to the IP address that target maps to"""
        self.set_all(host_names, self.get_one(target, raise_on_not_found=True))

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Manipulate your hosts file')

    parser.add_argument('name', nargs='+')
    parser.add_argument('--set', dest='ip_address')
    parser.add_argument('--alias')
    parser.add_argument('--get', action='store_true', default=False)
    parser.add_argument('--remove', action='store_true', default=False)
    parser.add_argument('--dry', action='store_true', default=False)

    args = parser.parse_args()

    if os.name == 'nt':
        hosts_path = os.path.join(os.environ['SYSTEMROOT'],
                                  'system32/drivers/etc/hosts')
    elif os.name == 'posix':
        hosts_path = '/etc/hosts'
    else:
        raise Exception('Unsupported OS: %s' % os.name)

    hosts = Hosts(hosts_path)

    try:
        if args.get:
            hosts.print_all(args.name)
        elif args.alias is not None:
            hosts.alias_all(args.name, args.alias)
            if args.dry:
                print(hosts.file_contents())
            else:
                hosts.write(hosts_path)
        elif hasattr(args, 'ip_address'):
            hosts.set_all(args.name, args.ip_address)
            if args.dry:
                print(hosts.file_contents())
            else:
                hosts.write(hosts_path)
        elif args.remove:
            hosts.remove_all(args.name)
            if args.dry:
                print(hosts.file_contents())
            else:
                hosts.write(hosts_path)
    except Exception as e:
        print('Error: %s' % (e,))
        exit(1)
