# Function name: elda.predict
# Date: 07/2005
# By: Ronglai Shen
#----------------
# Description: A function giving test errors from the elda.cv fit
#----------------
# Usage: elda.predict(fit, newx, cost=NULL,...)
#----------------
# Arguments:
# fit: model fit from elda.cv
# newx: Test data. A list with two components: x- an expression genes in the rows, samples in the columns), and y- a vector of the class labels for each sample. 
# cost: A K by K matrix with the row being the actual outcome class and column being the predicted class. The entry c(i,j)=c(j|i), and the diagonals==0.
#       The ij-th entry of the cost matrix is the penalty of assigning a sample to class j when it actually belongs to class i. 
#       Default is equal costs.
#----------------
# Value:
# pred.errors: The number of test errors at each shrinkage threshold value.
# pred.split.errors: The number of test errors tabularized into each misclassification category at each shrinkage threshold value. 
# Dset: The final classifier trained from elda.cv, of size fit$ave.size, the set of genes with top appearance frequency.
#       For example, pred=elda.predict(fit, newx, cost=NULL), then pred$Dset is a list of final classifiers at each threshold value.



elda.predict=function (fit, newx, cost=NULL, threshold=fit$threshold, n.threshold=20 ,prior = fit$prior, threshold.scale = fit$threshold.scale) 
{
    testx=newx$x
    testy=newx$y
    k=length(prior)
    n=length(testy)
    n.threshold=length(threshold)
    
    norm.cen <- fit$norm.cen
    norm.sd=fit$norm.sd
    if (!is.null(norm.cen))
    {
        testx <- t(scale(t(testx), center = norm.cen, scale = FALSE))
    }
    if(!is.null(norm.sd))
    {
        testx <- t(scale(t(testx), center = FALSE, scale = norm.sd))
    }
    
    
    sd <- fit$sd
    centroid.overall <- fit$centroid.overall
    centroids <- fit$centroids
    se.scale <- fit$se.scale
    delta.ori=fit$delta.ori
    n.threshold=length(threshold)
    
#    size0=round(fit$ave.size,0)   
#    Dset.out=alist()
#    for ( i in 1:n.threshold)
#    {
#        Dset.out[[i]]=names(sort(table(fit$Dset[[i]]),decreasing=T))[1:size0[i]]
#    }

    Dset.out=fit$Dset
       
    yhat <- as.list(seq(n.threshold))
    pred.errors <- NULL
    size=NULL
    dd=as.list(seq(n.threshold))
    pred.split.errors=matrix(NA,nrow=n.threshold,ncol=k*(k-1))
    for ( ii in 1:n.threshold)
    {
        classifier=Dset.out[[ii]]   
        size[ii]<-length(classifier)
        geneid=is.element(rownames(testx),classifier)
        dd[[ii]] <- sdsc.diagdisc(x=(testx - centroid.overall)/sd, centroids=delta.ori, 
              prior=prior, cost=cost,weight=geneid)
        rownames(dd[[ii]])=names(testy)
        colnames(dd[[ii]])=seq(1:k)
        yhat[[ii]] <- sdsc.softmin(dd[[ii]])
        
    if(!is.null(testy)) {
         pred.errors[ii] <- sum(yhat[[ii]] != testy)
       }    
            tabular.errors=table(yhat[[ii]], testy)
            ### to make sure tabular errors don't drop a column when a class wasn't predicted for once
            d=dim(tabular.errors)
            rnames=rownames(tabular.errors)
            cnames=colnames(tabular.errors)
            if(d[1]<k)
            {
                sdiff=setdiff(cnames,rnames)
                adj=matrix(0, nrow=k, ncol=length(sdiff))
                tabular.errors=t(cbind(t(tabular.errors),adj))
                nam=c(rnames,sdiff)
                rownames(tabular.errors)=nam
                tabular.errors=tabular.errors[cnames,]
            }
            
                                
            l=1
            for (i in 1:k)
            for(j in 1:k)
            {
                if(j!=i){
                pred.split.errors[ii,l]=tabular.errors[i,j]
                l=l+1
                }
            }
                            
      }
        
        l=1; nam=NULL
        for (i in 1:k)
        for(j in 1:k)
            {
                if(j!=i){
                nam[l]=paste(i,"|",j,sep="")
                l=l+1
                }
            }
        colnames(pred.split.errors)=nam
        
        out=list(pred.errors=pred.errors, pred.split.errors=pred.split.errors, size=size, Dset=Dset.out, dd=dd, nonzero=fit$nonzero, threshold=threshold)
        out
    

}
