#' scrdb cell sorter
#'
#' given a matrix and a bunch of marker sets, this function split cells by marker expression, and write separate matrices for downstream analysis
#'
#' @return mptjomg
#'
#' @export
sc_proj_mat_on_marks = function(mat, markers_fn)
{
	markers = read.table(markers_fn, header = T, stringsAsFactors = F)
	
	istype_tg = table(markers$subtype, markers$gene)
	
	good_genes = intersect(rownames(mat@mat), colnames(istype_tg))
	if (length(good_genes) < 3) {
		stop(
			"found only ",
			length(good_genes),
			" marker genes represented in the matrix, terminating"
		)
	}
	
	istype_tg = as.matrix(istype_tg[, good_genes])
	
	rna = as.matrix(mat@mat)
	rna_n = t(t(rna) / colSums(rna))
	rna_n = rna_n[good_genes, ]
	
	totmark_it = t(rna_n) %*% t(istype_tg)
	
	return(list(totmark_it = totmark_it, rna_n = rna_n))
}

#' cluster mat projected on markers genes
#'
#' @param mat 
#' @param markers_fn 
#' @param n_mark_clusts 
#' @param analysis_dir 
#' @param alg_type 
#' @param min_clust_size 
#' @param tab_clust_fp_fn 
#' @param clust_fp_metadata_fields 
#'
#' @return
#'
#' @export
sc_marker_clusts = function(mat,
														markers_fn,
														n_mark_clusts = 80,
														analysis_dir = getwd(),
														alg_type = "knn",
														min_clust_size=20, 
														tab_clust_fp_fn = "clust_fp.txt",
														clust_fp_metadata_fields = c("Patient", "Sample.Name", "Cell.type"),
														ordered_subtypes = c("ab_T_cell", "regulatory_CD4_T_cell", "CD4_T_cells", "CD8_T_cells", "gd_T_cell", "NK_cell", "cytotoxic_B_cells", "IgA_B_cell", "IgM_B_cell", "IgG_B_cell", "plasma_cells", "DC", "macrophage", "granulocytes", "erythrocyte", "osteoclast"), type_col=list("CD3"="#e41a1c", "CD45"="#377eb8", "CD3&CD45"="#4daf4a"), totmark_cl=NULL)
{
	a = sc_proj_mat_on_marks(mat, markers_fn)
	totmark_it = a$totmark_it
	rna_n = a$rna_n
	
	if (is.null(totmark_cl)) {
		if (alg_type == "knn") {
			totmark_cl = scc_cluster_knn_graph (mat,	t(totmark_it), k_knn = n_mark_clusts, min_clust_size = min_clust_size)
		}
		else if (alg_type == "kmeans") {
			totmark_cl = scc_cluster_kmeans (mat, t(totmark_it), K=n_mark_clusts)
		}
	}
	
	clusts = totmark_cl@clusts
	
	centers = t(.row_stats_by_factor(t(totmark_it), clusts, rowFunction=rowMeans))
	
	centers_norm = t(t(centers) / colMeans(centers))
	
	if (!is.null(ordered_subtypes)) {
		ordered_subtypes = intersect(ordered_subtypes, colnames(centers_norm))
		subtypes = 1:ncol(centers_norm)
		names(subtypes) = colnames(centers_norm)
		hc_marks = subtypes[ordered_subtypes]
		
		hc_cl = order(apply(centers_norm[, hc_marks], 1, which.max) + 1e-3 * apply(centers_norm[,hc_marks], 1, max))
	}
	else {
		hc_cl = hclust(dist(cor(t(centers_norm))), "ward.D2")$order
		hc_marks = hclust(dist(t(centers_norm)), "ward.D2")$order
	}
	
	
	fp_shades = colorRampPalette(RColorBrewer::brewer.pal(n = 9, "YlOrRd"))(1000)

	cn_m = melt(pmin(centers_norm, 10), value.name='enr')
	
	
	cn_m$cl = factor(cn_m$Var1, levels = hc_cl)
	cn_m$mark = factor(cn_m$Var2, levels = ordered_subtypes)
	
	tot_cl = table(clusts)
	tot_m = melt(tot_cl, value.name='count')
	tot_m$total = "#cells"
	tot_m$cl = factor(tot_m$clusts, levels=hc_cl)
	
	type_cl = table(totmark_cl@clusts, totmark_cl@scmat@cell_metadata[names(totmark_cl@clusts), 'Cell.type'])
	type_cl_n = type_cl / rowSums(type_cl)
	
	type_m = melt(type_cl_n, value.name='frac')
	type_m$cl = factor(type_m$Var1, levels=hc_cl)
	
	#png(sprintf("%s/subtype_clusts_%s.png", analysis_dir, alg_type),	w = 1500,	h = 1000)
	
	p_hm <- ggplot(cn_m, aes(x=factor(cl), y=factor(mark))) + geom_tile(aes(fill = enr)) + scale_fill_gradientn(colours=RColorBrewer::brewer.pal(n=9, "YlOrRd")) +  xlab('') + ylab('') + theme(axis.text.x=element_text(size=5, angle=90, vjust=0.5), axis.text.y=element_text(size=8), legend.text=element_text(size=8), legend.title=element_blank(), plot.margin=margin(0, 3, 1, 3))
	
	p_tot_barplot <- ggplot(data=tot_m, aes(x=factor(cl), y=count, fill=total)) + geom_bar(stat="identity") + xlab('') + ylab('') + theme(axis.text.x=element_text(size=5, angle=90, vjust=0.5), axis.text.y=element_text(size=8), legend.text=element_text(size=8), legend.title=element_blank(), plot.margin=margin(0, 3, 1, 3))
	  
	p_type_barplot <- ggplot(data=type_m, aes(x=factor(cl), y=frac, fill=Var2)) + geom_bar(stat="identity") + xlab('') + ylab('') + theme(axis.text.x=element_text(size=5, angle=90, vjust=0.5), axis.text.y=element_text(size=8), legend.text=element_text(size=8), legend.title=element_blank(), plot.margin=margin(0, 3, 1, 3))
	
	plot_grid(p_hm, p_tot_barplot, p_type_barplot, nrow=3, ncol=1, align='v', rel_heights=c(4,1,1))
	
	ggsave(sprintf("%s/subtype_clusts_%s.png", analysis_dir, alg_type), width=24, height=16, units="cm")
	
	dist_to_center = sqrt(rowSums((totmark_it - centers[clusts,]) ** 2))
	max_center_dist = tapply(dist_to_center, clusts, max)
	
	centers_data = cbind(max_center_dist, centers)
	colnames(centers_data)[1] = 'max_center_dist'
	write.table(
		centers_data[hc_cl, ],
		sprintf("%s/subtype_centers_%s.txt", analysis_dir, alg_type),
		quote = F,
		sep = "\t"
	)
	
	if (!is.na(tab_clust_fp_fn)) {
		f = apply(totmark_cl@clust_fp, 1, max) > 1.5
		
		if (length(clust_fp_metadata_fields) > 1 || !is.na(clust_fp_metadata_fields)) {
			for (s in clust_fp_metadata_fields) {
				write.table(table(totmark_cl@scmat@cell_metadata[names(totmark_cl@clusts), s], totmark_cl@clusts), sprintf("%s/%s_%s.%s", analysis_dir, alg_type, tab_clust_fp_fn, s), sep = "\t", quote = F)	
			}
		}
		write.table(
			round(totmark_cl@clust_fp[f, ], 2),
			sprintf("%s/%s_%s", analysis_dir, alg_type, tab_clust_fp_fn),
			sep = "\t",
			quote = F
		)
	}
	
	return(totmark_cl)
}

#' assign cells to types by given projected markers cluster centers
#'
#' @return
#'
#' @export

sc_marker_split = function(mat,
													 markers_fn,
													 ref_centers_fn,
													 ref_centers_type_assign_fn,
													 fig_pref = NULL,
													 out_base_dir = NULL)
{
	a = sc_proj_mat_on_marks(mat, markers_fn)
	totmark_it = a$totmark_it
	
	ref_centers = read.table(
		ref_centers_fn,
		sep = "\t",
		header = T,
		stringsAsFactors = F
	)
	ref_centers = ref_centers[order(as.numeric(rownames(ref_centers))), ]
	
	ref_centers_types = read.table(
		ref_centers_type_assign_fn,
		sep = "\t",
		header = T,
		stringsAsFactors = F
	)
	ref_centers_types = ref_centers_types[order(as.numeric(rownames(ref_centers_types))), 'type']
	
	cl_max_dist = ref_centers[, 1]
	ref_centers = ref_centers[, -1]
	
	dists = as.matrix(pdist(totmark_it, ref_centers))
	clusts = apply(dists, 1, which.min)
	dist_to_center = sqrt(rowSums((totmark_it - ref_centers[clusts,]) ** 2))
	valid_cells = dist_to_center <= cl_max_dist[clusts]
	
	clusts[!valid_cells] = max(clusts) + 1
	
	ref_centers_types = c(ref_centers_types, 'None')
	
	mat@cell_metadata$assigned_type = ref_centers_types[clusts]
	
	if (!is.null(out_base_dir)) {
		ind = split(1:nrow(mat@cell_metadata), mat@cell_metadata$assigned_type)
		lapply(names(ind), function(st) {
			dir.create(sprintf("%s/%s", out_base_dir, st), showWarnings = F)
			write.table(as.matrix(mat@mat[, ind[[st]]]), sprintf("%s/%s/%s_sc_mat.txt", out_base_dir, st, st), quote=F, sep="\t")
			write.table(mat@cell_metadata[ind[[st]],], sprintf("%s/%s/%s_sc_mat_md.txt", out_base_dir, st, st), quote=F, sep="\t")
		})
	}
	
	if (!is.null(fig_pref)) {
		tnames = unique(ref_centers_types[-length(ref_centers_types)])
		tcols = RColorBrewer::brewer.pal(n = length(tnames), 'Set1')
		names(tcols) = tnames
		
		cell_comp = table(c(ref_centers_types[clusts], tnames)) - 1
		
		png(sprintf("%s_cell_type_comp.png", fig_pref),
				w = 800,
				h = 600)
		barplot(cell_comp / sum(cell_comp),
						ylab = "Fraction",
						col = tcols[names(cell_comp)])
		dev.off()
		
		
		png(sprintf("%s_cells_per_ref_cl.png", fig_pref),
				w = 800,
				h = 400)
		tab = table(c(clusts, 1:nrow(ref_centers))) - 1
		
		barplot(tab,
						ylab = "#cells",
						col = tcols[ref_centers_types],
						ylim = c(0, 1.1 * max(tab)))
		legend(
			"topleft",
			legend = names(tcols),
			fill = tcols,
			ncol = length(tcols),
			bty = 'n',
			cex = 0.9
		)
		dev.off()
		
	}
	
	mat
}
