#!/usr/bin/env python

import re
from optparse import OptionParser
import matplotlib.pyplot as plt

p = OptionParser(usage='%prog [OPTION] FILE...',
                 description='plot timings from gpaw parallel timer.  '
                 'The timer dumps a lot of files called "timings.<...>.txt".  '
                 'This programme plots the contents of those files.  '
                 'Typically one would run "%prog timings.*.txt" to plot '
                 'timings on all cores.')
p.add_option('--interval', metavar='TIME1:TIME2',
             help='plot only timings within TIME1 and TIME2 '
             'after start of calculation.')
p.add_option('--align', metavar='NAME', default='SCF-cycle',
             help='horizontally align timings of different ranks at first '
             'call of NAME [default=%default]')
p.add_option('--nolegend', action='store_true',
             help='do not plot the separate legend figure')
p.add_option('--nointeractive', action='store_true',
             help='disable interactive legend')


opts, fnames = p.parse_args()


if opts.interval:
    plotstarttime, plotendtime = map(float, opts.interval.split(':'))
else:
    plotstarttime = 0
    plotendtime = None


# We will read/store absolute timings T1 and T2, which are probably 1e9.
# For the plot we want timings relative to some starting point.
class Call:
    def __init__(self, name, T1, level, rankno):
        self.name = name
        self.level = level  # nesting level
        self.T1 = T1
        self.T2 = None
        self.rankno = rankno

    @property
    def t1(self):
        return self.T1 - alignments[self.rankno]

    @property
    def t2(self):
        return self.T2 - alignments[self.rankno]


fig = plt.figure()
ax = fig.add_subplot(111)
fig.subplots_adjust(left=0.08, right=.95, bottom=0.07, top=.95)
patch_name_p = []


class Function:
    thecolors = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow',
                 'darkred', 'indigo', 'springgreen', 'purple']
    thehatches = ['', '//', 'O', '*', 'o', r'\\', '.', '|']

    def __init__(self, name, num):
        self.name = name
        self.num = num
        self.color = self.thecolors[num % len(self.thecolors)]
        self.hatch = self.thehatches[num // len(self.thecolors)]
        self.bar = None
        self.calls = []


# Example: 'T42  >>  12.34  name  (5.51s)    started'
pattern = r'T\d+ \S+ (\S+) (.+?) \(.*?\) (started|stopped)'
functions = {}
functionslist = []
maxlevel = 0
alignments = []
firstcallsbyrank = []
lastcallsbyrank = []


for rankno, fname in enumerate(fnames):
    alignment = None
    ongoing = []
    with open(fname) as fd:
        for lineno, line in enumerate(fd):
            m = re.match(pattern, line)
            if m is None:
                failing_line = line
                break  # Bad syntax

            T, name, action = m.group(1, 2, 3)
            T = float(T)

            if action == 'started':
                level = len(ongoing)
                maxlevel = max(level, maxlevel)
                call = Call(name, T1=T, level=level, rankno=rankno)
                if lineno == 0:
                    assert len(firstcallsbyrank) == rankno
                    firstcallsbyrank.append(call)
                if alignment is None and name == opts.align:
                    alignment = call.T1
                if name not in functions:
                    function = Function(name, len(functions))
                    functions[name] = function
                    functionslist.append(function)
                ongoing.append(call)
            elif action == 'stopped':
                assert action == 'stopped', action
                call = ongoing.pop()
                assert name == call.name
                call.T2 = T
                functions[name].calls.append(call)

        # If we failed there may still be lines left.  If there are no
        # lines left, only last line was mangled (file was incomplete) and
        # that is okay.  Otherwise it's an error:
        for line in fd:
            p.error('Bad syntax: {}'.format(failing_line))

    assert alignment is not None, 'Cannot align to "{}"'.format(opts.align)
    alignments.append(alignment)

    # End any remaining ongoing calls:
    for call in ongoing:
        call.T2 = T
        functions[call.name].calls.append(call)
    lastcallsbyrank.append(call)

tmp_tmin = min([call.t1 for call in firstcallsbyrank])
alignments = [a + tmp_tmin for a in alignments]  # Now timings start at 0
tmax = max([call.t2 for call in lastcallsbyrank])


if plotendtime is None:
    plotendtime = tmax

for name in sorted(functions):
    function = functions[name]
    plotcalls = []
    for call in function.calls:
        # Skip timings that fall out of the viewed interval:
        if call.t2 < plotstarttime:
            continue
        if call.t1 > plotendtime:
            continue
        plotcalls.append(call)

    # Shift: maxlevel, rank
    bar = ax.bar([call.t1 for call in plotcalls],
                 height=[1.0] * len(plotcalls),
                 width=[call.T2 - call.T1 for call in plotcalls],
                 bottom=[call.level + call.rankno * (maxlevel + 1)
                         for call in plotcalls],
                 color=function.color,
                 hatch=function.hatch,
                 edgecolor='black',
                 align='edge',
                 label=function.name)
    for child in bar.get_children():
        patch_name_p.append((child, name))
    ax.axis(xmin=plotstarttime, xmax=plotendtime)


if not opts.nolegend:
    ncolumns = 1 + len(functionslist) // 32
    namefig = plt.figure()
    nameax = namefig.add_subplot(111)

    for function in functionslist:
        nameax.bar([0], [0], [0], [0],
                   color=function.color, hatch=function.hatch,
                   label=function.name[:20])

    nameax.legend(handlelength=2.5,
                  labelspacing=0.0,
                  fontsize='large',
                  ncol=ncolumns,
                  mode='expand',
                  frameon=True,
                  loc='best')


if not opts.nointeractive:
    default_text = 'Click on a patch to get the name'
    p = fig.subplotpars

    # Create interactive axes
    box = (p.left, p.top, p.right - p.left, 1 - p.top)
    fc = fig.get_facecolor()
    try:
        iax = fig.add_axes(box, facecolor=fc)  # matplotlib 2.x
    except AttributeError:
        iax = fig.add_axes(box, axisbg=fc)  # matplotlib 1.x

    ibg = iax.patch
    for s in iax.get_children():
        if s != ibg:
            s.set_visible(False)
    itext = iax.text(0.5, 0.5, default_text, va='center', ha='center',
                     transform=iax.transAxes)

    def print_name_event(event):
        text = default_text
        # Do action based on the chosen tool
        # tb = fig.canvas.manager.toolbar
        # if tb.mode != '':
        #     return
        for patch, name in patch_name_p:
            if patch.contains(event)[0]:
                text = name
                break
        itext.set_text(text)
        # The next lines can be replace with
        # fig.canvas.draw()
        # but it'll be very slow
        iax.draw_artist(ibg)
        iax.draw_artist(itext)

        canvas = fig.canvas
        if hasattr(canvas, 'update'):
            canvas.update()  # matplotlib 0.x
        else:
            canvas.blit()  # matplotlib 2.x
        canvas.flush_events()

    fig.canvas.mpl_connect('button_press_event', print_name_event)


plt.show()
