import numpy as np
import matplotlib.pyplot as plt
from matplotlib.offsetbox import AnchoredText

plt.ion()

dir_in = '/tank/chaocean/qjamet/RUNS/data_chao12/orar/'
dir_fig = '/tank/users/qjamet/Figures/publi/nature_amoc_rapid/'

#-- runs parameters --
nr = 46            # number of vertical grid points
ny = 900           # number of meridional grid points
nmem = 24          # number of members
ndump = 73         # number of outputs per year (5-day averaged)
nyr = 50           # number of years
nt = ndump*nyr     # total length of time series

#-- Model grid ; not to bother with loading model grid files --
zz = np.array([    0.        ,    -6.09354544,   -12.81455231,   -19.91337967,         -27.88404083,   -36.53453827,   -46.4291687 ,   -57.46109009,         -70.29701233,   -84.95201111,  -102.23623657,  -122.33073425,        -146.23381042,  -174.33418274,  -207.85084534,  -247.39579773,        -294.39660645,  -349.63720703,  -414.65155029,  -490.23425293,        -577.80517578,  -677.89978027,  -791.53027344,  -918.69213867,       -1059.765625  , -1214.07836914, -1381.37036133, -1560.41552734,       -1750.52783203, -1950.203125  , -2158.62451172, -2374.28271484,       -2596.45898438, -2823.80761719, -3055.81591797, -3291.36083984,       -3530.15087891, -3771.27294922, -4014.62646484, -4259.46728516,       -4505.84130859, -4753.12841797, -5001.47802734, -5250.35986328,       -5499.99365234, -5749.90966797])
#
yG = np.arange(-19.95829963684082, 55.04170608520508, 0.08333587646484375)
#-- max of AMOC total variance associated with the Gulf Stream --
jjj = 695
kkk = 24


#-------------------------------------------------------------------------------
# load ensemble AMOC time series ; organized as [nmem, nt, nr, ny].
# 5-days averaged AMOC time series have been detrended and 1-year low-pass filtered
# (see Supporting Infiormations of Jamet et al. (2019) ; https://doi.org/10.1029/2019GL082552).
# AMOC data are in 2 seprated files, which are then combined.
#-------------------------------------------------------------------------------
mocyzt = np.zeros((nmem, nt, nr, ny))

#- first 12 members -
fileN0 = 'MOCyzt_orar_ensemble_detrend_1ylpf.bin'
f2r = open(dir_in + fileN0,'r')
tmpmoc = np.fromfile(f2r, '>f4').reshape([int(nmem/2), nt, nr, ny])
f2r.close()
mocyzt[0:int(nmem/2), :, :, :] = tmpmoc
#- last 12 members -
fileN0 = 'MOCyzt_orar_ensemble_2_detrend_1ylpf.bin'
f2r = open(dir_in + fileN0,'r')
tmpmoc = np.fromfile(f2r, '>f4').reshape([int(nmem/2), nt, nr, ny])
f2r.close()
mocyzt[int(nmem/2):, :, :, :] = tmpmoc
del tmpmoc

#------------------------------------------------------
# Computed the different variance terms 
# expected to be < \sigma_t^2(f) > = \sigma_t^2 (<f>) + mean_t( \sigma_{ens}^2 (f) ) -
# < mean_t( f-<f> )^2 >
#------------------------------------------------------

#-- Initialize --
varT = np.zeros([nr, ny])   # Total variance
varF = np.zeros([nr, ny])   # Forced variance = temporal variance of ensemble mean
varI = np.zeros([nr, ny])   # Intrinsic variance = time mean of ensemble variance
res  = np.zeros([nr, ny])   # Residual variance = ensemble variance of time mean
ddof = 0                    #'Delta Degrees of Freedom': 0 -> biased estimator; 1 -> unbiased

for kk in range(nr):
  print("k= %i" %kk)
  for jj in range(ny):
    #- 50-yr long time series -
    moct = mocyzt[:, :, kk, jj]
    #-
    mocf = np.mean(moct, 0)
    moci = moct - np.tile( mocf[np.newaxis, :], (nmem, 1))
    #-- ensemble mean of the time-variance of each members ( < \sigma_t^2(f) > ) --
    tvar_tot = np.mean( np.var(moct, 1, ddof=ddof) )
    varT[kk, jj] = tvar_tot
    #-- time-variance of the ensemble mean ( \sigma_t^2 (<f>) ) --
    tvar_ens_mean = np.var(mocf, 0, ddof=ddof)
    varF[kk, jj] = tvar_ens_mean
    #-- time-mean of the ensemble spread (mean_t( \sigma_{ens}^2 (f) )) --
    tmean_var_ens = np.mean( np.var(moci, 0, ddof=ddof), 0)
    varI[kk, jj] = tmean_var_ens
    #-- residual (< mean_t( f-<f> )^2 > --
    res[kk, jj] = -np.var(moct.mean(1), 0, ddof=ddof) 
   

#-------------------
#	 PLOT
#-------------------

fig1 = plt.figure(figsize=(16,10))
llev1 = np.arange(0, 5.1, 0.1)*1e0
scaleF=1e3
fig1.clf()
#
ax1 = fig1.add_subplot(2, 2, 1)
cs1 = ax1.contourf(yG, zz/1000, varT, \
	levels=llev1, cmap='Blues_r', extend='max')
ax1.plot(yG[jjj],zz[kkk]/1000,'r*',markersize=20)
at1 = AnchoredText(r'$\left< \sigma^2(f_i(t)) \right>$ -- TOTAL', prop=dict(size=15), frameon=True, \
        loc='upper left')
at1.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
ax1.add_artist(at1)
#
ax2 = fig1.add_subplot(2, 2, 2)
cs2 = ax2.contourf(yG, zz/1000, varF, \
	levels=llev1, cmap='Blues_r', extend='max')
at2 = AnchoredText(r'$\sigma^2(\left<f_i(t)\right>)$ -- FORCED', prop=dict(size=15), frameon=True, \
        loc='upper left')
at2.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
ax2.add_artist(at2)
#
ax3 = fig1.add_subplot(2, 2, 3)
cs3 = ax3.contourf(yG, zz/1000, varI, \
	levels=llev1, cmap='Blues_r', extend='max')
at3 = AnchoredText(r'$\overline{\epsilon^2(f_i(t))}$ -- INTRINSIC', prop=dict(size=15), frameon=True, \
        loc='upper left')
at3.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
ax3.add_artist(at3)
#
ax4 = fig1.add_subplot(2, 2, 4)
cs4 = ax4.contourf(yG, zz/1000, -res*scaleF, \
	levels=llev1, cmap='Blues_r', extend='max')
at4 = AnchoredText(str(r'$\epsilon^2(\overline{f_i(t)})$ -- RESIDUAL ($\mathbf{\times %.01e}$)' % (1/scaleF)), \
	prop=dict(size=15), frameon=True, loc='upper left')
at4.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
ax4.add_artist(at4)

for ip in range(4):
  ax = fig1.add_subplot(2, 2, ip+1)
  ax.set_facecolor([0.5, 0.5, 0.5])
  if ( ip == 0 ) | ( ip == 2 ):
    ax.set_ylabel('Depth [km]', fontsize='x-large')
  if ( ip == 2 ) | ( ip == 3 ):
    ax.set_xlabel('Latitdue', fontsize='x-large')

#-- colorbar --
cbax1 = fig1.add_axes([0.92, 0.2, 0.01, 0.6])
cb1 = fig1.colorbar(cs1, ax=ax1, orientation='vertical', cax=cbax1)
cb1.set_label(r'[Sv$^{2}$]')


