# Function name: elda.cv
# Date: 07/2005
# By: Ronglai Shen
#----------------
# Description: The ELDA method based on the shrunken centroids method and a spectral decomposition
#----------------
# Usage: elda.cv(data, nfold = NULL, folds = NULL, cost=NULL,...)
#----------------
# Arguments:
# data: A list with two components: x- an expression matrix with genes in the rows, samples in the columns(rownames=genenames); y- a vector of the class labels for each sample. 
# nfold: Number of cross-validation folds. Default is ten. If specify nfolds equal to the number of samples n, then leave-one-out CV will be performed. 
# folds A list with nfold components, each component a vector of indices of the samples in that fold.
# cost: A K by K matrix with the row being the actual outcome class and column being the predicted class. The entry c(i,j)=c(j|i), and the diagonals==0.
#       The ij-th entry of the cost matrix is the penalty of assigning a sample to class j when it actually belongs to class i. 
#       Default is equal costs
#----------------
# Value:
# cv.errors: The number of cross-validation errors at each shrinkage threshold value.
# split.errors: The number of CV errors tabularized into each misclassification category at each shrinkage threshold value. 
# ave.nonzero: The number of genes after shrinkage selection at each threshold value averaged over all CV steps.
# ave.size: The number of genes further after spectral decomposition at each threshold value averaged over all CV steps.
# Dset: The appended list of genes (with redundant appearance) fromm all cross validation steps , its length should equal to ave.size*nfold.


elda.cv=function(data, nfold=10, folds=NULL, cost=NULL, threshold=NULL, n.threshold=30, prior=NULL, gene.center=FALSE, gene.scale=FALSE)
{
        x <- data$x 
        y <- data$y
        if(is.null(prior)){prior.onall=table(y)/length(y)}
        k=length(prior.onall)
        n=length(y)
        if(!is.null(threshold)){
            threshold=threshold
            n.threshold=length(threshold)                     
            }
        temp=onestep(x,y,threshold=NULL,prior=prior.onall, n.threshold=n.threshold, cost=cost)
        if (is.null(threshold)){                
            threshold.onall <- temp$threshold
            threshold=threshold.onall
            }
            if(is.null(folds)){
                folds = split(sample(1:n), rep(1:nfold, length=n))
            }
        norm.cen=NULL
        norm.sd=NULL
        if(gene.center)       
            {
             norm.cen <- apply(x, 1, mean)         
             x <- t(scale(t(x), center = norm.cen, scale = FALSE))
             }
             
        if(gene.scale)
        {
            norm.sd <- apply(x,1,sd)
            x <- t(scale(t(x), center = FALSE, scale = norm.sd))
        }
       

      
        this.call <- match.call()
       
        cv.errors=rep(NA,n.threshold)
        split.errors=matrix(NA,nrow=n.threshold,ncol=k*(k-1))
        predy=matrix(NA,nrow=n.threshold,ncol=n)
        nonzero=NULL
        size=NULL
        Dset=Dset.temp=alist()     
        for ( i in 1:nfold)
        {
            cat("elda","fold",i,"\n")
            a=onestep(x=as.matrix(x[,-folds[[i]]]),y=y[-folds[[i]]],prior=prior.onall,cost=cost,threshold=threshold, n.threshold=n.threshold)
            nonzero=cbind(nonzero,a$nonzero)
            size=cbind(size,a$size)
            Dset.temp[[i]]=a$Dset
#            aDset=a$Dset
#            short=n.threshold-length(aDset)
#            Dset[[i]]=ifelse(short>0, c(aDset, as.list(rep(NA,short))), aDset)
            
            for (j in 1:n.threshold)
            {
            pred.scores=sdsc.diagdisc((as.matrix(x[, folds[[i]]])-a$centroid.overall)/a$sd, a$delta.ori ,prior=a$prior, cost=a$cost, weight=a$geneid[[j]])
            predy[j,folds[[i]]] <- sdsc.softmin(pred.scores)
            }
        }
        for ( ii in 1:n.threshold)
        {
#            cons=NULL
#            for(l in 1:n) {
#                cons=c(cons, Dset.temp[[l]][[ii]])
#            }
#            Dset[[ii]]=names(sort(table(cons),decreasing=T))[1:round(size[ii],0)]
                      
            cons=NULL
            for(l in 1:nfold) {
                tempDset=alist()
                short=n.threshold-length(Dset.temp[[l]])
                tempDset=Dset.temp[[l]]
                if(short>0) tempDset= c(Dset.temp[[l]], as.list(rep(NA,short)))                            
                cons=c(cons, tempDset[[ii]])
            }
            Dset[[ii]]=names(sort(table(cons),decreasing=T))[1:round(size[ii],0)]
            

            cv.errors[ii]=sum(predy[ii,]!= y)
            tabular.errors=table(predy[ii,], y)
            ### to make sure tabular errors don't drop a column when a class wasn't predicted for once
            d=dim(tabular.errors)
            rnames=rownames(tabular.errors)
            cnames=colnames(tabular.errors)
            if(d[1]<k)
            {
                sdiff=setdiff(cnames,rnames)
                adj=matrix(0, nrow=k, ncol=length(sdiff))
                tabular.errors=t(cbind(t(tabular.errors),adj))
                nam=c(rnames,sdiff)
                rownames(tabular.errors)=nam
                tabular.errors=tabular.errors[cnames,]
            }
            
                                
            l=1
            for (i in 1:k)
            for(j in 1:k)
            {
                if(j!=i){
                split.errors[ii,l]=tabular.errors[i,j]
                l=l+1
                }
            }
                            
        }
        rownames(predy) <- seq(n.threshold)
        colnames(predy) <- paste(seq(n))   
        
        l=1; nam=NULL
        for (i in 1:k)
        for(j in 1:k)
            {
                if(j!=i){
                nam[l]=paste(i,"|",j,sep="")
                l=l+1
                }
            }
        colnames(split.errors)=nam
        
        ave.nonzero=apply(nonzero,1,mean)
        ave.size=apply(size,1,mean)
         

        out=list(prior=temp$prior,threshold=temp$threshold,centroid.overall=temp$centroid.overall, delta.ori=temp$delta.ori,
                 se.scale=temp$se.scale, sd=temp$sd,centroids=temp$centroids, sign.contrast=temp$sign.contrast,
                 norm.cen=norm.cen, norm.sd=norm.sd, threshold.scale=temp$threshold.scale,Dset=Dset,cv.errors=cv.errors[seq(n.threshold)], 
            ave.nonzero=round(ave.nonzero[seq(n.threshold)],3),ave.size=round(ave.size[seq(n.threshold)],3), split.errors=split.errors[seq(n.threshold),])
        class(out)="loocv"
        out
        
    }
