#!/usr/bin/env python3

# Software Name : uprofile
# SPDX-FileCopyrightText: Copyright (c) 2022 Orange
# SPDX-License-Identifier: BSD-3-Clause
#
# This software is distributed under the BSD License;
# see the LICENSE file for more details.
#
# Author: Cédric CHEDALEUX <cedric.chedaleux@orange.com> et al.


import argparse

import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff
import plotly.graph_objects as go
from plotly.subplots import make_subplots


METRICS = {
    'time_exec': 'Execution task',
    'cpu': 'CPU load',
    'sys_mem': 'System memory (in MB)',
    'proc_mem': 'Process Memory (in MB)',
    'gpu': 'GPU load',
    'gpu_mem': 'GPU memory (in MB)'
}


def read(csv_file):
    """
    Generate a DataFrame from the CSV file
    Metrics event can have up to 3 extra parameters in addition to its type and its timestamp
    :param csv_file:
    :return: Dataframe
    """
    return pd.read_csv(csv_file, sep=';', names=['metric', 'timestamp', 'extra_1', 'extra_2', 'extra_3'])


def filter_dataframe(df, metric):
    """
    Filter the datafram by metric type
    :param df:
    :param metric:
    :return: Dataframe
    """
    return df[df['metric'] == metric]


def gen_time_exec_df(df):
    """
    Format the dataframe to represent time exec data as gant tasks
    :param df:
    :return: dataframe with gant tasks
    """
    if df.empty:
        return None
    # 'time_exec' format is 'time_exec:<end_timestamp>:<start_timestamp>:<task_name>')
    time_exec_df = df[['extra_2', 'extra_1', 'timestamp']].copy()
    time_exec_df.rename(columns={"extra_2": "Task", "extra_1": "Start", "timestamp": "Finish"}, inplace=True)
    time_exec_df['Description'] = time_exec_df.apply(lambda row: "Task: {} (duration = {} ms)"
                                                     .format(row['Task'], int(row['Finish']) - int(row['Start'])),
                                                     axis=1)
    time_exec_df['Start'] = pd.to_datetime(time_exec_df['Start'], unit='ms')
    time_exec_df['Finish'] = pd.to_datetime(time_exec_df['Finish'], unit='ms')
    return time_exec_df


def create_gantt_graph(df):
    """
    Create a gant graph to represent the tasks
    (see https://chart-studio.plotly.com/~empet/15242/gantt-chart-in-a-subplot-httpscommun/ and https://plotly.com/python/gantt/)
    :param df:
    :return: graph
    """
    size = len(df['Start'])
    colors = px.colors.sample_colorscale("turbo", [n / (size - 1) for n in range(size)])
    return ff.create_gantt(df, colors=colors, show_colorbar=True, showgrid_x=True, showgrid_y=True,
                           show_hover_fill=True)


def create_cpu_graphs(df):
    if df.empty:
        return None
    # 'cpu' metrics (format is 'cpu:<timestamp>:<cpu_number>:<percentage_usage>')
    cpus = pd.unique(df['extra_1'])
    for cpu in cpus:
        cpu_df = df[df['extra_1'] == cpu]
        yield go.Scatter(x=pd.to_datetime(cpu_df['timestamp'], unit='ms'),
                         y=pd.to_numeric(cpu_df['extra_2']),
                         name="CPU {}".format(cpu),
                         showlegend=True)


def create_sys_mem_graphs(df):
    # 'sys_mem' metrics (format is 'mem:<timestamp>:<total>:<available>:<free>')
    if df.empty:
        return None

    names = ["Total", "Available", "Free"]
    for index in range(3):
        yield go.Scatter(x=pd.to_datetime(df['timestamp'], unit='ms'),
                         y=(pd.to_numeric(df["extra_{}".format(index + 1)], downcast="integer") / 1024),
                         name=names[index],
                         showlegend=True)


def create_proc_mem_graphs(df):
    # 'proc_mem' metrics (format is 'mem:<timestamp>:<rss>:<shared>')
    if df.empty:
        return None

    names = ["RSS", "Shared"]
    for index in range(2):
        yield go.Scatter(x=pd.to_datetime(df['timestamp'], unit='ms'),
                         y=(pd.to_numeric(df["extra_{}".format(index + 1)], downcast="integer") / 1024),
                         name=names[index],
                         showlegend=True)


def create_gpu_mem_graphs(df):
    # 'gpu_mem' metrics (format is 'gpu_mem:<timestamp>:<total>:<used>')
    if df.empty:
        return None

    names = ["Used", "Total"]
    for index in range(2):
        yield go.Scatter(x=pd.to_datetime(df['timestamp'], unit='ms'),
                         y=(pd.to_numeric(df["extra_{}".format(index + 1)], downcast="integer") / 1024),
                         name=names[index],
                         showlegend=True)


def create_gpu_usage_graphs(df):
    # 'gpu' metrics (format is 'gpu:<timestamp>:<percentage_usage>')
    if df.empty:
        return None

    yield go.Scatter(x=pd.to_datetime(df['timestamp'], unit='ms'),
                     y=pd.to_numeric(df['extra_1']),
                     name="GPU usage",
                     showlegend=True)


def build_graphs(input_file, metrics):

    # Use a multiple subplots (https://plotly.com/python/subplots/) to display
    # - the execution task graph
    # - the CPU usage
    # - the memory usage (system and process)
    # - the GPU usage and memory
    graph_titles = []
    for metric in metrics:
        graph_titles.append(METRICS[metric])

    figs = make_subplots(rows=len(metrics),
                         cols=1,
                         shared_xaxes=True,
                         vertical_spacing=0.025,
                         subplot_titles=graph_titles
                         )

    with open(input_file) as file:
        global_df = read(file)

        row_index = 1
        for metric in metrics:
            if metric == 'time_exec':
                # Display a grant graph for representing task execution durations
                time_exec_df = gen_time_exec_df(filter_dataframe(global_df, metric))
                if time_exec_df is not None:
                    for trace in create_gantt_graph(time_exec_df).data:
                        figs.add_trace(trace, row=row_index, col=1)
            elif metric == 'cpu':
                # Display all CPU usages in the same graph
                for trace in create_cpu_graphs(filter_dataframe(global_df, metric)):
                    figs.add_trace(trace, row=row_index, col=1)
            elif metric == 'sys_mem':
                # Display memory usage
                for trace in create_sys_mem_graphs(filter_dataframe(global_df, metric)):
                    figs.add_trace(trace, row=row_index, col=1)
                figs.update_yaxes(row=row_index, col=1, rangemode="tozero")
            elif metric == 'proc_mem':
                # Display process memory usage
                for trace in create_proc_mem_graphs(filter_dataframe(global_df, metric)):
                    figs.add_trace(trace, row=row_index, col=1)
                figs.update_yaxes(row=row_index, col=1, rangemode="tozero")
            elif metric == 'gpu':
                # Display gpu usage
                for trace in create_gpu_usage_graphs(filter_dataframe(global_df, metric)):
                    figs.add_trace(trace, row=row_index, col=1)
                figs.update_yaxes(row=row_index, col=1, rangemode="tozero")
            elif metric == 'gpu_mem':
                # Display gpu memory
                for trace in create_gpu_mem_graphs(filter_dataframe(global_df, 'gpu_mem')):
                    figs.add_trace(trace, row=row_index, col=1)
                figs.update_yaxes(row=row_index, col=1, rangemode="tozero")
            row_index += 1

    figs.update_layout(
        height=1200,
        xaxis_tickformat='%M:%S',
        showlegend=False
    )

    return figs


def main():
    """
    Show the duration task graph and the cpu usage graph on a single view
    The tools reads the metrics file generated by the uprofile library
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('INPUT_FILE', type=str, help='Input file that contains profiling data')
    parser.add_argument('--output', '-o', type=str,
                        help='Save the graph to the given HTML file')
    parser.add_argument('--metric', type=str, dest='metrics', choices=METRICS.keys(), action='append', default=[],
                        help='Select the metric to display')
    args = parser.parse_args()

    if args.INPUT_FILE is None:
        parser.error('no INPUT_FILE given')

    if not args.metrics:
        args.metrics = METRICS.keys()

    graphs = build_graphs(args.INPUT_FILE, args.metrics)
    graphs.show()

    if args.output is not None:
        print("Saving graph to '{}'".format(args.output))
        graphs.write_html(args.output)


if __name__ == '__main__':
    main()