# Copyright 2022 Q-CTRL. All rights reserved.
#
# Licensed under the Q-CTRL Terms of service (the "License"). Unauthorized
# copying or use of this file, via any medium, is strictly prohibited.
# Proprietary and confidential. You may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#    https://q-ctrl.com/terms
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS. See the
# License for the specific language.
"""
Functions for plotting filter functions.
"""

from typing import Dict

import matplotlib.pyplot as plt
import numpy as np
from qctrlcommons.preconditions import check_argument

from .style import (
    FIG_HEIGHT,
    FIG_WIDTH,
    qctrl_style,
)
from .utils import figure_as_kwarg_only


@qctrl_style()
@figure_as_kwarg_only
def plot_filter_functions(
    filter_functions: Dict,
    *,
    figure: plt.Figure,
):
    """
    Create a plot of the specified filter functions.

    Parameters
    ----------
    filter_functions : dict
        The dictionary of filter functions to plot. The keys should be the names of the filter
        functions, and the values the list of samples representing that filter function. Each such
        sample must be a dictionary with 'frequency', 'inverse_power', and
        'inverse_power_uncertainty' (optional) keys, giving the frequency (in Hertz) at which the
        sample was taken, the inverse power (in seconds) of the filter function at the sample, and
        the optional uncertainty of that inverse power (in seconds).

        The key 'inverse_power_precision' can be used instead of 'inverse_power_uncertainty'. If
        both are provided then the value corresponding to 'inverse_power_uncertainty' is used.

        If the uncertainty of an inverse power is provided, it must be non-negative.

        For example, the following would be a valid ``filter_functions`` input::

            {
             'Primitive': [
                {'frequency': 0.0, 'inverse_power': 15.},
                {'frequency': 1.0, 'inverse_power': 12.},
                {'frequency': 2.0, 'inverse_power': 3., 'inverse_power_uncertainty': 0.2},
             ],
             'CORPSE': [
                {'frequency': 0.0, 'inverse_power': 10.},
                {'frequency': 0.5, 'inverse_power': 8.5},
                {'frequency': 1.0, 'inverse_power': 5., 'inverse_power_uncertainty': 0.1},
                {'frequency': 1.5, 'inverse_power': 2.5},
             ],
            }
    figure : matplotlib.figure.Figure, optional
        A matplotlib Figure in which to place the plot.
        If passed, its dimensions and axes will be overridden.
    """

    check_argument(
        filter_functions,
        "At least one filter function must be provided.",
        {"filter_functions": filter_functions},
    )

    figure.set_figwidth(FIG_WIDTH)
    figure.set_figheight(FIG_HEIGHT)

    axes = figure.subplots(nrows=1, ncols=1)

    for name, samples in filter_functions.items():
        frequencies, inverse_powers, inverse_power_uncertainties = np.array(
            list(
                zip(
                    *[
                        (
                            sample["frequency"],
                            sample["inverse_power"],
                            sample["inverse_power_uncertainty"]
                            if "inverse_power_uncertainty" in sample
                            else sample.get("inverse_power_precision", 0.0),
                        )
                        for sample in samples
                    ]
                )
            )
        )

        check_argument(
            np.all(inverse_power_uncertainties >= 0.0),
            "Uncertainties must all be non-negative in filter functions",
            {"filter_functions": filter_functions},
            extras={"samples": samples},
        )

        inverse_powers_upper = inverse_powers + inverse_power_uncertainties
        inverse_powers_lower = inverse_powers - inverse_power_uncertainties

        lines = axes.plot(frequencies, inverse_powers, label=name)
        axes.fill_between(
            frequencies,
            inverse_powers_lower,
            inverse_powers_upper,
            alpha=0.35,
            hatch="||",
            facecolor="none",
            edgecolor=lines[0].get_color(),
            linewidth=0,
        )

    axes.legend()

    axes.set_xscale("log")
    axes.set_yscale("log")

    axes.autoscale(axis="x", tight=True)

    axes.set_xlabel("Frequency (Hz)")
    axes.set_ylabel("Inverse power (s)")
