import numpy as np
import matplotlib.pyplot as plt
# --- Set up plot styling ---
plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.size'] = 12
# Detuning range in units of Omega_0
delta_range = np.linspace(-3, 3, 500)
offset = 0.001
delta_rangen = np.linspace(-3, -offset, 250)
delta_rangep = np.linspace(offset, 3, 250)
# Rabi frequency (we'll show multiple values)
Omega0_values = [1]
colors = ['#e74c3c']
labels = [r'Dressed states']
# Create figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7.5, 4))
# ============= LEFT PLOT: ROTATING FRAME (DRESSED STATES) =============
ax1.set_title('Dressed State Energies (Rotating Frame)', fontsize=13, fontweight='bold')
# Plot uncoupled (bare) states as dashed lines
# Plot dressed states for different Omega_0 values
for i, Omega0 in enumerate(Omega0_values):
# Dressed state energies: E_± = ±(1/2)√(δ² + |Ω₀|²)
Omega = Omega0*np.sqrt(delta_range**2 + 1)
E_plus = 0.5 * Omega/Omega0
E_minus = -0.5 * Omega
ax1.plot(delta_range, E_plus, color=colors[i], linewidth=2, label=labels[i])
ax1.plot(delta_range, E_minus, color=colors[i], linewidth=2)
ax1.plot(delta_range, -delta_range/2, 'k--', linewidth=1.5, alpha=0.5, label='Bare states')
ax1.plot(delta_range, delta_range/2, 'k--', linewidth=1.5, alpha=0.5)
ax1.axhline(y=0, color='gray', linestyle=':', alpha=0.3)
ax1.axvline(x=0, color='gray', linestyle=':', alpha=0.3)
ax1.set_xlabel('$\\delta / \\Omega_0$', fontsize=12)
ax1.set_ylabel('$E / (\\hbar \\Omega_0)$', fontsize=12)
ax1.legend(fontsize=10, loc='upper right')
ax1.grid(True, alpha=0.3)
ax1.set_xlim(-3, 3)
ax1.set_ylim(-2, 2)
# Add annotation for avoided crossing
ax1.annotate('', xy=(0, .5), xytext=(0, -.5),
arrowprops=dict(arrowstyle='<->', color='gray', lw=1.5))
ax1.text(0.15, 0, 'Autler-Townes\nsplitting', fontsize=9, color='gray')
# ============= RIGHT PLOT: LABORATORY FRAME =============
ax2.set_title('Energies in Laboratory Frame', fontsize=13, fontweight='bold')
# For the lab frame, we add ±ω₀/2 and use sgn(δ) for the shifts
# We'll set ω₀ = 10 in units of Omega_0 for visualization
omega_0 = 3
for i, Omega0 in enumerate(Omega0_values):
Omega = np.sqrt(delta_rangen**2 + Omega0**2)
sgn_delta = np.sign(delta_rangen)
# E_e = (ω₀/2) + (δ/2) - sgn(δ)·√(δ² + |Ω₀|²)/2
E_e = 0.5*omega_0 + 0.5*delta_rangen - 0.5*sgn_delta*Omega
# E_g = -(ω₀/2) - (δ/2) + sgn(δ)·√(δ² + |Ω₀|²)/2
E_g = -0.5*omega_0 - 0.5*delta_rangen + 0.5*sgn_delta*Omega
ax2.plot(delta_rangen, E_e, color=colors[0], linewidth=2, label=labels[0])
ax2.plot(delta_rangen, E_g, color=colors[0], linewidth=2)
Omega = np.sqrt(delta_rangep**2 + Omega0**2)
sgn_delta = np.sign(delta_rangep)
# E_e = (ω₀/2) + (δ/2) - sgn(δ)·√(δ² + |Ω₀|²)/2
E_e = 0.5*omega_0 + 0.5*delta_rangep - 0.5*sgn_delta*Omega
# E_g = -(ω₀/2) - (δ/2) + sgn(δ)·√(δ² + |Ω₀|²)/2
E_g = -0.5*omega_0 - 0.5*delta_rangep + 0.5*sgn_delta*Omega
ax2.plot(delta_rangep, E_e, color=colors[0], linewidth=2)
ax2.plot(delta_rangep, E_g, color=colors[0], linewidth=2)
# Plot bare states (without coupling)
E_e_bare = 0.5*omega_0 + 0.*delta_range
E_g_bare = -0.5*omega_0 - 0.*delta_range
ax2.plot(delta_range, E_e_bare, 'k--', linewidth=1.5, alpha=0.5, label=r'$\pm\frac{\hbar\omega_0}{2}$')
ax2.plot(delta_range, E_g_bare, 'k--', linewidth=1.5, alpha=0.5)
#ax2.axhline(y=0, color='gray', linestyle=':', alpha=0.3)
ax2.axvline(x=0, color='gray', linestyle=':', alpha=0.3)
ax2.set_xlabel('$\\delta / \\Omega_0$', fontsize=12)
#ax2.set_ylabel('$E / (\\hbar \\Omega_0)$', fontsize=12)
ax2.legend(fontsize=10, loc='upper right')
ax2.grid(True, alpha=0.3)
ax2.set_yticks([])
ax2.set_xlim(-3, 3.05)
ax2.set_ylim(-2.4, 2.4)
# Add annotation for avoided crossing
arrow_xpos = -1.5
ax2.annotate('', xy=(arrow_xpos, -omega_0/2), xytext=(arrow_xpos, omega_0/2),
arrowprops=dict(arrowstyle='<->', color='gray', lw=1.5))
ax2.text(arrow_xpos+0.15, 0, r'$\hbar\omega_0$', fontsize=9, color='gray')
# --- Add the fake break here ---
# Coordinates for the rectangle to hide the middle section
break_y_min = -0.1
break_y_max = 0.1
#break_x_min = -0.1 # x position of left spine
#break_x_max = 0.3 # Width of the break patch, set slightly larger than the spine
# Add a white rectangle to hide the axis spine in the break region
#rect = Rectangle((break_x_min, break_y_min), break_x_max, break_y_max - break_y_min,
# facecolor='white', edgecolor='white', zorder=3, transform=ax.transData)
#ax.add_patch(rect)
ax2.spines['left'].set_visible(False)
ax2.spines['right'].set_visible(False)
ax2.vlines(ax2.get_xlim()[0],ax2.get_ylim()[0], break_y_min+0.05, color='black', linewidth=1)
ax2.vlines(ax2.get_xlim()[0],ax2.get_ylim()[1], break_y_max, color='k', linewidth=1)
ax2.vlines(ax2.get_xlim()[1]-0.05,ax2.get_ylim()[0], break_y_min+0.05, color='k', linewidth=1)
ax2.vlines(ax2.get_xlim()[1]-0.05,ax2.get_ylim()[1], break_y_max, color='black', linewidth=1)
# Add the break symbol "//"
ax2.text(ax2.get_xlim()[0], (break_y_min + break_y_max) / 2+0.03, r'$//$',
ha='center', va='center', fontsize=14, color='black', rotation=90)
ax2.text(ax2.get_xlim()[1]-0.05, (break_y_min + break_y_max) / 2+0.03, r'$//$',
ha='center', va='center', fontsize=14, color='black', rotation=90)
plt.tight_layout()
plt.show()