from astropy.visualization import quantity_support
from astropy import units as u
from matplotlib import pyplot as plt
from matplotlib import rcParams
from matplotlib.transforms import blended_transform_factory
import numpy as np

quantity_support()

gradient = np.linspace(0, 1, 256)[np.newaxis, :]**0.5

plot_data = [[['Sky Localization', -20 * u.second, 5 * u.second],
              ['Classification', -20 * u.second, 5 * u.second],
              ['Detection', -50  * u.second, 30 * u.second]],
             [['Sky Localization', 30 * u.second, 5 * u.second],
              ['Classification', 30 * u.second, 5 * u.second],
              ['Automated Vetting', 30 * u.second, 5 * u.second],
              ['Detection', 16 * u.second, 14 * u.second]],
             [['Re-annotate', 3.6 * u.min, 1 * u.min],
              ['Cluster additional events', 0.8 * u.min, 2.8 * u.min]],
             [['Classification', 4 * u.hour, 10 * u.minute],
              ['Human Vetting', 5 * u.minute, 4 * u.hour],
              ['Parameter Estimation', 75 * u.second, 4 * u.hour]],
             [['Classification', 6 * u.day, 6 * u.hour],
              ['Parameter Estimation', 4 * u.hour, 6 * u.day]]]

alert_labels = ['Early Warning\nAlert Sent',
                '1st Preliminary\nAlert Sent',
                '2nd Preliminary\nAlert Sent',
                'Initial Alert or\nRetraction Sent',
                'Update\nAlert Sent']
bar_height = 0.8

xlim = [-10 * u.minute, 100 * u.day]

fig, axs = plt.subplots(
    len(plot_data),
    sharex=True,
    figsize=(8, 4),
    gridspec_kw=dict(
        height_ratios=[len(_) + 1 - bar_height for _ in plot_data],
        top=0.9, left=0, right=1, hspace=0.05, bottom=0.1
    ))

for ax, data, alert_label, props in zip(axs, plot_data, alert_labels, rcParams['axes.prop_cycle']):
    ax.spines['left'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_yticks([])
    ax.xaxis.label.set_visible(False)
    ax.set_facecolor('0.95')

    labels, starts, durations = zip(*data)
    starts = u.Quantity(starts).to(u.second)
    durations = u.Quantity(durations).to(u.second)

    t = max(starts + durations)
    if t.value > 0:
        tp = t * 1.1
        tp2 = t * 1.25
        tp10 = 10*t
    else:
        tp = t + 3 * u.second
        tp2 = t + 8 * u.second
        tp10 = t + 40 * u.second
    ax.axvline(tp, color='black')
    ax.text(tp2, 0.5, alert_label,
            transform=blended_transform_factory(ax.transData, ax.transAxes),
            fontweight='bold', va='center', ha='left')

    ax.barh(np.arange(len(labels)), width=durations,
            left=starts, height=bar_height,
            facecolor=props['color'], edgecolor='black')
    for i, (start, duration, label) in enumerate(zip(starts, durations, labels)):
        ax.text(start, i,
                ' ' + label + ' ', ha='right', va='center')
    ax.set_ylim(0.5 * bar_height - 1, len(labels) - 0.5 * bar_height)
    ax.imshow(gradient, extent=[tp.value, tp10.value, -10, 10], cmap='Greys_r', vmin=-1, vmax=1, aspect='auto')
    ax.axvspan(tp10, xlim[1], color='white')

fig.suptitle('Time relative to gravitational-wave merger')
ax.set_xscale('symlog', linthresh=40)
ax.set_xlim(-10 * u.minute, 100 * u.day)
ticks = [-30 * u.second, 0 * u.second, 30 * u.second, 3 * u.minute, 1 * u.hour, 1 * u.day, 1 * u.week]
ax.set_xticks(ticks)
ax.set_xticklabels([
    '{0.value:g} {0.unit.short_names[0]}'.format(_) if abs(_) < 1 * u.minute else
    '{0.value:g} {0.unit.long_names[0]}'.format(_)
    for _ in ticks
])
ax.minorticks_off()
ax.set_xlabel('Time since GW signal')
axs[-1].arrow(0, 0, 1, 0,
              transform=ax.transAxes, clip_on=False,
              head_width=0.1, head_length=0.01,
              linewidth=axs[-1].spines['bottom'].get_linewidth(),
              edgecolor=axs[-1].spines['bottom'].get_edgecolor(),
              facecolor=axs[-1].spines['bottom'].get_edgecolor(),
              length_includes_head=True)
for ax in axs[:-1]:
    plt.setp(ax.xaxis.get_major_ticks(), visible=False)