#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# This file is part of BridgeDB, a Tor bridge distribution system.
#
# :authors: Isis Lovecruft 0xA3ADB67A2CDB8B35 <isis@torproject.org>
#           please also see AUTHORS file
# :copyright: (c) 2013 Isis Lovecruft
#             (c) 2007-2013, The Tor Project, Inc.
#             (c) 2007-2013, all entities within the AUTHORS file
# :license: 3-clause BSD, see included LICENSE for information

"""get-tor-exits -- Download the current list of Tor exit relays."""

from __future__ import print_function

import os.path
import socket
import sys

from ipaddr import IPAddress

from OpenSSL import SSL

from twisted.python import log
from twisted.python import usage
from twisted.names import client as dnsclient
from twisted.names import error as dnserror
from twisted.web import client
from twisted.internet import defer
from twisted.internet import protocol
from twisted.internet import reactor
from twisted.internet import ssl
from twisted.internet.error import ConnectionRefusedError
from twisted.internet.error import DNSLookupError
from twisted.internet.error import TimeoutError


log.startLogging(sys.stderr)


def backupFile(filename):
    """Move our old exit list file so that we don't append to it."""
    if os.path.isfile(filename):
        backup = filename + '.bak'
        log.msg("get-tor-exits: Moving old exit list file to %s"
                      % backup)
        os.renames(filename, backup)

def getSelfIPAddress():
    """Get external IP address, to ask check which relays can exit to here.

    TODO: This is blocking. Make it use Twisted.
    """
    ip = s = None
    try:
        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        s.connect(('bridges.torproject.org', 443))
        name = s.getsockname()[0]
        ip = IPAddress(name)
        if ip.is_link_local or ip.is_private or ip.is_reserved:
            name = s.getpeername()[0]
            ip = IPAddress(name)
    except ValueError as error:
        log.err("get-tor-exits: A socket gave us something that wasn't an IP: %s"
                % error)
    except Exception as error:
        log.err("get-tor-exits: Unhandled Exception: %s\n%s\n"
                % (error.message, error))
    finally:
        if s is not None:
            s.close()
    return ip.compressed

def handle(failure):
    """Handle a **failure**."""
    if failure.type == ConnectionRefusedError:
        log.msg("get-tor-exits: Could not download exitlist; connection was refused.")
    elif failure.type == DNSLookupError:
        log.msg("get-tor-exits: Could not download exitlist; domain resolution failed.")
    elif failure.type == TimeoutError:
        log.msg("get-tor-exits: Could not download exitlist; connection timed out.")

    failure.trap(ConnectionRefusedError, DNSLookupError, TimeoutError)

def writeToFile(response, filename):
    log.msg("get-tor-exits: Downloading list of Tor exit relays.")
    finished = defer.Deferred()
    response.deliverBody(FileWriter(finished, filename))
    return finished


class GetTorExitsOptions(usage.Options):
    """Options for this script"""
    optFlags = [['stdout', 's', "Write results to stdout instead of file"]]
    optParameters = [['file', 'f', 'exit-list', "File to write results to"],
                     ['address', 'a', '1.1.1.1', "Only exits which can reach this address"],
                     ['port', 'p', 443, "Only exits which can reach this port"]]


class FileWriter(protocol.Protocol):
    """Read a downloaded file incrementally and write to file."""
    def __init__(self, finished, file):
        """Create a FileWriter.

        .. warning:: We currently only handle the first 2MB of a file. Files
            over 2MB will be truncated prematurely.

        :param finished: A :class:`~twisted.internet.defer.Deferred` which
            will fire when another portion of the download is complete.
        """
        self.finished = finished
        self.remaining = 1024 * 1024 * 2
        self.fh = file

    def dataReceived(self, bytes):
        """Write a portion of the download with ``bytes`` size to disk."""
        if self.remaining:
            display = bytes[:self.remaining]
            self.fh.write(display)
            self.fh.flush()
            self.remaining -= len(display)

    def connectionLost(self, reason):
        """Called when the download is complete."""
        log.msg('get-tor-exits: Finished receiving exit list: %s'
                 % reason.getErrorMessage())
        self.finished.callback(None)


class WebClientContextFactory(ssl.ClientContextFactory):
    """An HTTPS client."""

    def getContext(self, hostname, port):
        """Get this connection's OpenSSL context.

        By default, :api:`twisted.internet.ssl.ClientContextFactory` uses
        ``OpenSSL.SSL.SSLv23_METHOD``, which allows SSLv2, SSLv3, and TLSv1,
        then they disable SSLv2 in the
        :api:`twisted.internet.ssl.ClientContextFactory.getContext` method.

        We disable SSLv3 also.

        :rtype: ``OpenSSL.SSL.Context``
        :returns: An OpenSSL context with options set.
        """
        ctx = self._contextFactory(self.method)
        ctx.set_options(SSL.OP_NO_SSLv2 ^ SSL.OP_NO_SSLv3)
        return ctx


def main(filename=None, address=None, port=None):

    fh = filename
    if filename:
        if (not isinstance(filename, file)) and (filename is not sys.stdout):
            fh = open(filename, 'w')

    if not address:
        address = getSelfIPAddress()

    check  = "https://check.torproject.org/torbulkexitlist"
    params = []

    params.append('ip=%s' % address)
    if port is not None:
        params.append("port=%s" % port)

    check += '?' + '&'.join(params)

    log.msg("get-tor-exits: Requesting %s..." % check)

    contextFactory = WebClientContextFactory()
    agent = client.Agent(reactor, contextFactory)
    d = agent.request("GET", check)
    d.addCallback(writeToFile, fh)
    d.addErrback(handle)
    d.addCallbacks(log.msg, log.err)

    if not reactor.running:
        d.addCallback(lambda ignored: reactor.stop())
        reactor.run()

    if filename:
        if (not isinstance(filename, file)) and (filename is not sys.stdout):
            fh.flush()
            fh.close()


if __name__ == "__main__":
    try:
        options = GetTorExitsOptions()
        options.parseOptions()
    except usage.UsageError as error:
        log.err(error)
        raise SystemExit(options.getUsage())

    if options['stdout']:
        filename = sys.stdout
    elif options['file']:
        filename = options['file']
        log.msg("get-tor-exits: Saving Tor exit relay list to file: '%s'"
                % filename)
        backupFile(filename)

    main(filename, options['address'], options['port'])
