library(gclus)
library(plyr)
library(reshape2)
library(dplyr)
library(zoo)
library(RColorBrewer)
library(scales)
library(xlsx)

scr_gen_niche_counts=function(sc_object,T_totumi=4){  
  rm(scr_niche_counts);rm(scr_sn2_counts);rm(scr_sn2_footprint);rm(scr_niche_footprint)
  
  scr_umis=as.matrix(sc_object@scmat@mat)
  f = rowSums(scr_umis) > T_totumi;  nms = names(which(f))
  
  niche_counts=t(apply(scr_umis[nms,names(sc_object@clusts)],1,function(x) tapply(x, sc_object@clusts, sum)))
  scr_niche_counts<<-niche_counts[,cluster_order]   
  
  
  niche_totals=colSums(scr_niche_counts)
  niche_umifrac=t(apply(scr_niche_counts,1,function(x) x*1000/niche_totals))
  scr_cluster_umifrac<<-niche_umifrac
  
  
  niche_umifrac_n=sapply(as.character(cluster_order),function(x) 
                          (niche_umifrac[,x]+0.1)/apply(niche_umifrac[,setdiff(colnames(niche_umifrac),x)]+0.1,1,median))
  #niche_umifrac_n[which(niche_counts==ths_umis_niche_fc)]<-1
  scr_cluster_umifrac_n<<-niche_umifrac_n
  
  us = scr_umis[nms,names(sc_object@clusts)]
  clust_geomean =  t(apply(us[,], 1, 
                           function(x) tapply(x, 
                                              sc_object@clusts, 
                                              function(y) exp(mean(log(1+y)))-1)))
  clust_meansize = tapply(colSums(us), sc_object@clusts, mean)
  niche_geomean = t(800*t(clust_geomean)/as.vector(clust_meansize))
  niche_geomean_n= sapply(as.character(cluster_order),function(x) 
    (niche_geomean[,x]+0.1)/apply(niche_geomean[,setdiff(colnames(niche_geomean),x)]+0.1,1,median))

  scr_cluster_geomean<<-niche_geomean[,cluster_order]
  scr_cluster_geomean_n<<-niche_geomean_n[,cluster_order]
}


scr_barplot_heatmap_markers=function(sc_cl_object,markers_file,barplot_dir,heatmap_file,w=4000,
                                     chisq_threshold,T_totumi=15,
                                     col=NULL,
                                     print_barplots=FALSE,pmin=3,
                                     nor_statistic=median,min_gene_fc=2,set_pmin=F){
  #us = scr_umis_norm[rownames(tfs),names(scr_niche_refined_km_red)] 
  
  
  clusts=sc_cl_object@clusts
  footprint_table=sc_cl_object@clust_fp[,cluster_order]
  #footprint_table=scr_cluster_geomean_n[intersect(rownames(scr_niche_counts),rownames(scr_cluster_geomean_n)),cluster_order]
  
  if(!exists("scr_niche_counts")){
	message("I will compute scr_niche_counts matrix")
	niche_counts= .row_stats_by_factor(sc_cl_object@scmat@mat[rownames(sc_cl_object@clust_fp),], clusts, rowSums)
	scr_niche_counts <<- niche_counts
  }
  
  if(!is.null(col)){
    color=as.character(col)
  }else { color="darkgrey"}
 
  markers_table=read.table(markers_file,header=FALSE,row.names=1,sep="\t")
  markers=intersect(rownames(markers_table),rownames(footprint_table))
  
  markers_qvalues=scr_chisq_intraniche_markers(markers) 
  
  f_marker_cov = rowSums(scr_niche_counts[markers,]) > T_totumi
  f_marker_chisq = markers_qvalues < chisq_threshold
  f_marker_fc=apply(sc_cl_object@clust_fp[markers,],1,max) > min_gene_fc
  
  
  
  markers=intersect(names(which(f_marker_cov)),names(which(f_marker_chisq)))
  markers=intersect(markers,names(which(f_marker_fc)))
  
  ###define heatmap height based on number of markers
  h=pmin(pmax(length(markers)*30,400),12000)
  
  
  #marker_fp = scr_niche_footprint[markers,as.character(cluster_order)]
  #marker_fp = scr_cluster_umifrac[markers,as.character(cluster_order)]
  marker_fp = footprint_table[markers,as.character(cluster_order)]
  markers_sorted=markers[as.numeric(order(apply(marker_fp, 1,function(x) which.max(rollmean(x,1)))))]
  #markers_sorted=markers
  
  markers_tot_umi=rowSums(scr_niche_counts[markers_sorted,])
  marker_max_umi_frac=round(apply(scr_niche_counts[markers_sorted,]/markers_tot_umi,1,max),2)
  
  
  print(paste0(length(markers_sorted),"  markers survived"))
  shades2=colorRampPalette(c("white","white","orange","red","purple","black"))(1000)
  #shades2=colorRampPalette(rev(RColorBrewer::brewer.pal(11,"RdBu")))(1000)
  ###CHECKOUT FTP NORMALIZATION
  #marker_fp_to_plot=pmin((0.1+marker_fp[markers_sorted,])/(0.1+apply(marker_fp[markers_sorted,],1,nor_statistic)),pmin)
  #marker_fp_to_plot=pmin(scr_cluster_umifrac_n[markers_sorted,cluster_order],pmin)
  
  if(set_pmin){
	pmin=pmax(quantile(footprint_table[markers_sorted,cluster_order],0.99),4.5)
  }
  message("I'm using PMIN   ", pmin)
  marker_fp_to_plot=pmin(log2(footprint_table[markers_sorted,cluster_order]+1),pmin)
  
  #marker_fp_to_plot=pmax(pmin(log2(footprint_table[markers_sorted,cluster_order]),2),-2)
  png(heatmap_file,heigh=h,width=w) 
    par(mar=c(5,40,5,60))
    image(t(marker_fp_to_plot[markers_sorted,]), xaxt='n',yaxt='n', col=shades2)
    mtext(cluster_order, side=1,at=seq(0,1,length.out=length(cluster_order)), las=1,cex=2,line=2,las=2)
    mtext(cluster_order, side=3,at=seq(0,1,length.out=length(cluster_order)), las=1,cex=2,line=2,las=2)
	
	if(ncol(markers_table)==1){
		mtext(paste(markers_sorted,markers_tot_umi,marker_max_umi_frac,sep=" "), line=2,side=2,at=seq(0,1,length.out=length(markers_sorted)), las=1,cex=2,adj=1)
	}else{
		mtext(markers_table[markers_sorted,2], side=2,at=seq(0,1,length.out=length(markers_sorted)), line=2,las=1,cex=3,adj=1)
	}
    
    mtext(markers_table[markers_sorted,1], side=4,at=seq(0,1,length.out=length(markers_sorted)), line=2,las=1,cex=3,adj=0)
    
  for(i in 1:length(cluster_order)){       abline(v=(i-0.5)/(length(cluster_order)-1), lwd=0.5)  }
  for(i in 1:length(markers_sorted)) {       abline(h=(i-0.5)/(length(markers_sorted)-1), lwd=0.5)   }  
      
  dev.off()  
    
  if(print_barplots){
    for(nm in markers) {
      fn = sprintf("%s/%s_barplot.png", barplot_dir, nm)
      png(fn,h=300,w=1000)
      #      barplot(scr_cluster_umifrac[nm,],main=markers_table[nm,1],col=as.character(scr_sn_table[cluster_order,"color"]),cex.axis=3,cex.main=2,las=2)

      barplot(scr_cluster_umifrac[nm,cluster_order],main=markers_table[nm,1],cex.axis=3,cex.main=2,las=2,border=F,col="gray33",space=0.1)
	  dev.off()
    } 
  }
}


scr_2d_project_markers=function(sc2d_object,sc_object,markers_file,base_dir="./gene_projections/",reg_factor=9,auto_reg_factor=F,min_FC=1.5,filter_by_FC=T,T_totumi=30){
	
	markers_table=as.data.frame(read.table(markers_file,header=FALSE,row.names=1,sep="\t"))
	
	umis=as.matrix(sc_object@scmat@mat)
	
	f_cov=names(which(rowSums(umis[intersect(rownames(umis),rownames(markers_table)),])>T_totumi))
	
	if(filter_by_FC){
		f_marker_fc=names(which(apply(sc_object@clust_fp[intersect(rownames(sc_object@clust_fp),rownames(markers_table)),],1,max) > min_FC))
		genes=intersect(f_cov,f_marker_fc)
	}else{
		genes=f_cov
	}
	
	markers_table=markers_table[genes,,drop=FALSE]
	
	for (gene in intersect(rownames(markers_table),rownames(umis))){
		
		if(auto_reg_factor){
		tot=sum(umis[gene,])		
		reg_factor=pmin(pmax(round(25000/tot),4),10)
		}
		scp_plot_gene_2d(sc2d_object,gene,w=800,h=800,base_dir=base_dir,reg_factor=reg_factor,label=paste(gene,markers_table[gene,1],sep="_"))
		
	}

}

scr_chisq_intraniche_markers=function(genes,gen_niche_counts=FALSE){
  if(gen_niche_counts){scr_gen_niche_counts()}
  
  mat=rbind(scr_niche_counts[genes,],colSums(scr_niche_counts[!(rownames(scr_niche_counts) %in% genes),]))
  rownames(mat)=c(genes,"other")
  
  marker_pvalues=c()  

  for (gene in genes){
      matrix=rbind(mat[gene,],colSums(mat[which(rownames(mat)!=gene),]))
      pval=chisq.test(matrix)$p.value
      marker_pvalues=c(marker_pvalues,pval)
  }
  names(marker_pvalues)=genes  
  marker_qvalues=p.adjust(marker_pvalues,method = "BH")  
  return(marker_qvalues)
  
}


scr_phylostratigraphy_cell_modules=function(sc_object,footprint=NULL,PS_order,PS_table_file,
                                     T_totumis_seen=10,fc_threshold=2){
  
	if(is.null(footprint)){
		footprint=sc_object@clust_fp[,cluster_order]
	}
	genes=names(which(rowSums(as.matrix(sc_object@scmat@mat[intersect(rownames(footprint),rownames(sc_object@scmat@mat)),])) > T_totumis_seen))
	
	#f_genes_fc=names(which(apply(footprint[genes,],1,max) > fc_threshold))
	
	message("N seen genes: ",length(genes))
	PS_table=read.table(PS_table_file,h=FALSE,row.names=1);colnames(PS_table)="PS" 
	df=matrix(ncol=length(unique(PS_table$PS)),nrow=ncol(footprint))
	colnames(df)=sort(unique(PS_table$PS));rownames(df)=colnames(footprint)
	Niche_genes=c()  #we compile here genes that are considered as part of a niche gene module (used later to define backgroud)
	for (niche in colnames(footprint)){
		Genes=unique(names(which(footprint[genes,as.character(niche)]>fc_threshold)))
		df[niche,]=table(PS_table[Genes,"PS"])[as.character(sort(unique(PS_table$PS)))]     
		Niche_genes=unique(c(Niche_genes,Genes))
	}  
	IDs_OTHER=setdiff(genes,Niche_genes)
	
	df=rbind(df,table(PS_table[IDs_OTHER,"PS"])[as.character(sort(unique(PS_table$PS)))])
	rownames(df)=c(colnames(footprint),"other_genes")

	df=df[,PS_order] #we chose a custom order of phylostrata ("Eukaryota", etc)
	


	mat_ps_chisq=as.data.frame(matrix(ncol=ncol(df),nrow=nrow(df)-1));colnames(mat_ps_chisq)=colnames(df);rownames(mat_ps_chisq)=rownames(head(df,-1))
	for(niche in rownames(df)){
		for (ps in colnames(df)){ 
			niche_values=c(df[niche,ps],sum(df[niche,which(colnames(df)!=ps)]))
			other_sn_values=c(sum(df[which(rownames(df)!=niche),ps]),sum(sum(df[which(rownames(df)!=niche),which(colnames(df)!=ps)])))
			matrix=rbind(niche_values,other_sn_values);colnames(matrix)=c(ps,"other ps");rownames(matrix)=c(niche,"other niche")      
			#print(matrix)
			mat_ps_chisq[niche,ps]=fisher.test(matrix)$p.value
		}
	}   
	
	mat_ps_chisq=as.matrix(mat_ps_chisq)
	mat_ps_chisq[is.nan(mat_ps_chisq)]<-1
	
	a = mat_ps_chisq %>% melt() %>% rename(SN=Var1, PS=Var2, qval=value) %>% mutate(qval = p.adjust(qval,method = "BH")) %>% dcast(SN ~ PS, value.var='qval')
	#mat_ps_chisq_qv=a[,-1];rownames(mat_ps_chisq_qv)=a[,1]
	mat_ps_chisq_qv=a[,-1];rownames(mat_ps_chisq_qv)=a[,1]
	
	df_freq=df/rowSums(df)

	df_freq=df_freq[!grepl("Unclass*", rownames(df_freq)),]
	bck_freq=colSums(df)/sum(df)
	mat=apply(df_freq,1,function(x) log2((x)/(bck_freq)))
	mat[which(!is.finite(mat))]=-2
	
	png("Phylostratigraphy_cell_modules_heatmap.png",h=1000,w=4000)
	par(mar=c(0,0,0,0))  
	par(fig=c(0.3,0.95,0.3,0.9))
	mat_to_plot=pmax(pmin(mat,1),-1)
	shades=colorRampPalette(c("#A6611A","#DFC27D","white","#80CDC1","#018571"))(100)
	image(t(mat_to_plot),col=shades, xaxt='n',yaxt='n',zlim=c(-1,1))
	
	for(i in 1:nrow(mat)) {
	mtext(rownames(mat)[i], side=2, at=(i-1)/(nrow(mat)-1), adj=1, las=2, line=1,cex=5)
	abline(h=(i-0.5)/(nrow(mat)-1), lwd=0.5)    
	mtext(colSums(df)[i], side=4, at=(i-1)/(nrow(mat)-1), adj=0, las=2, line=1,cex=3)
	}
 
	for(i in 1:ncol(mat)) {
	mtext(colnames(mat)[i], side=1, at=(i-1)/(ncol(mat)-1), adj=1, las=2, line=1,cex=5)
	abline(v=(i-0.5)/(ncol(mat)-1), lwd=0.5)
	mtext(rowSums(df[rownames(df_freq),])[i], side=3, at=(i-1)/(ncol(mat)-1), adj=0, las=2, line=1,cex=3)
	text(labels=ifelse(t(mat_ps_chisq_qv[i,])<0.01,"*",""), x=(i -1)/(ncol(mat)-1),y=seq(0,1,by=1/(nrow(mat)-1)), las=2,cex=7,col="gray22")
	} 
	par(fig=c(0.2,0.3,0.15,0.2),new=TRUE)   ####FC legend
	image(x=seq(-max(mat_to_plot),max(mat_to_plot),(2*max(mat_to_plot))/(length(shades)-1)), 
	y=c(0,1), col=shades, yaxt="n", z=matrix(nrow=length(shades),ncol=1,data=c(1:length(shades))),
	ylab="",xlab="",cex.axis=1)
	mtext("log2 FC",side=1,line=5,cex=3)

	dev.off()
  
	png("Pylostratigraphy_stratiphiedFC.png",h=1500,w=1500)
	phylo_col=c("#ffffcc","#a1dab4","#41b6c4","#2c7fb8","#253494")
	tmp_fp_n=footprint
	par(mar=c(1,30,15,1))  
	boxplot(apply(log2(tmp_fp_n[genes,]),1,max)~ordered(as.factor(PS_table[genes,]),levels=PS_order)
	,las=2,border = "black",outcex=2,pch=20,horizontal=TRUE,col="grey",cex.axis=5,outcol=alpha("darkgrey",0.6),
	boxwex=0.5,xaxt="n",ylim=c(0,3),cex.names=6,lwd=3,frame=FALSE,cex=6) 
	axis(side=3,at=0:5,labels=0:5,cex.axis=4,mgp=c(3,3,0))
	mtext(3,text="log2 Max FC",line=8,cex=5)
	dev.off()

	pairwise.wilcox.test(apply(log2(footprint[genes,]),1,max),ordered(as.factor(PS_table[genes,]),levels=PS_order),p.adjust.method='none',alternative="greater")

}


scr_export_cluster_annotation=function(sc_object,cluster_list=NULL,rat_threshold=2,excel_fn="Cluster_annotation.xlsx",
                              gene_annot_file){
    options(java.parameters = "-Xmx8000m")

	system(paste0("rm ",excel_fn))
	
	footprint_table=sc_object@clust_fp
	
	if(is.null(cluster_list)){
		cluster_list=colnames(footprint_table)
	}
	
	annot=read.table(gene_annot_file,header=T,sep="\t",fill=TRUE,quote="",row.names=1)	  
	wb <- createWorkbook() 
	
	niche_counts= .row_stats_by_factor(sc_object@scmat@mat[rownames(sc_object@clust_fp),], sc_object@clusts, rowSums)
	
	for (cluster in cluster_list){
	
      genes=names(which(footprint_table[order(footprint_table[,cluster],decreasing=T),cluster] > rat_threshold))
      genes=intersect(genes,rownames(niche_counts))
	  
	  gene_niche_umifrac=niche_counts[genes,cluster]*100/rowSums(niche_counts[genes,])
	  
	  annot_sign_genes=cbind.data.frame(rowSums(niche_counts[genes,]),niche_counts[genes,cluster],gene_niche_umifrac[genes],log2(footprint_table[genes,cluster]),annot[genes,])
	  colnames(annot_sign_genes)=c("Total umis","Niche umis","niche.umi.fraction(%)","log2FC",colnames(annot))
	  
	  jgc()
	  
	  sheet <- createSheet(wb, sheetName = paste0("C",cluster))
	  addDataFrame(annot_sign_genes, sheet)
    }
	
	saveWorkbook(wb, excel_fn) 
}

jgc <- function()
{
	gc()
  .jcall("java/lang/System", method = "gc")
} 



scr_tf_tf_cor_footprint=function(footprint,sc_object,output_file="TF_map.png",cor_zlim=0.79,tf_file,T_totumi=50,min_FC=2.5,h=5000,w=4000,plot_gene_names=TRUE,nmodules=80,
									PS_table_file=NULL,PS_order=c("Eukaryota","Holozoa","Metazoa","Porifera","Demospongia","Amphimedon"),
									fuse_modules=F,module_fuse_thrs=0.7,fuse_modules2=F,pmin=2.5){
									
	tfs=read.table(tf_file,header=F,sep="\t",fill=TRUE,quote="",row.names=1)
	if(!is.null(PS_table_file)){
		PS_table=read.table(PS_table_file,h=FALSE,row.names=1);colnames(PS_table)="PS" 
	}
	f_tf_cov = names(which(rowSums(as.matrix(sc_object@scmat@mat[intersect(rownames(tfs),rownames(sc_object@scmat@mat)),])) > T_totumi))
	
	f_tf_min_fc=names(which(apply(footprint[f_tf_cov,],1,max) > min_FC))
	
	message(paste0("N TFs seen with ",T_totumi," coverage and min_FC over",min_FC," : ",length(f_tf_min_fc)))
	tf_fp = footprint[f_tf_min_fc,cluster_order]
   
	tf_tf_corr=cor(t(tf_fp[,]))
	
	tmp_cor=tf_tf_corr
	diag(tmp_cor)=0
	message("Max cor: ",max(tmp_cor))
	message("Min cor: ",min(tf_tf_corr))
    ##sort tfs
	nms=rownames(tf_fp)
	nms = nms[order(apply(tf_fp[nms,cluster_order], 1, function(x) which.max(rollmean(x,4))))]  ########!!!!!!!!!!!!!!!Change to 1 for small cluster maps (<30 clusts)


	message("I will compute hierarchical clustering of genes based on correlation...")
	x_cor=tf_tf_corr
	x_clust=hclust(as.dist(1-x_cor),method = "ward.D2")
	cls = cutree(x_clust, nmodules)
	message("Finished clustering genes...")
  
  ##Filter modules by min internal correlation
	modules <- list()
	all_intra_module_mean_cors=c()
	for(k in 1:nmodules) {
		if(length(names(which(cls==k)))==1){next}
		c <- x_cor[names(which(cls==k)),names(which(cls==k))]; 
		all_intra_module_mean_cors=c(all_intra_module_mean_cors,mean(c[lower.tri(c)]))
		if (mean(c[lower.tri(c)]) > 0.02) {
			modules[[length(modules)+1]] <- names(cls[cls==k])
		}
	}
	x_cor = x_cor[unlist(modules), unlist(modules)]

	names(modules)=1:length(modules)
	
	if(fuse_modules2){  #merge modules if they have X correlation
		max_cor=1
		while(max_cor>module_fuse_thrs){		
			module_fps=t(sapply(modules,function(x) colSums(footprint[x,])))
			#module_fps = log2(1+module_fps)
			module_cor=cor(t(module_fps))
			module_cor[is.na(module_cor)]=0
			module_cor[upper.tri(module_cor, diag = TRUE)]=0
			max_cor=max(module_cor)
			message(paste0("Now the max cor is: ",max_cor))
			if(max_cor<=module_fuse_thrs){message("I stop because max_cor is already below threshold");break}
			cor=1
			while(cor > module_fuse_thrs){ #keep looping until no large cor values are found
				v=which(module_cor == max(module_cor),arr.ind=TRUE)[1,] #get the row and col coordinate of the max cor value in the module_cor matrix
				cor=module_cor[v[1],v[2]]
				rown=rownames(module_cor)[v[1]]
				coln=colnames(module_cor)[v[2]]
				if(cor > module_fuse_thrs){
					modules[[rown]]=c(modules[[rown]],modules[[coln]])
					modules=modules[names(modules)!=coln]
				}
				module_cor[v[1],]=0 #kill both clusters
				module_cor[,v[2]]=0
			}
		}
	}

	names(modules)=1:length(modules) 

    ##Reorder modules by which.max o average exression
    #mod_means=as.data.frame(matrix(nrow=length(modules), ncol=length(cluster_order)))
    mod_means=as.data.frame(matrix(nrow=length(modules), ncol=ncol(footprint)))
    for(k in 1:length(modules)) {
      mod_means[k,]=apply(footprint[unlist(modules[k]),cluster_order],2,mean)
      #mod_means[k,]=apply(sn_footprint_clust[unlist(modules[k]),],2,max)
    }
    modules = modules[order(apply(as.matrix(mod_means), 1, which.max))]
	
	modules_to_niche=cluster_order[apply(as.matrix(mod_means[names(modules),]), 1, which.max)]
	names(modules_to_niche)=names(modules)
	
	#tf_colors_bar=as.character(rep(scr_sn_table[modules_to_niche,"color"],lengths(modules)))
	#names(tf_colors_bar)=unlist(modules)
	###Sort genes WITHIN modules
	for (mod in modules){
		genes=unlist(modules[mod])
		genes_sorted=apply(footprint[genes,cluster_order], 1, function(x) names(which.max(rollmean(x,1))))
		modules[mod]=genes_sorted
	}
	
  nms=unlist(modules)
  
  tf_tf_corr_to_plot=pmax(pmin(tf_tf_corr[nms,nms],cor_zlim),-cor_zlim)
  png(output_file,height=h,width = w)
  par(mar=c(0,0,0,0))
  
  ###########Top panel, tf-tf corr
  par(fig=c(0.05,0.8,0.36,0.98))
  cor_shades = colorRampPalette(c("darkblue","blue","white", "gold","brown"))(1000)
  image(tf_tf_corr_to_plot,
        #col=cor_shades, zlim=c(-max(tf_tf_corr),max(tf_tf_corr)), xaxt='n',yaxt='n') 
        col=cor_shades, zlim=c(-cor_zlim,cor_zlim), xaxt='n',yaxt='n')
  if(plot_gene_names){ 
    mtext(paste(nms,tfs[nms,1],sep="_"), side=4,at=seq(0,1,length.out=length(nms)), col=as.character(tfs[nms,3]),las=1,adj=0,cex=1.8,line=0.5)
    mtext(nms, side=3,at=seq(0,1,length.out=length(nms)), las=2,cex=2,col=as.character(tfs[nms,3]),line=0.5)
  }
 

  if(!is.null(PS_table_file) & plot_gene_names==FALSE){
	par(fig=c(0.805,0.82,0.36,0.98),new=TRUE)
	image(t(as.matrix(1:length(nms))),col=phylo_col[as.character(PS_table[nms,])], axes = F,xaxt='n',yaxt='n',xaxs='i',yaxs='i')  
	par(fig=c(0.825,0.84,0.36,0.98),new=TRUE)
	image(t(as.matrix(1:length(nms))),col=phylo_col_simplified[as.character(PS_table[nms,])], axes = F,xaxt='n',yaxt='n',xaxs='i',yaxs='i')
	
	par(fig=c(0.05,0.8,0.07,0.09),new=TRUE)
	image(as.matrix(1:length(nms)),col=phylo_col[as.character(PS_table[nms,])], axes = F,xaxt='n',yaxt='n',xaxs='i',yaxs='i')  
	par(fig=c(0.05,0.8,0.04,0.06),new=TRUE)
	image(as.matrix(1:length(nms)),col=phylo_col_simplified[as.character(PS_table[nms,])], axes = F,xaxt='n',yaxt='n',xaxs='i',yaxs='i')
	
  }
  
  par(fig=c(0.05,0.8,0.1,0.34),new=TRUE)
  shades2=colorRampPalette(c("white","white","white","orange","red","purple","black"))(1000)

  tf_fp_to_plot=pmin(tf_fp[nms,cluster_order],pmin)
  print(quantile(tf_fp[nms,cluster_order],seq(0,1,by=0.1)))
  image(tf_fp_to_plot, xaxt='n',yaxt='n', col=shades2)
  mtext("Cell niches",side=4,line=8,cex=6)
  big_k = length(cluster_order)
  for(i in 1:big_k) {
    mtext(cluster_order[i], side=4, at=(i-1)/(big_k-1), adj=0, las=2, line=2,cex=1.5)
   
  }  
  all_labs2=tfs[nms,1]
  cols=as.character(tfs[nms,3])
  for(i in 1:length(all_labs2)){
    nm = all_labs2[i];
    nm_id=nms[i]
	if(plot_gene_names){  mtext(nm, side=1, at=(i-1)/(length(all_labs2)-1), adj=1, las=2, line=1,cex=1.6,col=cols[i]) }
    if(!is.null(PS_table_file) & plot_gene_names==FALSE){mtext(as.character(PS_table[nm_id,1]), side=1, at=(i-1)/(length(all_labs2)-1), adj=1, las=2, line=1,cex=1.6)}
	
	abline(v=(i-0.5)/(length(all_labs2)-1), lwd=0.5)
  }

  ##Expression legend
  par(fig=c(0.85,0.95,0.05,0.065),new=TRUE)
  image(x=seq(0,max(tf_fp_to_plot),max(tf_fp_to_plot)/(length(shades2)-1)), 
        y=c(0,1), col=shades2, yaxt="n", z=matrix(nrow=length(shades2),ncol=1,data=c(1:length(shades2))),
        cex.axis=2)
  mtext("normalized\nlog expression",side=1,line=5,cex=2)
  dev.off()  
  
  
}


scr_plot_cmod_markers = function(sccl_object,black_list=c(),output_file,clust_ord = c(),height=8000, width=3000,
						clust_col=NULL,per_clust_genes=20,gene_min_fold=3,gene_annot_file,gene_list=NULL,transverality_N=ncol(sccl_object@clust_fp),transv_excluded_niches=NULL)
  {    
    
	annot=read.table(gene_annot_file,header=T,sep="\t",fill=TRUE,quote="",row.names=1)
    #mat_niche = g_fp[scr_markers,];colnames(mat_niche)=seq(1:ncol(mat_niche))
	
	niche_geomean_n= sccl_object@clust_fp
	
	
	if(is.null(gene_list)){
	genes=unique(as.vector(unlist(apply(niche_geomean_n, 2, function(x) names(head(sort(-x[x>gene_min_fold]),n=per_clust_genes))))))
	genes=setdiff(genes, black_list)
	transversal_genes=names(which(apply(niche_geomean_n[,setdiff(as.character(colnames(niche_geomean_n)),transv_excluded_niches)], 1, function(x) sort(x,decreasing=T)[transverality_N]>1.4)))
	genes=setdiff(genes, transversal_genes)
	}else{
	genes=gene_list
	}
	
	genes=intersect(genes,rownames(niche_geomean_n))
	message("Will use ",length(genes)," genes")
	
	mat_niche=niche_geomean_n[genes,]	
		  
	if(length(clust_ord)==0) {
      message("recomputing cell ord")

	  hc1 = hclust(as.dist(1-cor(mat_niche,method="pearson")), "ward.D2")
      clust_ord = as.character(hc1$order)
      write.table(clust_ord,file="tmp_cell_clusts_ordered_by_scr_markers_plot.txt",quote=FALSE,col.names = FALSE,row.names=FALSE)
      png(paste0(output_file,"_TREE.png"),h=500,w=1000)
      plot(hc1,xlab="",xaxt='n',hang=-1,ylab="",main="",cex=1)
      dev.off()  
  	  scr_tmp_cluster_order<<-as.character(hc1$order)
    }

    #hc2 = hclust(dist(cor(t(mat_niche), m="spearman")), "ward.D2")
    #hc2 = hclust(dist(cor(t(lus))), "ward.D2")
    #hc2$order = order(apply(mat_niche[,as.character(clust_ord)],1,function(x) which.max(rollmean(x,5))))
    gene_ord=order(apply(mat_niche[,as.character(clust_ord)],1,function(x) which.max(rollmean(x,1))))
	
	#hc_gmods = cutree(hc2, 50)
	#mean_e = as.matrix(apply(t(lus_cl_fp), 1, function(x) tapply(x, hc_gmods, mean)))
	#gene_ord = order(apply(mean_e[hc_gmods,clust_ord],1,function(x) which.max(rollmean(x,3))))
	
    png(output_file, h=height, w=width)
    par(mar=c(0,0,0,0))
    par(fig=c(0.25,0.75,0.1,0.9))
    #mat_niche = g_fp[scr_markers,];colnames(mat_niche)=seq(1:ncol(mat_niche))
    #mat_niche_to_plot = pmin(pmax(mat_niche[hc2$order,], 0), 30)
    
    shades=colorRampPalette(c("white","white","orange","red","purple","black"))(1000)
    #image(t(pmax(log2(mat_niche_to_plot[, clust_ord]),0)), col=shades,xaxt="n",yaxt="n")
    image(t(pmin(log2(niche_geomean_n[genes[gene_ord],as.character(clust_ord)]+1),5)), col=shades,xaxt="n",yaxt="n")
	mtext(annot[genes[gene_ord],2], side=4, 
          at=seq(0,1,length.out=length(genes[gene_ord])), las=1, line=1)
    mtext(paste(annot[genes[gene_ord],1],genes[gene_ord],sep="||"), side=2, 
          at=seq(0,1,length.out=length(genes[gene_ord])), las=1, line=1)
    mtext(clust_ord,side=1,at=seq(0,1,length.out=sum(ncol(mat_niche))), las=2, line=1,cex=2)

    mtext(clust_ord,side=3,at=seq(0,1,length.out=sum(ncol(mat_niche))), las=2, line=1,cex=2)
	if(!is.null(clust_col)){	  
	par(fig=c(0.2,0.90,0.03,0.06),new=TRUE)
	image(as.matrix(1:length(clust_ord)),col=clust_col[as.character(clust_ord)], axes = F,xaxt='n',yaxt='n')
	}
    dev.off()
	
	print(length(genes))
	
	##################PLOT SINGLE-CELL PROFILE########################
	cell_order=c()  
    for (niche in clust_ord){
      cells=names(sccl_object@clusts[which(sccl_object@clusts==niche)])    
	  cell_order=c(cell_order,cells)
    }
	cluster_cell_count=as.matrix(table(sccl_object@clusts))
    n_cells_cluster=cluster_cell_count[clust_ord,1]

	umis=as.matrix(sccl_object@scmat@mat)
	mat = umis[genes, cell_order]
	totu = colSums(umis[, cell_order])
	mat = t(t(mat)/totu)*800

	lus_1 = log2(1+7*mat[genes[gene_ord], cell_order])
	lus = apply(lus_1 - apply(lus_1, 1, median),2, function(x) pmax(x,0))
	lus_smoo = t(apply(lus[genes[gene_ord],cell_order], 1, function(x) rollmean(x,5, fill=0)))
	
	#lus_smoo = t(apply(lus[genes[gene_ord],cell_order], 1, function(x) rollmean(x,5, fill=0)))

    png(paste0(output_file,"_sc_cells.png"), h=height, w=width*2)
    par(mar=c(0,0,0,0))
    par(fig=c(0.2,0.90,0.1,0.9))
    shades=colorRampPalette(c("white","white","orange","red","purple","black"))(1000)
    image(t(pmin(lus_smoo,4)), col=shades,xaxt="n",yaxt="n")
	#print(quantile(lus_smoo,seq(0,1,by=0.01)))
    x=0
    for(i in 1:length(n_cells_cluster)) {
      abline(v=(n_cells_cluster[i]+x)/(sum(n_cells_cluster)-1)-1/(2*sum(n_cells_cluster)), lwd=2)      
      mtext(clust_ord[i], side=1, at=((n_cells_cluster[i]/1.5+x)/(sum(n_cells_cluster)-1)), adj=1, las=2, line=1,cex=4)
      mtext(clust_ord[i], side=3, at=((n_cells_cluster[i]/1.5+x)/(sum(n_cells_cluster)-1)), las=2, line=1,cex=4)     
      x=x+n_cells_cluster[i]
    }
	
	
	if(!is.null(clust_col)){	  
	par(fig=c(0.2,0.90,0.05,0.08),new=TRUE)
	image(as.matrix(1:length(clust_ord)),col=clust_col[as.character(clust_ord)], axes = F,xaxt='n',yaxt='n')
	}
	dev.off()
}


scr_plot_bootstrap_matrix=function(bootstrap_list,sc_object,cluster_order=as.character(scr_sn_table$niches)){
	cell_order=c()  
    for (niche in cluster_order){
      cells=names(sc_object@clusts[which(sc_object@clusts==niche)])    
	  cell_order=c(cell_order,cells)
   }
	
	bootstraps=bootstrap_list$coclust[cell_order,cell_order]
	num_trials=bootstrap_list$num_trials[cell_order,cell_order]
	
	
	#boot_colorscale=colorRampPalette(c("white","green","blue","brown","black","red"))(1000)
	boot_colorscale=colorRampPalette(brewer.pal(name='YlGnBu',n=9))(1000)
	boot_colorscale=colorRampPalette(c("white",brewer.pal(name='YlGnBu',n=9),"navyblue","black"))(1000)

	boot_to_plot=bootstraps/num_trials
	
	boot_to_plot=pmin(boot_to_plot,0.4)
	
	png("Bootstrap_matrix.png",h=2000,w=2000)	
	par(mar=c(0,0,0,0))  
	par(fig=c(0.1,0.9,0.1,0.9))
	image(boot_to_plot,col=boot_colorscale, xaxt='n',yaxt='n')  
  
	cluster.sizes = as.numeric(table(sc_object@clusts)[cluster_order]) #a vector of sizes for each cluster
	current.line = 0
	#browser()
	for (i in 1:length(cluster.sizes)){

		current.line = current.line + cluster.sizes[i]
		mtext(cluster_order[i], side=3, at=current.line/(sum(cluster.sizes)-1) - 1/(2*sum(cluster.sizes)), las=2, line=0.5,cex=2)
		mtext(cluster_order[i], side=1, at=current.line/(sum(cluster.sizes)-1) - 1/(2*sum(cluster.sizes)), las=2, line=0.5,cex=2)
		abline(v = current.line/(sum(cluster.sizes)-1) - 1/(2*sum(cluster.sizes)), col = "grey", lwd = 2, lty = "dotted")
		abline(h = current.line/(sum(cluster.sizes)-1) - 1/(2*sum(cluster.sizes)), col = "grey", lwd = 2, lty = "dotted")
	}   
	

  ###plot legend
  par(fig=c(0.1,0.2,0.05,0.075),new=TRUE)
  image(x=seq(0,max(boot_to_plot),(max(boot_to_plot) - 0)/(length(boot_colorscale)-1)), y=c(0,1), 
        col=boot_colorscale, yaxt="n", z=matrix(nrow=length(boot_colorscale),ncol=1,data=c(1:length(boot_colorscale))),ylab="",xlab="",cex.axis=1)

	dev.off()
}


scr_merge_clusters_by_bootstrap=function(sc_object,boot_list,cl_fuse_threshold=0.2){
	
	coclust_ij=boot_list$coclust
	cosample_ij=boot_list$num_trials
	isclust_ci = diag(max(sc_object@clusts))[,sc_object@clusts]
	coclust_kl =isclust_ci %*% (coclust_ij) %*% t(isclust_ci)
	cosamp_kl =isclust_ci %*% (cosample_ij) %*% t(isclust_ci)
	clusts_robust_kl = coclust_kl / cosamp_kl  #what is the probability of two elements in the cluster to be coclustered in the bootstrap?
	
	colnames(clusts_robust_kl)=rownames(clusts_robust_kl)=1:ncol(clusts_robust_kl)
	
	clust_list=c()
	for(niche in colnames(sc_object@clust_fp)){
		clust_list[[niche]]=names(which(sc_object@clusts==niche))
	}
	
	
	cl_fuse_threshold=cl_fuse_threshold
	max_boot=1
	n=0
	while(max_boot>cl_fuse_threshold){		
		if(n==0){ #first time we take the boot matrix as it is
			clust_boot=clusts_robust_kl
		}else{
			message("I'm going for round ",n+1," !!")
		
		}
		clust_boot[upper.tri(clust_boot, diag = TRUE)]=0
		max_boot=max(clust_boot)
		
		message(paste0("Now the max_boot is: ",max_boot))
		if(max_boot<=cl_fuse_threshold){message("I stop because max_boot is already below threshold");break}
		boot=1
		tmp_boot=clust_boot
		while(boot > cl_fuse_threshold){ #keep looping until no large cor values are found
			v=which(tmp_boot == max(tmp_boot),arr.ind=TRUE)[1,] #get the row and col coordinate of the max cor value in the module_cor matrix
			boot=tmp_boot[v[1],v[2]]
			rown=rownames(tmp_boot)[v[1]]
			coln=colnames(tmp_boot)[v[2]]
			#browser()
			if(boot > cl_fuse_threshold){
				message("....mergin...",rown,"..and...",coln," which have a similarity of ", round(boot,4))
				clust_list[[rown]]=c(clust_list[[rown]],clust_list[[coln]])
				clust_list=clust_list[setdiff(names(clust_list),coln)]
				clust_boot[rown,]= apply(clust_boot[c(rown,coln),],2,max)#we substitute one of the clusts (keeps the name of the nrow, arbitraty) by the average of the two vs the others
				clust_boot[,rown]= apply(clust_boot[c(rown,coln),],2,max)
				clust_boot=clust_boot[setdiff(colnames(clust_boot),coln),setdiff(colnames(clust_boot),coln)]
				clust_boot[upper.tri(clust_boot, diag = TRUE)]=0
			}
			tmp_boot[rown,]=tmp_boot[,rown]=0 #kill both clusters
			tmp_boot[,coln]=tmp_boot[coln,]=0
			
		}
		
		n=n+1
	}
	
	
	names(clust_list)=as.character(1:length(clust_list))
	
	new_clusts=c()
	for(id in names(clust_list)){
		tmp_clust=as.integer(rep(id,length(clust_list[[id]])))
		names(tmp_clust)=clust_list[[id]]
		new_clusts=c(new_clusts,tmp_clust)
	}
	new_clusts_2=c()
	for(cell_id in sort(names(new_clusts))){
		new_clusts_2[cell_id]=new_clusts[cell_id]
		
	}
	
	fused_cl=sc_object
	fused_cl@scmat=sc_object@scmat
	fused_cl@feat_mat=sc_object@feat_mat
	fused_cl@clusts=new_clusts_2
	fused_cl@clusts = .reorder_knn_clusts(fused_cl@feat_mat, fused_cl@clusts)
	fused_cl = .scc_postprocess(fused_cl)
    fused_cl@knn_ordered=matrix(nrow=0,ncol=0)
	
	return(fused_cl)
}

scr_find_bad_clusters=function(sc_object,fc_threshold=2,plot_boxplots=F){
	
	genes=unique(as.vector(unlist(apply(sc_object@clust_fp, 2, function(x) names(which(x>fc_threshold))))))
	
	niches=colnames(sc_object@clust_fp)
	
	mat = as.matrix(sc_object@scmat@mat)
	totu = colSums(sc_object@scmat@mat)
	mat_norm = t(t(mat)/totu)*800
	
	###check for each niche which genes (among those with FC>2) have more counts IN than OUT
	Good_genes_matrix=matrix(ncol=length(niches),nrow=6)
	rownames(Good_genes_matrix)=c("Frac_cells_over_quant","Frac_cells_over_quant1","Frac_cells_over_quant2","N_genes","Median_umis_good_genes","N_cells")
	colnames(Good_genes_matrix)=niches
	
	
	for (niche in niches){
		sub_genes=names(which(sc_object@clust_fp[genes,niche] > fc_threshold))
		
		good_genes=names(head(sort(-sc_object@clust_fp[which(sc_object@clust_fp[,niche]>fc_threshold),niche]),n=50))
		
		niche_cells=names(sc_object@clusts[which(sc_object@clusts==niche)])
		nonniche_cells=setdiff(sc_object@scmat@cells,niche_cells)
		
		mean_fc_good_genes=apply(sc_object@clust_fp[good_genes,setdiff(niches,niche)],2,mean)
		neighbor1=names(sort(mean_fc_good_genes,decreasing=T)[1])
		neighbor2=names(sort(mean_fc_good_genes,decreasing=T)[2])
		
		niche_plus1neighbor_cells=c(niche_cells,names(sc_object@clusts[which(sc_object@clusts==neighbor1)]))		
		niche_plus2neighbor_cells=c(niche_plus1neighbor_cells,names(sc_object@clusts[which(sc_object@clusts==neighbor2)]))
		
		
		##Shit
		umis_in=rowSums(mat[sub_genes,niche_cells])
		umis_out=rowSums(mat[sub_genes,nonniche_cells])		
		genes_inside_over_oustide=names(which(umis_in > umis_out))
		##
		
		
		niche_cells_norm_counts=colSums(mat_norm[good_genes,niche_cells])
		nonniche_cells_norm_counts=colSums(mat_norm[good_genes,nonniche_cells])
		quant_thrs=quantile(nonniche_cells_norm_counts,0.99)
		Frac_niche_cells_over_quant=sum(niche_cells_norm_counts>quant_thrs)/length(niche_cells)
		
		nonniche1neigh_cells_norm_counts=colSums(mat_norm[good_genes,setdiff(sc_object@scmat@cells,niche_plus1neighbor_cells)])
		quant_thrs1=quantile(nonniche1neigh_cells_norm_counts,0.99)
		Frac_niche_cells_over_quant1=sum(niche_cells_norm_counts>quant_thrs1)/length(niche_cells)
		
		
		nonniche2neigh_cells_norm_counts=colSums(mat_norm[good_genes,setdiff(sc_object@scmat@cells,niche_plus2neighbor_cells)])
		quant_thrs2=quantile(nonniche2neigh_cells_norm_counts,0.99)
		Frac_niche_cells_over_quant2=sum(niche_cells_norm_counts>quant_thrs2)/length(niche_cells)
		
		Good_genes_matrix["Frac_cells_over_quant",niche]=Frac_niche_cells_over_quant
		Good_genes_matrix["Frac_cells_over_quant1",niche]=Frac_niche_cells_over_quant1
		Good_genes_matrix["Frac_cells_over_quant2",niche]=Frac_niche_cells_over_quant2
		
		Good_genes_matrix["N_cells",niche]=length(niche_cells)
		
		
		
		if(length(genes_inside_over_oustide)<=1){
			Good_genes_matrix["N_genes",niche]=0
			Good_genes_matrix["Median_umis_good_genes",niche]=0
		}
		if(length(genes_inside_over_oustide) > 1){
			niche_cells_inoverout_counts=colSums(mat[genes_inside_over_oustide,niche_cells])
			Good_genes_matrix["N_genes",niche]=length(genes_inside_over_oustide)
			Good_genes_matrix["Median_umis_good_genes",niche]=median(niche_cells_inoverout_counts)
			
		}

		
	}
	
	return(Good_genes_matrix)
}

scr_select_good_clusters = function(scl, good_clusters){
	good_cells = names(scl@clusts)[which(scl@clusts %in% good_clusters)]

	scl@scmat = scm_sub_mat(scl@scmat, cells=good_cells)
	scl@feat_mat = scl@feat_mat[,good_cells]
	scl@clusts = as.integer(as.factor(scl@clusts[good_cells]));names(scl@clusts)=good_cells
	scl@clusts = .reorder_knn_clusts(scl@feat_mat, scl@clusts)
	scl = .scc_postprocess(scl)
	scl@nclust = length(good_clusters)
    scl@knn_ordered=matrix(nrow=0,ncol=0)
	
	return(scl)
}

scr_select_good_clusters_not_renaming = function(scl, good_clusters){
	good_cells = names(scl@clusts)[which(scl@clusts %in% good_clusters)]

	scl@scmat = scm_sub_mat(scl@scmat, cells=good_cells)
	scl@feat_mat = scl@feat_mat[,good_cells]
	scl@clusts = as.integer(as.factor(scl@clusts[good_cells]));names(scl@clusts)=good_cells
	#scl@clusts = .reorder_knn_clusts(scl@feat_mat, scl@clusts)
	scl = .scc_postprocess(scl)
	scl@nclust = length(good_clusters)
    scl@knn_ordered=matrix(nrow=0,ncol=0)
	
	return(scl)
}



