###Fit integrated GAMs - parity 3+

##Setup
#Load ONS data
load("data/ONS2018_allc.RData")

#Source functions
source("r/integrated_functions.r")

back <- FALSE # change to TRUE if performing backtesting


##GAM setup
source("r/p3_setup.r")
betadim <- ncol(B_A2)+ncol(B_C2)+ncol(B_T2)+ncol(B_AC4)
S1 <- S2 <- S3 <- S4 <- S5 <- matrix(0,betadim,betadim)
S1[1:ncol(B_A2),1:ncol(B_A2)] <- S_A2
S2[(1:ncol(B_C2)+ncol(B_A2)),(1:ncol(B_C2)+ncol(B_A2))] <- S_C2
S3[(1:ncol(B_T2)+ncol(B_A2)+ncol(B_C2)),(1:ncol(B_T2)+ncol(B_A2)+ncol(B_C2))] <- S_T2
S4[(1:ncol(B_AC4)+ncol(B_A2)+ncol(B_C2)+ncol(B_T2)),(1:ncol(B_AC4)+ncol(B_A2)+ncol(B_C2)+ncol(B_T2))] <- S_AF
S5[(1:ncol(B_AC4)+ncol(B_A2)+ncol(B_C2)+ncol(B_T2)),(1:ncol(B_AC4)+ncol(B_A2)+ncol(B_C2)+ncol(B_T2))] <- S_CF
X <- cbind(B_A2,B_C2,B_T2,B_AC4)
wt <- data3w$nw/data3w$n


##ONS data setup
data <- data.frame(age=ONS_births_dat[[4]]$Group.1,coh=ONS_births_dat[[4]]$Group.2,N=ONS_expos_dat[[4]]$x,n=ONS_births_dat[[4]]$x,x=ONS_rates_dat[[4]]$x)
data <- data[which(data$N>0 & data$n>=0 & data$n <= data$N & data$age %in% c(15:44)),]
if (back) data <- data[which(data$age + data$coh <= 2013),]
ONSN <- data$N
ONSn <- data$n
Nm <- nrow(data)
data3f <- expand.grid(age=agerange,coh=cohrange,gapc=1:11)
data3f <- left_join(data3f,data,by=c("age","coh"))
ACind <- which(!is.na(data3f$N))
ACind <- ACind[ACind <= nrow(data3f)/11]
data3f <- data3f[which(!is.na(data3f$N)),]
Xf <- cbind(B_A2all[A_indall[data3f$age-14],],
            B_C2all[C_indall[data3f$coh-1944],],
            B_T2[T_ind[data3f$gapc],],
            B_ACfull3[rep(ACind,11),])
Nf <- nrow(Xf)
Afind <- c(data3f$age-14)[1:Nm]


##TA model setup
newdata3 <- aggregate(gapc_3 ~ age_3, FUN = function(x) c(y=length(x), t1=length(x[x==1]), t2=length(x[x==2]), t3=length(x[x==3]), t4=length(x[x==4]), t5=length(x[x==5]), t6=length(x[x==6]), t7=length(x[x==7]), t8=length(x[x==8]), t9=length(x[x==9]), t10=length(x[x==10]), t11=length(x[x==11])))
newdata3 <- data.frame(newdata3$age_3,newdata3$gapc_3)
newdata3$yw <- aggregate(weights_3st ~ age_3, FUN = sum)$weights_3st
newdata3$wtmult <- newdata3$yw/newdata3$y
colnames(newdata3) <- c("a","y",paste0(1:11),"yw","wtmult")
newdata3 <- rbind(0,0,0,newdata3)
rownames(newdata3) <- agerange

y <- newdata3[,paste0(1:11)]
Na <- nrow(y)
a <- as.numeric(rownames(y))-median(as.numeric(rownames(y)))
a2 <- a^2
ab <- y
for (i in 3:11) ab[,i] <- c(rep(0,i-3),1:(Na-(i-3)))
ab[,2] <- 2:(Na+1)
ab <- ab[,-1]
ab[ab==0] <- 1
Nab <- max(ab)
wtc <- newdata3[paste(rownames(y)),"wtmult"]


##Fit integrated models
standata <- list(N=N,Nm=Nm,Nf=Nf,Na=Na,Nab=Nab,betadim=betadim,succ=succ,tot=tot,ONSN=ONSN,ONSn=ONSn,
                 X=X,Xf=Xf,wt=wt,wtc=wtc,y=y,a=a,ab=ab,Afind=Afind,
                 S1=S1,S2=S2,S3=S3,S4=S4,S5=S5)

#1:1
stanout <- stan(file="stan/p3_1_1.stan",data=standata,chains=1,iter=2000)
if (!back) save(stanout,file="output/p3_1_1.RData")
if ( back) save(stanout,file="output/p3_1_1_2013.RData")

# #9:1
# stanout <- stan(file="stan/p3_9_1.stan",data=standata,chains=1,iter=2000)
# if (!back) save(stanout,file="output/p3_9_1.RData")
# if ( back) save(stanout,file="output/p3_9_1_2013.RData")
# 
# #2:1
# stanout <- stan(file="stan/p3_2_1.stan",data=standata,chains=1,iter=2000)
# if (!back) save(stanout,file="output/p3_2_1.RData")
# if ( back) save(stanout,file="output/p3_2_1_2013.RData")
# 
# #1:2
# stanout <- stan(file="stan/p3_1_2.stan",data=standata,chains=1,iter=2000)
# if (!back) save(stanout,file="output/p3_1_2.RData")
# if ( back) save(stanout,file="output/p3_1_2_2013.RData")
# 
# #1:9
# stanout <- stan(file="stan/p3_1_9.stan",data=standata,chains=1,iter=2000)
# if (!back) save(stanout,file="output/p3_1_9.RData")
# if ( back) save(stanout,file="output/p3_1_9_2013.RData")
