#!/usr/local/bin/perl -w
use strict;
use File::Basename;
use Getopt::Std;
my $PROGRAM = basename $0;
my $USAGE=
"Usage: $PROGRAM
-r: align using region
-R: align only region
-T: output rooted tree
";

use DomRefine::General;
use DomRefine::Read;
use DomRefine::Tree;

### Settings ###
my $SUPPORT_VALUE_THRESHOLD = 0.95;

my %OPT;
getopts('rRTlv', \%OPT);

my $TMP_INPUT = define_tmp_file("$PROGRAM.input");
my $TMP_ALIGNMENT = define_tmp_file("$PROGRAM.alignment");
my $TMP_ALIGNMENT_ERR = define_tmp_file("$PROGRAM.alignment.err");
my $TMP_TREE = define_tmp_file("$PROGRAM.ph");
my $TMP_TREE_LOG = define_tmp_file("$PROGRAM.tree_log");
my $TMP_DIVIDE = define_tmp_file("$PROGRAM.divide");
END {
    remove_tmp_file($TMP_INPUT);
    remove_tmp_file($TMP_ALIGNMENT);
    remove_tmp_file($TMP_ALIGNMENT_ERR);
    remove_tmp_file($TMP_TREE);
    remove_tmp_file($TMP_TREE_LOG);
    remove_tmp_file($TMP_DIVIDE);
}

### Main ###
-t and die $USAGE;
my $DCLST = save_stdin($TMP_INPUT);

my %CLUSTER = ();
my %DOMAIN = ();
get_dclst_structure($TMP_INPUT, \%CLUSTER, \%DOMAIN);

my @gene = ();
create_alignment($TMP_INPUT, \@gene, $TMP_ALIGNMENT, region => $OPT{r}, REGION => $OPT{R});
add_cluster_number($TMP_INPUT, \%DOMAIN, $TMP_ALIGNMENT);

phylogenetic_tree($TMP_ALIGNMENT, $TMP_TREE, $TMP_TREE_LOG);

my $TREE = create_tree($TMP_TREE);
my @NODES = $TREE->get_root_node->get_all_Descendents;

my @SUPPORT_VALUE;
save_support_values(\$TREE, \@SUPPORT_VALUE);

if ($OPT{l}) {
    put_node_labels(\$TREE);
    print_tree(\$TREE);
    exit;
}

### Set point to divide
my $IDX;
if (@ARGV) {
    $IDX = $ARGV[0];
} else {
    $IDX = find_best_division();
    if (! defined $IDX) {
	print STDERR "cannot divide\n";
	exit;
    }
    print STDERR "best_i=$IDX\n";
}
if ($OPT{T}) {
    # output tree
    $TREE->reroot_at_midpoint($NODES[$IDX]);
    put_support_values(\$TREE, \@SUPPORT_VALUE);
    print_tree(\$TREE);
} else {
    # members of a sub-cluster
    my ($diff, $r_leaves1, $r_leaves2, $detail) = check_domain_pattern_diff($TMP_TREE, $IDX);
    my %sub_cluster;
    for my $leaf_name (@{$r_leaves2}) {
	if ($leaf_name =~ /^(\S+?)_([A-Za-z0-9]+?)_(\S+)$/) {
	    my ($clusters, $sp, $gene) = ($1, $2, $3);
	    $sub_cluster{"$sp:$gene"} = 1;
	} else {
	    die $leaf_name;
	}
    }
    # output sub-clusters
    my $rep_clsuter = get_rep_cluster($DCLST);
    open(TMP_INPUT, $TMP_INPUT) || die;
    while (<TMP_INPUT>) {
	my ($cluster, $gene, $domain) = split;
	if ($sub_cluster{$gene}) {
	    s/^\S+/${rep_clsuter}m/;
	} else {
	    s/^\S+/${rep_clsuter}M/;
	}
	print;
    }
    close(TMP_INPUT);
}

################################################################################
### Function ###################################################################
################################################################################

sub find_best_division {

    my $diff_max;
    my $i_best;

    print STDERR "finding best division..\n";
    my @height = ();
    open(TMP_DIVIDE, ">$TMP_DIVIDE") || die;;
    for (my $i=0; $i<@NODES; $i++) {
	if ($NODES[$i]->branch_length > 0) {
	    my $node1 = $NODES[$i]->internal_id;
	    my $node2 = $NODES[$i]->ancestor->internal_id;
	    my ($diff, $r_leaves1, $r_leaves2, $detail, $height) = check_domain_pattern_diff($TMP_TREE, $i, $node1, $node2);
	    print TMP_DIVIDE $detail;
	    push @height, $height;
	} else {
	    # print STDERR "i=$i\tnode=", $NODES[$i]->to_string, "\n"; # print for debug
	}
    }
    close(TMP_DIVIDE);
    my $min_height = min(@height);
    print STDERR "done.\n";

    # my @line = `cat $TMP_DIVIDE | sort -t '\t' -k4,4 -k5,5gr -k3,3r`; # h,o_sp,root_len
    my @line = `cat $TMP_DIVIDE | sort -t '\t' -k2,2nr -k4,4r -k5,5 -k6,6gr`; # diff,root_len,h,o_sp
    for my $line (@line) {
	if ($line =~ /^i=(\d+).*?diff=\s+\d+/) {
	    my $i = $1;
	    if (! defined $i_best) {
		$i_best = $i;
	    }
	    print STDERR $line;
	}
	# if ($line =~ /^i=(\d+).*b=(\S+?),\sl= (\S+), .*l1= (\S+), l2= (\S+), .*.log2.l1.l2..= (\S+),.*h= (\S+), .*o_sp=.*?=\s*(\S+?)\s*, o_sp_part=\s*(.*?),\s/) {
	#     my ($i, $support_value, $len, $len1, $len2, $log_ratio_len, $height, $sp_overlap, $sp_overlap_part) = ($1, $2, $3, $4, $5, $6, $7, $8, $9);
	#     my $h_relative = 0;
	#     if ($min_height) {
	# 	$h_relative = $height/$min_height;
	#     }
	#     $h_relative = sprintf("%.5f", $h_relative);
	#     $line =~ s/, n= /, h_rel= $h_relative, n= /;
	#     if (! defined $i_best) {
	# 	$i_best = $i;
	#     }
	#     print STDERR $line;
	# }	
    }

    return $i_best;
}

sub add_cluster_number {
    my ($tmp_input, $r_domain, $tmp_alignment) = @_;

    my @alignment = `cat $tmp_alignment`;
    chomp(@alignment);
    for (my $i=0; $i<@alignment; $i++) {
	if ($alignment[$i] =~ /^>(\S+)/) {
	    my $gene = $1;
	    if (${$r_domain}{$gene}) {
		my @cluster = ();
		for my $domain (keys %{${$r_domain}{$gene}}) {
		    push @cluster, ${$r_domain}{$gene}{$domain}{cluster};
		}
		my $clusters = join("-", sort {$a cmp $b} uniq(@cluster));
		$alignment[$i] =~ s/^>(\S+)/>${clusters}_$1/;
	    } else {
		die $alignment[$i];
	    }
	}
    }
    
    open(TMP_ALIGNMENT, ">$tmp_alignment") || die;
    print TMP_ALIGNMENT join("\n", @alignment), "\n";
    close(TMP_ALIGNMENT);
}


sub check_domain_pattern_diff {
    my ($tree_file, $i, $node1, $node2) = @_;

    my $tree = create_tree($tree_file);
    my @nodes = $tree->get_root_node->get_all_Descendents;
    my $n_seq = grep {$_->is_Leaf} @nodes;
    my $branch_length = $nodes[$i]->branch_length;
    $tree->move_id_to_bootstrap;
    my $boot = $nodes[$i]->bootstrap || "";

    # calculation
    $tree->reroot_at_midpoint($nodes[$i]);
    my $root_node = $tree->get_root_node;
    my ($sub_tree1_node, $sub_tree2_node) = $root_node->each_Descendent;
    my @leaves1 = get_sub_tree_leaves($sub_tree1_node);
    my @leaves2 = get_sub_tree_leaves($sub_tree2_node);
    my @clusters1 = extract_clusters_from_leaves(@leaves1);
    my @clusters2 = extract_clusters_from_leaves(@leaves2);
    my @species1 = extract_species_from_leaves(@leaves1);
    my @species2 = extract_species_from_leaves(@leaves2);
    if ($OPT{v}) {
	print STDERR "leaves1 @leaves1\n";
	print STDERR "leaves2 @leaves2\n";
	print STDERR "clusters1 @clusters1\n";
	print STDERR "clusters2 @clusters2\n";
	print STDERR "species1 @species1\n";
	print STDERR "species2 @species2\n";
    }
    my @all_clusters = uniq_clusters(@clusters1, @clusters2);
    if (@all_clusters != 2) {
	die @all_clusters;
    }
    my ($cluster1, $cluster2) = @all_clusters;
    my ($diff, $diff1, $diff2, $diff3) = domain_pattern_diff($cluster1, $cluster2, \@clusters1, \@clusters2);
    my @all_species = uniq(@species1, @species2);
    my @common_species = check_redundancy(@species1, @species2);

    my $sum_of_duplication = 0;
    my $sum_of_sp_disappeared = 0;
    # ($sum_of_duplication, $sum_of_sp_disappeared) = sum_of_duplication(\$tree);

    # print
    $node1 ||= "";
    $node2 ||= "";
    my $height = sprintf("%.5f", $root_node->height);
    my @len1 = get_sub_tree_branch_length($sub_tree1_node);
    my @len2 = get_sub_tree_branch_length($sub_tree2_node);
    my $len1 = mean(@len1) || 0;
    my $len2 = mean(@len2) || 0;
    $len1 = sprintf("%.5f", $len1);
    $len2 = sprintf("%.5f", $len2);
    my $len12 = mean(@len1, @len2);
    my $len_relative = 0;
    if ($len12) {
	$len_relative = $branch_length / $len12;
    }
    $len_relative = sprintf("%.5f", $len_relative);
    my $log_ratio_len = "";
    if ($len2 and $len1/$len2) {
	$log_ratio_len = log($len1/$len2)/log(2);
	$log_ratio_len = abs($log_ratio_len);
	$log_ratio_len = sprintf("%.5f", $log_ratio_len);
    }
    my $height1 = sprintf("%.5f", $sub_tree1_node->height);
    my $height2 = sprintf("%.5f", $sub_tree2_node->height);
    my $n1 = scalar(@leaves1);
    my $n2 = scalar(@leaves2);
    my $n_sp = scalar(@all_species);
    my $n_sp_common = scalar(@common_species);
    my $sp_overlap = @common_species/@all_species;
    $sp_overlap = sprintf("%.5f", $sp_overlap);
    my $sp_overlap1 = 0;
    if (@species1) {
	$sp_overlap1 =  @common_species/@species1;
    }
    my $sp_overlap2 = 0;
    if (@species2) {
	$sp_overlap2 =  @common_species/@species2;
    }
    my $max_sp_overlap_part = max($sp_overlap1, $sp_overlap2);

    my $min_height = ($height1 + $height2 + $branch_length) / 2;
    if ($min_height < $height1 and $min_height < $height2) {
	die;
    } elsif ($min_height < $height1) {
	$min_height = $height1;
    } elsif ($min_height < $height2) {
	$min_height = $height2;
    }

    my $detail = "i=$i($node1,$node2), diff=\t$diff($diff1,$diff2,$diff3)\tb=$boot,\tl= $branch_length"
	. ", l_rel= $len_relative"
	. ", l1= $len1, l2= $len2"
	. ", |log2(l1/l2)|= $log_ratio_len,"
	. "\th= $min_height"
	# . "\th= $height"
	# . ", h1= $height1, h2= $height2"
	. ", n= $n_seq = $n1 + $n2"
	. ", o_sp=$n_sp_common/$n_sp=\t$sp_overlap\t, o_sp_part=\t$max_sp_overlap_part"
	# . ", n_dup=$sum_of_duplication, n_dis=$sum_of_sp_disappeared"
	. "\n";

    return $diff, \@leaves1, \@leaves2, $detail
	# , $height
	, $min_height
	;
}

sub domain_pattern_diff {
    my ($cluster1, $cluster2, $r_clusters1, $r_clusters2) = @_;
    # print STDERR "$cluster1\t$cluster2\t@{$r_clusters1}\t@{$r_clusters2}\n";

    my $clusters1_value = 0;
    for my $cluster (@{$r_clusters1}) {
	if ($cluster eq $cluster1) {
	    $clusters1_value ++;
	} elsif ($cluster eq $cluster2) {
	    $clusters1_value --;
	}
    }

    my $clusters2_value = 0;
    for my $cluster (@{$r_clusters2}) {
	if ($cluster eq $cluster1) {
	    $clusters2_value ++;
	} elsif ($cluster eq $cluster2) {
	    $clusters2_value --;
	}
    }

    my $diff_fusion_prot = 0;
    for my $cluster (@{$r_clusters1}) {
	if ($cluster =~ /-/) {
	    $diff_fusion_prot ++;
	}
    }
    for my $cluster (@{$r_clusters2}) {
	if ($cluster =~ /-/) {
	    $diff_fusion_prot --;
	}
    }

    return (abs($clusters2_value - $clusters1_value) + abs($diff_fusion_prot), $clusters1_value, $clusters2_value, $diff_fusion_prot);
}

sub extract_clusters_from_leaves {
    my @leaves = @_;
    
    my @cluster = ();
    for my $leaf (@leaves) {
	if ($leaf =~ /^(\S+?)_([A-Za-z0-9]+?)_/) {
	    my ($clusters, $sp) = $1;
	    push @cluster, $clusters;
	} else {
	    die;
	}
    }

    return @cluster;
}

sub extract_species_from_leaves {
    my @leaves = @_;
    
    my %species;
    for my $leaf (@leaves) {
	if ($leaf =~ /^(\S+?)_([A-Za-z0-9]+?)_/) {
	    my ($clusters, $sp) = ($1, $2);
	    $species{$sp} = 1;
	} else {
	    die;
	}
    }

    return keys %species;
}

sub uniq_clusters {
    my (@clusters) = @_;

    my %cluster = ();
    for my $clusters (@clusters) {
	my @cluster = split("-", $clusters);
	for my $cluster (@cluster) {
	    $cluster{$cluster} = 1;
	}
    }

    return keys %cluster;
}
