import numpy as np
import matplotlib.pyplot as plt

# Model parameters
ALPHA_1 = -2.16
ALPHA_2 = -1.46
A = 0.97
M_GAP_LO = 2.72
M_GAP_HI = 6.13
ETA_GAP_LO = 50
ETA_GAP_HI = 50
ETA_MIN = 50
ETA_MAX = 4.91
BETA = 1.89
M_MIN = 1.16
M_MAX = 54.38

def lopass(m, m_0, eta):
    return 1 / (1 + (m / m_0) ** eta)

def hipass(m, m_0, eta):
    return 1 - lopass(m, m_0, eta)

def bandpass(m, m_lo, m_hi, eta_lo, eta_hi, A):
    return 1 - A * hipass(m, m_lo, eta_lo) * lopass(m, m_hi, eta_hi)

def mass_distribution_1d(m):
    return (
        bandpass(m, M_GAP_LO, M_GAP_HI, ETA_GAP_LO, ETA_GAP_HI, A)
        * hipass(m, M_MIN, ETA_MIN)
        * lopass(m, M_MAX, ETA_MAX)
        * (m / M_GAP_HI) ** np.where(m < M_GAP_HI, ALPHA_1, ALPHA_2)
    )

m = np.geomspace(1, 100, 2000)
fig, ax = plt.subplots()
ax.set_xscale("log")
ax.set_yscale("log")

# Violet: '#9400D3', Navy: '#001F75'
ax.plot(m, m * mass_distribution_1d(m), color='navy', linewidth=2, label='Mass distribution')

ax.set_xlim(1, 100)
ax.set_ylim(0, 100)
ax.set_xlabel(r"mass, $m$ [$M_\odot$]")
ax.set_ylabel(r"$m\,p(m|\lambda)$")

# Mass gap region in light blue, dashed lines
y_min, y_max = ax.get_ylim()
ax.fill_between([1.9, 2.9], y_min, y_max, color="blue", alpha=0.12, zorder=1)
ax.axvline(x=1.9, color="blue", linestyle="--", alpha=0.7)
ax.axvline(x=2.9, color="blue", linestyle="--", alpha=0.7)
ax.text(
    2.4, y_min-0.01, r"$2.4^{+0.5}_{-0.5}$",
    ha="center", va="bottom", fontsize=11, fontweight="bold", color="blue"
)

ax2 = ax.twiny()
ax2.set_xlim(ax.get_xlim())
ax2.set_xscale(ax.get_xscale())
ax2.set_xticks([M_MIN, M_GAP_LO, M_GAP_HI, M_MAX])
ax2.set_xticklabels(
    [
        r"$M_\mathrm{min}$",
        r"$M^\mathrm{gap}_\mathrm{low}$",
        r"$M^\mathrm{gap}_\mathrm{high}$",
        r"$M_\mathrm{max}$",
    ]
)
ax2.grid(axis="x")
fig.tight_layout()
fig.show()