#!/usr/bin/env python3
# -*- coding: utf-8 -*-

'''
Multicast functionality.
'''

# Copyright (C) 2009-2022  Xyne
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# (version 2) as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import ipaddress
import logging
import socket
import socketserver
import urllib.parse
import threading
import time

from .common import (
    DAEMON_THREADS,
    format_seconds,
    get_local_ipv6_addresses,
    get_all_interfaces,
    get_ip_addresses,
    replace_interfaces_with_ips,
    replace_uri_host_and_get_port,
    unbound_address
)

MULTICAST_GROUP = '224.4.4.4'
MULTICAST_PORT = 32768
MULTICAST_INTERVAL = 300


# -------------------------------- Functions --------------------------------- #

def multicast(message, group, ports, bind_address=None):  # pylint: disable=too-many-branches
    '''
    Send a multicast message.
    '''
    try:
        message = message.encode('UTF-8')
    except AttributeError:
        pass

    # This is used to get the IP type of the group (IPv4 or IPv6).
    group_ip_type = type(ipaddress.ip_address(group))

    if isinstance(ports, int):
        ports = (ports,)

    if bind_address is None:
        # Binding to 0.0.0.0 did not work.
        addresses = (ip for (_, ip) in get_all_interfaces())
    else:
        addresses = (bind_address,)

    ipv6_iface_indices = dict(
        (addr, socket.if_nametoindex(iface))
        for (iface, addr) in get_local_ipv6_addresses()
    )

    for address in addresses:
        # Skip mismatched IP versions (e.g. IPv4 multicast groups on IPv6 networks).
        if not isinstance(address, group_ip_type):
            continue

        def send_to_sock(address_family, address, *_extra_bind_args):
            with socket.socket(address_family, socket.SOCK_DGRAM, socket.IPPROTO_UDP) as sock:
                sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 2)
                #  sock.bind((address.exploded, 0, *extra_bind_args))
                #  sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_IF, address.packed)
                for port in ports:
                    logging.info(
                        'sending multicast message to (%s, %d) via %s',
                        group, port, address
                    )
                    try:
                        sock.sendto(message, (group, port))
                    except OSError as err:
                        logging.error('failed to send multicast message: %s', err)

        try:
            if isinstance(address, ipaddress.IPv4Address):
                send_to_sock(socket.AF_INET, address)

            elif isinstance(address, ipaddress.IPv6Address):
                # TODO
                # This originally used a try-except block but discarded the
                # value if iface_index. Double-check if it served a purpose
                # before completely removing it.
                #  try:
                #      iface_index = ipv6_iface_indices[address]
                #  except KeyError:
                if address not in ipv6_iface_indices:
                    logging.error(
                        'failed to find associated interface for address %s, skipping...',
                        address
                    )
                    continue
                send_to_sock(socket.AF_INET6, address)

            else:
                logging.warning(
                    'multicast function only accepts addresses that are instances of '
                    'ipaddress.IPv4Address or ipaddress.IPV6Address: ignoring address of type %s',
                    type(address)
                )
                continue
        except (PermissionError, socket.gaierror) as err:
            #       if e.errno == socket.EAI_NONAME:
            #         continue
            #       else:
            if isinstance(err, socket.gaierror):
                err_msg = err.strerror
            else:
                err_msg = str(err)

            if bind_address:
                logging.error('announcement failed via %s: %s', address, err_msg)
            else:
                logging.error('announcement failed: %s', err_msg)


# --------------------------------- Threads ---------------------------------- #

def multicast_announcer(
    message_prefix,
    get_server_uris,
    group=MULTICAST_GROUP,
    ports=MULTICAST_PORT,
    interval=MULTICAST_INTERVAL,
    delay=1,
    interfaces=None
):  # pylint: disable=too-many-arguments
    '''
    Periodically announce presence via multicast.
    '''
    time.sleep(delay)
    # If any of the interfaces is an unbound address then all interfaces will be
    # used regardless of the other interfaces. In that case, just use an unbound
    # address.
    if not interfaces or any(unbound_address(i) for i in interfaces):
        interfaces = None
    while True:
        #      logging.info('announcing presence by multicast (group: {})'.format(group))
        # Do this here to ensure that changes are detected if an interface address
        # changes.
        addresses = set(replace_interfaces_with_ips(interfaces))
        for server_uri in get_server_uris():
            parsed_uri = urllib.parse.urlsplit(server_uri)
            bind_address = ipaddress.ip_address(parsed_uri.hostname)
            if bind_address not in addresses:
                continue
            message, _ = replace_uri_host_and_get_port(server_uri)
            multicast(message_prefix + message, group, ports, bind_address)
        time.sleep(interval)


# ----------------------------- MulticastServer ------------------------------ #

class MulticastServer(socketserver.UDPServer):
    '''
    Server for listening for multicast announcements.
    '''
    # ThreadingMixIn attribute.
    daemon_threads = DAEMON_THREADS

    def __init__(
        self,
        server_address,
        handler,
        multicast_group,
        *args,
        **kwargs
    ):
        self.multicast_group = multicast_group
        if not unbound_address(server_address[0]):
            for ip_address in get_ip_addresses(server_address[0]):
                server_address = (ip_address, server_address[1])
                break
        socketserver.UDPServer.__init__(self, server_address, handler, *args, **kwargs)

    # For details, see
    # https://www.tldp.org/HOWTO/Multicast-HOWTO-2.html#ss2.4
    # https://bbs.archlinux.org/viewtopic.php?pid=1833194#p1833194
    # https://stackoverflow.com/questions/10692956/what-does-it-mean-to-bind-a-multicast-udp-socket/10739443

    def server_bind(self):
        # The IP protocol version type.
        ip_type = type(ipaddress.ip_address(self.multicast_group))

        # Determine which addresses are available.
        if unbound_address(self.server_address[0]):
            addresses = (ip for (_, ip) in get_all_interfaces())
        else:
            addresses = (ipaddress.ip_address(self.server_address[0]),)

        if self.allow_reuse_address:
            self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.socket.bind((self.multicast_group, self.server_address[1]))

        for address in addresses:
            # Skip mismatched IP versions.
            if not isinstance(address, ip_type):
                continue
            mreq = ipaddress.ip_address(self.multicast_group).packed + address.packed
            self.socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)


# ---------------------------- MulticastSubserver ---------------------------- #

class MulticastSubserver(MulticastServer):
    '''
    Subserver to handle multicast announcements. This is a subclass of the
    multicast listening server. It is run by a main server to handle multicast
    announcements.
    '''

    def __init__(self, main_server, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.main_server = main_server


# --------------------- MulticastSubserverRequestHandler --------------------- #

class MulticastSubserverRequestHandler(socketserver.BaseRequestHandler):
    '''
    The request handler passed to the multicast subserver. This invokes a method
    in the parent server to handle the data from the multicast and add the peer
    to the pool of known peers.
    '''

    def handle(self):
        client_address = self.client_address
        data = self.request[0].decode()
        main_server = self.server.main_server
        multicast_prefix = main_server.get_multicast_prefix()

        mcast_prefix_len = len(multicast_prefix)
        if data[:mcast_prefix_len] == multicast_prefix:
            # TODO
            # Catch other errors here
            try:
                uri = data[mcast_prefix_len:]
                uri, _port = replace_uri_host_and_get_port(
                    data[mcast_prefix_len:],
                    client_address[0],
                    scheme=main_server.DEFAULT_PEER_SCHEME,
                    port=main_server.DEFAULT_PEER_PORT
                )
            except ValueError:
                main_server.log_error('invalid multicast message from %s', client_address[0])
                return
            main_server.handle_peer(uri, 'multicast')
        else:
            main_server.log_warning(
                'unrecognized multicast message from %s',
                client_address[0]
            )


# --------------------------- MulticastPeerManager --------------------------- #

class MulticastPeerManager():
    '''
    Manage the multicast subserver and announcer threads. This should be
    subclassed by servers that wish to acquire peers through multicasts.
    Subclasses must have the following attributes from one of the server classes:

        * options (with all options added by add_Multicast_argparse_groups())
        * handle_peer
        * get_server_uris
    '''

    def __init__(self, options, handler):
        self.options = options
        self.handler = handler
        self.multicast_server = None
        self.multicast_server_thread = None
        self.multicast_announcer_thread = None

    @staticmethod
    def get_server_uris():
        '''
        Placeholder for returning the server URIs, overridden in child classes.
        '''
        return iter()

    def shutdown(self):
        '''
        Shutdown multicasting.
        '''
        try:
            self.multicast_server.shutdown()
        except AttributeError:
            pass

    def get_multicast_prefix(self):
        '''
        The multicast prefix is prepended to the contents of the multicasts. Only
        multicasts with the same prefix will be processed.
        '''
        # The handler is a class, not an instance. The method should be static but
        # the base class defines a class method. Invoke the class method here with
        # None as it is not used by the subclass overrides.
        return self.handler.version_string(None) + ' '

    def get_multicast_info(self):
        '''
        Get multicast information.
        '''
        if self.options.multicast:
            multicast_address = self.options.multicast_server_address
            if unbound_address(multicast_address):
                multicast_address = 'all interfaces'

            if self.options.multicast_interfaces:
                multicast_interfaces = '\n'.join(self.options.multicast_interfaces)
            else:
                multicast_interfaces = 'all'
            multicast_ports = self.options.multicast_ports
            if isinstance(multicast_ports, list):
                multicast_ports = ' '.join(str(p) for p in multicast_ports)
            yield from (
                ('Multicast listening address', multicast_address),
                ('Multicast listening port', self.options.multicast_server_port),
                ('Multicast group', self.options.multicast_group),
                ('Multicast interval (s)', format_seconds(self.options.multicast_interval)),
                ('Multicast interfaces', multicast_interfaces),
                ('Multicast ports', multicast_ports)
            )
        else:
            yield ('Multicast', False)

    def start_multicast_threads(self):
        '''
        Start the multicast subserver to listen for multicast requests and start a
        thread to periodically announce the presence of this server.
        '''
        self.multicast_server = MulticastSubserver(
            self,
            (self.options.multicast_server_address, self.options.multicast_server_port),
            MulticastSubserverRequestHandler,
            self.options.multicast_group
        )
        self.multicast_server_thread = threading.Thread(
            target=self.multicast_server.serve_forever
        )
        self.multicast_server_thread.daemon = True  # DAEMON_THREADS
        self.multicast_server_thread.start()

        self.multicast_announcer_thread = threading.Thread(
            target=multicast_announcer,
            args=(
                self.get_multicast_prefix(),
                self.get_server_uris
            ),
            kwargs={
                'group': self.options.multicast_group,
                'ports': self.options.multicast_ports,
                'interval': self.options.multicast_interval,
                'interfaces': self.options.multicast_interfaces,
            }
        )
        self.multicast_announcer_thread.daemon = True  # DAEMON_THREADS
        self.multicast_announcer_thread.start()


# ------------------------------- TestHandler -------------------------------- #

class TestHandler(socketserver.BaseRequestHandler):
    '''
    Simple request hander for basic testing.
    '''

    def handle(self):
        print(f'{self.client_address[0]}:\n{self.request[0]}')


# -------------------------- Command-line arguments -------------------------- #

def add_multicast_argparse_groups(
    parser,
    multicast_address='0.0.0.0',
    multicast_port=MULTICAST_PORT,
    multicast_group=MULTICAST_GROUP,
    multicast_interval=MULTICAST_INTERVAL
):
    '''
    Add Multicast arguments to argument parser.
    '''
    multicast_options = parser.add_argument_group(
        title="Multicast Options",
        description="Options that affect the behavior of the multicast (sub)server system.",
    )

    multicast_options.add_argument(
        "--multicast", action='store_true',
        help='Use multicasting to announce presence and detect other servers.',
    )

    multicast_options.add_argument(
        "--multicast-server-address", metavar='<interface|address>', default=multicast_address,
        help='The multicast server listening address. Default: %(default)s.',
    )

    multicast_options.add_argument(
        '--multicast-server-port', metavar='<port>', type=int, default=multicast_port,
        help='The multicast server listening port. Default: %(default)s.',
    )

    multicast_options.add_argument(
        '--multicast-group', metavar='<group>', default=multicast_group,
        help='The multicast group. Default: %(default)s.',
    )

    multicast_options.add_argument(
        '--multicast-interval', metavar='<seconds>', type=int, default=multicast_interval,
        help='The multicast announcement interval. Default: %(default)s.',
    )

    multicast_options.add_argument(
        "--multicast-interface", metavar='<interface|address>', dest="multicast_interfaces",
        default=[], action='append',
        help=('''
        The interface or address through which to announce presence with
        multicast packets. If not given, all interfaces on which the server is
        listening are used. Interfaces on which the server is not listening are
        ignored.
        ''').strip(),
    )

    multicast_options.add_argument(
        '--multicast-ports', metavar='<port>', type=int, nargs='+', default=multicast_port,
        help='The multicast ports to which to send announcement messages. Default: %(default)s.',
    )
    return parser


if __name__ == "__main__":
    server = MulticastServer(
        ('', MULTICAST_PORT),
        TestHandler,
        MULTICAST_GROUP
    )
    server.serve_forever()
