"""
    Copyright (c) 2016 Ad Schellevis
    All rights reserved.

    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions are met:

    1. Redistributions of source code must retain the above copyright notice,
     this list of conditions and the following disclaimer.

    2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

    THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES,
    INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY
    AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
    AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
    OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
    SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
    INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
    CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
    ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
    POSSIBILITY OF SUCH DAMAGE.

    --------------------------------------------------------------------------------------
    aggregate flow data (format in parse.py) into sqlite structured container per type/resolution.
    Implementations are collected in lib\aggregates\
"""
import os
import datetime
import sqlite3

def convert_timestamp(val):
    """ convert timestamps from string (internal sqlite type) or seconds since epoch
    """
    if val.find('-') > -1:
        # formatted date/time
        if val.find(" ") > -1:
            datepart, timepart = val.split(" ")
        else:
            datepart = val
            timepart = "0:0:0,0"
        year, month, day = map(int, datepart.split("-"))
        timepart_full = timepart.split(".")
        hours, minutes, seconds = map(int, timepart_full[0].split(":"))
        if len(timepart_full) == 2:
            microseconds = int('{:0<6.6}'.format(timepart_full[1].decode()))
        else:
            microseconds = 0

        val = datetime.datetime(year, month, day, hours, minutes, seconds, microseconds)
    else:
        # timestamp stored as seconds since epoch, convert to utc
        val = datetime.datetime.utcfromtimestamp(float(val))

    return val

sqlite3.register_converter('timestamp', convert_timestamp)

class AggMetadata(object):
    """ store some metadata needed to keep track of parse progress
    """
    def __init__(self):
        self._filename = '/var/netflow/metadata.sqlite'
        # make sure the target directory exists
        target_path = os.path.dirname(self._filename)
        if not os.path.isdir(target_path):
            os.makedirs(target_path)
        # open sqlite database and cursor
        self._db_connection = sqlite3.connect(self._filename,
                                              detect_types=sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES)
        self._db_cursor = self._db_connection.cursor()
        # known tables
        self._tables = list()
        # cache known tables
        self._update_known_tables()

    def __del__(self):
        """ close database on destruct
        :return: None
        """
        if self._db_connection is not None:
            self._db_connection.close()

    def _update_known_tables(self):
        """ request known tables
        """
        result = list()
        self._db_cursor.execute('SELECT name FROM sqlite_master')
        for record in self._db_cursor.fetchall():
            self._tables.append(record[0])

    def update_sync_time(self, timestamp):
        """ update (last) sync timestamp
        """
        if 'sync_timestamp' not in self._tables:
            self._db_cursor.execute('create table sync_timestamp(mtime timestamp)')
            self._db_cursor.execute('insert into sync_timestamp(mtime) values(0)')
            self._db_connection.commit()
            self._update_known_tables()
        # update last sync timestamp, if this date > timestamp
        self._db_cursor.execute('update sync_timestamp set mtime = :mtime where mtime < :mtime', {'mtime': timestamp})
        self._db_connection.commit()

    def last_sync(self):
        if 'sync_timestamp' not in self._tables:
            return 0.0
        else:
            self._db_cursor.execute('select max(mtime) from sync_timestamp')
            return self._db_cursor.fetchall()[0][0]

class BaseFlowAggregator(object):
    # target location ('/var/netflow/<store>.sqlite')
    target_filename = None
    # list of fields to use in this aggregate
    agg_fields = None

    @classmethod
    def resolutions(cls):
        """ sample resolutions for this aggregation
        :return: list of sample resolutions
        """
        return  list()

    @classmethod
    def history_per_resolution(cls):
        """ history to keep in seconds per sample resolution
        :return: dict sample resolution / expire time (seconds)
        """
        return dict()

    @classmethod
    def seconds_per_day(cls, days):
        """
        :param days: number of days
        :return: number of seconds
        """
        return 60*60*24*days

    def __init__(self, resolution):
        """ construct new flow sample class
        :return: None
        """
        self.resolution = resolution
        # target table name, data_<resolution in seconds>
        self._db_connection = None
        self._update_cur = None
        self._known_targets = list()
        # construct update and insert sql statements
        tmp = 'update timeserie set  last_seen = :flow_end, '
        tmp += 'octets = octets + :octets_consumed, packets = packets + :packets_consumed '
        tmp += 'where mtime = :mtime and %s '
        self._update_stmt = tmp % (' and '.join(map(lambda x: '%s = :%s' % (x, x), self.agg_fields)))
        tmp = 'insert into timeserie (mtime, last_seen, octets, packets, %s) '
        tmp += 'values (:mtime, :flow_end, :octets_consumed, :packets_consumed, %s)'
        self._insert_stmt = tmp % (','.join(self.agg_fields), ','.join(map(lambda x: ':%s' % x, self.agg_fields)))
        # open database
        self._open_db()
        self._fetch_known_targets()

    def __del__(self):
        """ close database on destruct
        :return: None
        """
        if self._db_connection is not None:
            self._db_connection.close()

    def _fetch_known_targets(self):
        """ read known target table names from the sqlite db
        :return: None
        """
        if self._db_connection is not None:
            self._known_targets = list()
            cur = self._db_connection.cursor()
            cur.execute('SELECT name FROM sqlite_master')
            for record in cur.fetchall():
                self._known_targets.append(record[0])
            cur.close()

    def _create_target_table(self):
        """ construct target aggregate table, using resulution and list of agg_fields
        :return: None
        """
        if self._db_connection is not None:
            # construct new aggregate table
            sql_text = list()
            sql_text.append('create table timeserie ( ')
            sql_text.append('   mtime timestamp')
            sql_text.append(',  last_seen timestamp')
            for agg_field in self.agg_fields:
                sql_text.append(', %s varchar(255)' % agg_field)
            sql_text.append(',  octets numeric')
            sql_text.append(',  packets numeric')
            sql_text.append(',  primary key(mtime, %s)' % ','.join(self.agg_fields))
            sql_text.append(')')
            cur = self._db_connection.cursor()
            cur.executescript('\n'.join(sql_text))
            cur.close()
            # read table names
            self._fetch_known_targets()

    def is_db_open(self):
        """ check if target database is open
        :return: database connected (True/False)
        """
        if self._db_connection is not None:
            return True
        else:
            return False

    def _open_db(self):
        """ open / create database
        :return: None
        """
        if self.target_filename is not None:
            # make sure the target directory exists
            target_path = os.path.dirname(self.target_filename % self.resolution)
            if not os.path.isdir(target_path):
                os.makedirs(target_path)
            # open sqlite database
            self._db_connection = sqlite3.connect(self.target_filename % self.resolution,
                                                  detect_types=sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES)
            # open update/insert cursor
            self._update_cur = self._db_connection.cursor()

    def commit(self):
        """ commit data
        :return: None
        """
        if self._db_connection is not None:
            self._db_connection.commit()


    def add(self, flow):
        """ calculate timeslices per flow depending on sample resolution
        :param flow: flow data (from parse.py)
        :return: None
        """
        # make sure target exists
        if 'timeserie' not in self._known_targets:
            self._create_target_table()

        # push record(s) depending on resolution
        start_time = int(flow['flow_start'] / self.resolution) * self.resolution
        while start_time <= flow['flow_end']:
            consume_start_time = max(flow['flow_start'], start_time)
            consume_end_time = min(start_time + self.resolution, flow['flow_end'])
            if flow['duration_ms'] != 0:
                consume_perc = (consume_end_time - consume_start_time) / float(flow['duration_ms'] / 1000)
            else:
                consume_perc = 1
            if self.is_db_open():
                # upsert data
                flow['octets_consumed'] = consume_perc * flow['octets']
                flow['packets_consumed'] = consume_perc * flow['packets']
                flow['mtime'] = datetime.datetime.utcfromtimestamp(start_time)
                self._update_cur.execute(self._update_stmt, flow)
                if self._update_cur.rowcount == 0:
                    self._update_cur.execute(self._insert_stmt, flow)
            # next start time
            start_time += self.resolution

    def cleanup(self, do_vacuum = False):
        """ cleanup timeserie table
        :param expire: cleanup table, remove data older then [expire] seconds
        :param do_vacuum: vacuum database
        :return: None
        """
        if self.is_db_open() and 'timeserie' in self._known_targets \
          and self.resolution in self.history_per_resolution():
            self._update_cur.execute('select max(mtime) as "[timestamp]" from timeserie')
            last_timestamp = self._update_cur.fetchall()[0][0]
            if type(last_timestamp) == datetime.datetime:
                expire = self.history_per_resolution()[self.resolution]
                expire_timestamp = last_timestamp - datetime.timedelta(seconds=expire)
                self._update_cur.execute('delete from timeserie where mtime < :expire', {'expire': expire_timestamp})
                self.commit()
                if do_vacuum:
                    # vacuum database if requested
                    self._update_cur.execute('vacuum')

    def _parse_timestamp(self, timestamp):
        """ convert input to datetime.datetime or return if it already was of that type
        :param timestamp: timestamp to convert
        :return: datetime.datetime object
        """
        if type(timestamp) in (int, float):
            return datetime.datetime.utcfromtimestamp(timestamp)
        elif type(timestamp) != datetime.datetime:
            return datetime.datetime.utcfromtimestamp(0)
        else:
            return timestamp

    def _valid_fields(self, fields):
        """ cleanse fields (return only valid ones)
        :param fields: field list
        :return: list
        """
        # validate field list (can only select fields in self.agg_fields)
        select_fields = list()
        for field in fields:
            if field.strip() in self.agg_fields:
                select_fields.append(field.strip())

        return select_fields

    def get_timeserie_data(self, start_time, end_time, fields):
        """ fetch data from aggregation source, groups by mtime and selected fields
        :param start_time: start timestamp
        :param end_time: end timestamp
        :param fields: fields to retrieve
        :return: iterator returning dict records (start_time, end_time, [fields], octets, packets)
        """
        if self.is_db_open() and 'timeserie' in self._known_targets:
            # validate field list (can only select fields in self.agg_fields)
            select_fields = self._valid_fields(fields)
            if len(select_fields) == 0:
                # select "none", add static null as field
                select_fields.append('null')
            sql_select = 'select mtime as "start_time [timestamp]", %s' % ','.join(select_fields)
            sql_select += ', sum(octets) as octets, sum(packets) as packets\n'
            sql_select += 'from timeserie \n'
            sql_select += 'where mtime >= :start_time and mtime < :end_time\n'
            sql_select += 'group by mtime, %s\n'% ','.join(select_fields)

            # execute select query
            cur = self._db_connection.cursor()
            cur.execute(sql_select, {'start_time': self._parse_timestamp(start_time),
                                     'end_time': self._parse_timestamp(end_time)})
            #
            field_names = (map(lambda x:x[0], cur.description))
            for record in cur.fetchall():
                result_record = dict()
                for field_indx in range(len(field_names)):
                    if len(record) > field_indx:
                        result_record[field_names[field_indx]] = record[field_indx]
                if 'start_time' in result_record:
                    result_record['end_time'] = result_record['start_time'] + datetime.timedelta(seconds=self.resolution)
                    # send data
                    yield result_record
            # close cursor
            cur.close()

    def get_top_data(self, start_time, end_time, fields, value_field, data_filters=None, max_hits=100):
        """ Retrieve top (usage) from this aggregation.
            Fetch data from aggregation source, groups by selected fields, sorts by value_field descending
            use data_filter to filter before grouping.
        :param start_time: start timestamp
        :param end_time: end timestamp
        :param fields: fields to retrieve
        :param value_field: field to sum
        :param data_filter: filter data, use as field=value
        :param max_hits: maximum number of results, rest is summed into (other)
        :return: iterator returning dict records (start_time, end_time, [fields], octets, packets)
        """
        result = list()
        select_fields = self._valid_fields(fields)
        filter_fields = []
        query_params = {}
        if value_field == 'octets':
            value_sql = 'sum(octets)'
        elif value_field == 'packets':
            value_sql = 'sum(packets)'
        else:
            value_sql = '0'

        # query filters, correct start_time for resolution
        query_params['start_time'] = self._parse_timestamp((int(start_time/self.resolution))*self.resolution)
        query_params['end_time'] = self._parse_timestamp(end_time)
        if data_filters:
            for data_filter in data_filters.split(','):
                tmp = data_filter.split('=')[0].strip()
                if tmp in self.agg_fields and data_filter.find('=') > -1:
                    filter_fields.append(tmp)
                    query_params[tmp] = '='.join(data_filter.split('=')[1:])

        if len(select_fields) > 0:
            # construct sql query to filter and select data
            sql_select = 'select %s' % ','.join(select_fields)
            sql_select += ', %s as total, max(last_seen) last_seen \n' % value_sql
            sql_select += 'from timeserie \n'
            sql_select += 'where mtime >= :start_time and mtime < :end_time\n'
            for filter_field in filter_fields:
                sql_select += ' and %s = :%s \n' % (filter_field, filter_field)
            sql_select += 'group by %s\n'% ','.join(select_fields)
            sql_select += 'order by %s desc ' % value_sql

            # execute select query
            cur = self._db_connection.cursor()
            cur.execute(sql_select, query_params)

            # fetch all data, to a max of [max_hits] rows.
            field_names = (map(lambda x:x[0], cur.description))
            for record in cur.fetchall():
                result_record = dict()
                for field_indx in range(len(field_names)):
                    if len(record) > field_indx:
                        result_record[field_names[field_indx]] = record[field_indx]
                if len(result) < max_hits:
                    result.append(result_record)
                else:
                    if len(result) == max_hits:
                        # generate row for "rest of data"
                        result.append({'total': 0})
                        for key in result_record:
                            if key not in result[-1]:
                                result[-1][key] = ""
                    result[-1]['total'] += result_record['total']
            # close cursor
            cur.close()

        return result

    def get_data(self, start_time, end_time):
        """ get detail data
        :param start_time: start timestamp
        :param end_time: end timestamp
        :return: iterator
        """
        query_params = {}
        query_params['start_time'] = self._parse_timestamp((int(start_time/self.resolution))*self.resolution)
        query_params['end_time'] = self._parse_timestamp(end_time)
        sql_select = 'select mtime start_time, '
        sql_select += '%s, octets, packets, last_seen as "last_seen [timestamp]"  \n' % ','.join(self.agg_fields)
        sql_select += 'from timeserie \n'
        sql_select += 'where mtime >= :start_time and mtime < :end_time\n'
        cur = self._db_connection.cursor()
        cur.execute(sql_select, query_params)

        # fetch all data, to a max of [max_hits] rows.
        field_names = (map(lambda x:x[0], cur.description))
        while True:
            record = cur.fetchone()
            if record is None:
                break
            else:
                result_record=dict()
                for field_indx in range(len(field_names)):
                    if len(record) > field_indx:
                        result_record[field_names[field_indx]] = record[field_indx]
                yield result_record