# -*- coding: utf-8 -*-
"""
/***************************************************************************
 CDVI
                                 A QGIS plugin
 City Disaster Vulnerability Index
                             -------------------
        begin                : 2020-06-11
        copyright            : (C) 2020 by WB
        email                : andlang@outlook.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

import functools
import operator
import os
import typing

import numpy as np
import pandas as pd
from cdvi.utilities.transformations import linear, z_score, logarithmic, sigmoid, quadratic


class CityDataLoader(object):
    def __init__(self, plugin_dir: str, country_abbr: str, city_name: str, city_abbr: str):
        self.plugin_dir = plugin_dir
        self.city_abbr = city_abbr
        self.city_dir = self.plugin_dir + '/data/' + city_name.replace(" ", "") + '_' + country_abbr + '/'

    def get_initial_data_matrix(self) -> pd.DataFrame:
        path = self.city_dir + 'Initial_Data/' + self.city_abbr + '_CDVI_InitialData.csv'
        return pd.read_csv(path)

    def get_multipliers(self) -> pd.DataFrame:
        pillars = self.get_pillars()
        indicators = [[pillar + '_' + item for item in
                       self.get_setup_matrix(pillar).query('Norm_Base > 0')['Unnamed: 1'].tolist()] for pillar in
                      pillars]

        columns = functools.reduce(operator.iconcat, indicators, [])
        combined_data = self.combine_data(pillars[0])
        initial_data = np.zeros((len(combined_data.iloc[3:].index), len(columns))) + 1

        return pd.concat([combined_data.iloc[3:, np.where(combined_data.loc['Norm_Base'] == 0)[0]],
                          pd.DataFrame(initial_data, columns=columns)], axis=1)

    def get_pillars(self) -> typing.List[str]:
        return [file[len(self.city_abbr + "_CDVI_"):-4] for file in os.listdir(self.city_dir) if file.endswith('.csv')]

    def combine_data_and_apply_transformation(self, pillar: str, multipliers: pd.DataFrame) -> pd.DataFrame:
        return self.apply_transformation(self.combine_data(pillar, multipliers), multipliers).iloc[3:]

    def combine_data(self, pillar: str, multipliers: pd.DataFrame = pd.DataFrame()) -> pd.DataFrame:
        cdvi_data = pd.read_csv(self.city_dir + self.city_abbr + '_CDVI_' + pillar + '.csv')

        setup_matrix = self.get_setup_matrix(pillar).T
        setup_matrix.columns = setup_matrix.iloc[1]

        combined = pd.concat([setup_matrix.drop(['Unnamed: 0', 'Unnamed: 1'], axis=0), cdvi_data], axis=0, sort=False) \
            .dropna(axis=1)

        for column in combined.columns:
            multipliers_column = pillar + '_' + column
            if multipliers_column in multipliers:
                combined.loc[combined.index[3:], column] *= multipliers[multipliers_column]

        return combined

    def get_setup_matrix(self, pillar: str) -> pd.DataFrame:
        return pd.read_csv(self.city_dir + 'Setup_Matrix/' + self.city_abbr + '_CDVI_' + pillar + '.csv')

    def apply_transformation(self, df: pd.DataFrame, multipliers: pd.DataFrame) -> pd.DataFrame:
        denominators = self.combine_data('Desc', multipliers).T
        count = denominators.query('Norm_Base == 0').count()['Norm_Base']

        for i in list(df):
            norm_base = df[i]['Norm_Base']

            if norm_base < 2:
                continue

            inverse = df[i]['Sign'] == 3
            transform = df[i]['Transform']

            # normalize with denominator
            if norm_base > 2:
                divisor = denominators.iloc[norm_base - 3 + count, 3:]
                df[i].iloc[3:] /= divisor.mask(divisor == 0, 1)

            # transformations
            if transform == 2:
                df[i].iloc[3:] = linear(df[i].iloc[3:], inverse)

            if transform == 3:
                df[i].iloc[3:] = z_score(df[i].iloc[3:], inverse)

            if transform == 4:
                df[i].iloc[3:] = logarithmic(df[i].iloc[3:], inverse)

            if transform == 5:
                df[i].iloc[3:] = sigmoid(df[i].iloc[3:], inverse)

            if transform == 6:
                df[i].iloc[3:] = quadratic(df[i].iloc[3:], inverse)

        return df
