#include <iostream>
#include <cmath>
#include <algorithm>
#include <cfloat>
#include "HomData.h"
#include "BinQueue.hpp"
#include "DistMat.h"
#include "Options.h"

//class NodeNbrList
struct delCheck {
	bool operator()(PairDist* p) {
		if (p->isDeleted()) {
//			cout << "dd:" << *p << " " << p->getRefcount()<<endl;
			p->declRef();
			return true;
		} else {
			return false;
		}
	}
};
/*
struct delObj {
	void operator()(PairDist* p) {
		p->declRef();
		cout << "DD:" << *p << " " << p->getRefcount()<<endl;
	}
};
*/
void NodeNbrList::checkDeleted() {
	list<PairDist*>::iterator ep;
//	cout << "count(B)=" << nbrList.size() << endl;
/*
int i = 0;
for (list<PairDist*>::iterator p = nbrList.begin(); p != nbrList.end(); p++){
i++;
	cout << "aaaa:" << **p << endl;
	if ((*p)->isDeleted()) {
		cout << "ERRR" << **p << endl;
	}
}
cout << "count=" << i << endl;
*/
	ep = std::remove_if(nbrList.begin(), nbrList.end(), delCheck() );

	// delete object (decrement reference count)
//	for_each( ep, nbrList.end(), delObj() );
	// delete from nbrList
//	cout << "count(B)=" << nbrList.size() << endl;
	nbrList.erase(ep, nbrList.end());
//	cout << "count(A)=" << nbrList.size() << endl;

}

//class DistMat
const double DistMat::BIG_VALUE = DBL_MAX;
const double DistMat::SMALL_VALUE = -DBL_MAX;
double DistMat::UM_TOL_RATIO = 5.0;

DistMat::DistMat(vector<PairDist*>* d, SimValue::Best _best, int scl)
	: homSet(d), minDist(BIG_VALUE), maxDist(SMALL_VALUE),
	minCut(SMALL_VALUE), maxCut(BIG_VALUE), distscale(scl)
{
	simBest = _best;
	for (vector<PairDist*>::iterator p = homSet->begin();
			p != homSet->end(); p++) {
		double d = (*p)->getSimValue();
		if (d > maxDist) maxDist = d;
		if (d < minDist) minDist = d;
		nodeIndex.addData(*p);
	}
	binQ = NULL;
}
DistMat::~DistMat() {
/*
cerr << "delete homSet data\n";
	for (vector<PairDist *>::iterator p=homSet->begin(); p != homSet->end(); p++) {
		delete (*p);
	}
*/
	delete homSet;
	if (binQ) {
		delete binQ;
	}
}

size_t DistMat::size() {
	return homSet->size();
}
void DistMat::setDistCutoff(double min, double max) {
	minCut = min;
	maxCut = max;
}
void DistMat::add(PairDist* pdist) {
	nodeIndex.addData(pdist);
	homSet->push_back(pdist);
	if (binQ != NULL) {
		binQ->add(pdist->getSimValue(), pdist);
	}
}
/*
void DistMat::add(Domain *dom1, Domain *dom2, DistData *dist) {
	node1 = ClusterNode::createLeafNode(dom1);
	node2 = ClusterNode::createLeafNode(dom2);
	pdist = PairDist::newInstance(node1, node2, dist);
	homSet->push_back(h);
}
*/
void DistMat::sort() {
	std::sort(homSet->begin(), homSet->end(), PairDist::cmpPairDist());
}
void DistMat::createIndices() {
	nodeIndex.convert();
	createBinQueue();
}
void DistMat::createBinQueue() {
	double mind = ( definedDist(minCut) ? minCut :
		(definedDist(minDist) ? minDist : 0) );
	double maxd = ( definedDist(maxCut) ? maxCut :
		(definedDist(maxDist) ? maxDist : 0) );
	binQ = new BinQueue<PairDist *>((int)floor(mind), (int)ceil(maxd),
			(BinQueue<PairDist*>::Best) simBest, distscale);
//cout << "mind=" << mind << " " << "maxd=" << maxd << endl;
	if (definedDist(minCut)) {
		binQ->setLowCutFlag();
	}
	if (definedDist(maxCut)) {
		binQ->setHighCutFlag();
	}

	for (vector<PairDist*>::iterator p = homSet->begin();
			p != homSet->end(); p++) {
		double d = (*p)->getSimValue();
		binQ->add(d, *p);
	}
}
void DistMat::correctDistance(SpecSetInstances *spSetInst, int corrMode) {
	map< pair<int, int>, PairDist* > distMap;
	for (vector<PairDist*>::iterator p = homSet->begin();
			p != homSet->end(); p++) {
		int nid1 = (*p)->getNode(0)->getID();
		int nid2 = (*p)->getNode(1)->getID();
		if (nid1 > nid2) {
			int tmp = nid1;
			nid1 = nid2; nid2 = tmp;
		}
//cout << "P: " << nid1 << " " << nid2 << endl;
		pair<int,int> id_pair = make_pair(nid1, nid2);
		distMap.insert( make_pair(id_pair, *p) );
	}
	for (int i = 0; i < nodeIndex.size(); i++) {
		ClusterNode *n1 = ClusterNode::getNode(i);
		string spec1 = GeneName::getSpec(n1->getName());
		bool part1 = spSetInst->checkPartial(spec1);
		if (! part1) {
			continue;
		}
		vector<DistData*>* p = nodeIndex.getTmpIdx(i);
		if (p == NULL) {
			continue;
		}
		std::sort(p->begin(), p->end(), DistData::cmpDistDataByDist());

		int distnum = (*p).size();;
		vector<double> max_r(distnum, 1.0); // max ratio
/*
		vector<double> sum_log_r(distnum, 0); // average ratio
		vector<double> num_r(distnum, 0); // average ratio
*/

		for (int j = 0; j < distnum; j++) {
			PairDist* d_12 = (*p)[j]->getPairDist();
			ClusterNode *n2;
			if (d_12->getNode(0)->getID() == i) {
				n2 = d_12->getNode(1);
			} else {
				n2 = d_12->getNode(0);
			}
			string spec2 = GeneName::getSpec(n2->getName());
			bool part2 = spSetInst->checkPartial(spec2);
			if (part2) {
				continue;
			}
			int id2 = n2->getID();
/*
			double sum_log_r = 0; //avarage ratio
			double num_r = 0;
*/

			for (int k = j+1; k < (*p).size(); k++) {
				PairDist* d_13 = (*p)[k]->getPairDist();
				ClusterNode *n3;
				if (d_12->getNode(0)->getID() == i) {
					n3 = d_13->getNode(1);
				} else {
					n3 = d_13->getNode(0);
				}
				int id3 = n3->getID();
				string spec3 = GeneName::getSpec(n3->getName());
				bool part3 = spSetInst->checkPartial(spec3);
				if (part3) {
					continue;
				}

				int i2 = id2, i3 = id3;
				if (i2 > i3) {
					int tmp = i2; i2 = i3; i3 = tmp;
				}
//cout << "P0: " << i2 << " " << i3 << endl;
				map< pair<int,int>, PairDist* >::iterator p = distMap.find( make_pair(i2, i3) );
				PairDist* d_23 = NULL;
				if (p != distMap.end()) {
					d_23 = p->second;
				} else {
					// skip if d_23 is not defined
					continue;
				}
				PairDist *od1, *od2, *od3;
				orderTriDist(d_12, d_13, d_23, od1, od2, od3);
				double rr = 1; //mod ratio
				if (od3 == d_23) {
					double mod_d1 = checkTriangle(od1, od2, od3);
					if (mod_d1 > 0) {
					    if (DEBUG::debug_flag) {
						cout << "TriError: " << *d_12 << endl;
						cout << "OrigDist:" << d_12->getDist() << " "
							<< d_13->getDist() << " " << d_23->getDist() << endl;
					    }
						rr = od3->getDist() / (od1->getDist() + od2->getDist());

						// partition ratio is set as distance ratio (with pseudo count 1)
/*
						double r1 = (od1->getDist() + 1) / (od1->getDist() + od2->getDist() + 2);
						od1->setDist( od1->getDist() + mod_d1 * r1 );
						od2->setDist( od2->getDist() + mod_d1 * (1-r1) );
					    if (DEBUG::debug_flag ) {
						cout << "Mod1_A: " << mod_d1 << 
							": " << d_12->getDist() << " " << d_13->getDist() << " "
								<< d_23->getDist() << endl;
					    }
*/
					}
					double mod_d2 = checkUltraMetricity(od1, od2, od3);
					if (mod_d2 > 0) {
						/* find x s.t. dist set (xa, xb, c) satisifies relaxed UMetric condition;
							solution is: x = (r+1) c / ( (r-1)a + (r+1)b ) */
						double x = (UM_TOL_RATIO + 1) * od3->getDist() /
						    ((UM_TOL_RATIO-1) * od1->getDist() + (UM_TOL_RATIO+1) * od2->getDist());
						if (x > rr) {
							rr = x;
					    		if (DEBUG::debug_flag ) {
								cout << "UMError: " << *d_12 << "//" << *d_13 << endl;
								cout << "OrigDist:" << d_12->getDist() << " "
									<< d_13->getDist() << " " << d_23->getDist() << endl;
								cout << "x=" << x << endl;
							}
						}

/*
						od1->setDist( od1->getDist() * x );
						od2->setDist( od2->getDist() * x );
					    if (DEBUG::debug_flag ) {
						cout << "Mod2: " << mod_d2 << 
							": " << d_12->getDist() << " " << d_13->getDist() << " "
								<< d_23->getDist() << endl;
					    }
*/
					}

					if (rr > max_r[j]) {
						max_r[j] = rr;
					}
					if (rr > max_r[k]) {
						max_r[k] = rr;
					}
/*
					sum_log_r[j] += log(rr);
					sum_log_r[k] += log(rr);
					num_r[j]++;
					num_r[k]++;
*/

				}
				
/*
if (part1 || part2) {
			cout << i << " " << j << " ";
			cout << n1->getID() << " " << n2->getID() << " ";
			cout << spec1 << " " << part1 << " ";
			cout << spec2 << " " << part2 << " ";
			cout << d_12->getNode(0)->getName() << " ";
			cout << d_12->getNode(1)->getName() << " ";
			cout << d_12->getDist() << endl;
}
*/
			}
		}
		for (int j = 0; j < distnum; j++) {
//			if (sum_log_r[j] > 0 && num_r[j] > 0) {
//				double ave_r = exp( sum_log_r[j] / num_r[j] );

				double ave_r = max_r[j];
				PairDist* d_12 = (*p)[j]->getPairDist();
				double dist0 = d_12->getDist();
				d_12->setDist( dist0 * ave_r );
				if (DEBUG::debug_flag) {
					cout << "Ave_r: " <<ave_r << endl;
					cout << "Mod1_12: " << dist0 << " >> " << d_12->getDist() << ": " << *d_12 << endl;
				}
//			}
		}
	}
}
void DistMat::orderTriDist(PairDist *d1, PairDist *d2, PairDist *d3,
		PairDist*& od1, PairDist*& od2, PairDist*& od3) {
	// order distances
	if (d1->getDist() < d2->getDist()) {
		od1 = d1; od2 = d2;
	} else {
		od1 = d2; od2 = d1;
	}
	if (od1->getDist() > d3->getDist()) {
		od3 = od2; od2 = od1; od1 = d3;
	} else if (od2->getDist() < d3->getDist()) {
		od3 = d3;
	} else {
		od2 = d3; od3 = d2;
	}
}
double DistMat::checkTriangle(PairDist *d1, PairDist *d2, PairDist *d3) {
	if (d1->getDist() + d2->getDist() < d3->getDist()) {
		// violation of the triangular inequality
		return (d3->getDist() - (d1->getDist() + d2->getDist()));
	}
	return 0;
}
double DistMat::checkUltraMetricity(PairDist *d12, PairDist *d13, PairDist *d23) {
	double d1 = (d12->getDist() + d13->getDist() - d23->getDist()) / 2;
	double d2 = (d12->getDist() + d23->getDist() - d13->getDist()) / 2;
	if (d2 >  d1 * UM_TOL_RATIO) {
		// violation of the ultrametricity beyond the tolerance ratio UM_TOL_RATIO (>=1)
		return d2/UM_TOL_RATIO  - d1;
	}
	return 0;
}

NodeNbrList *DistMat::findNeighbors(ClusterNode* node) {
	NodeNbrList *nbrList = nodeIndex[node->getID()];
//cout << "nbrList:"<<nbrList<<" "<<nbrList->size()<<endl;
	nbrList->checkDeleted();
//cout << "Indx:" << &nodeIndex << " " << nodeIndex.size() << endl;
/*
for (list<PairDist*>::iterator p = nbrList->begin(); p != nbrList->end(); p++){
	cout << "pppp" << **p << endl;
	if ((*p)->isDeleted()) {
		cout << "ERRR" << **p << endl;
	}
}
*/

//cout << "nbrList(a):"<<nbrList<<" "<<nbrList->size()<<endl;
	return nbrList;
}
int DistMat::getBestData(PairDist*& bestData) {
	return binQ->getBestData(bestData);
}
bool DistMat::definedDist(double dist) {
	return( ! (dist == SMALL_VALUE || dist == BIG_VALUE) );
}
void DistMat::printMat() {
	for (vector<PairDist*>::iterator p = homSet->begin();
			p != homSet->end(); p++) {
		(*p)->print();
	}
}

// class NodeIndex
NodeIndex::NodeIndex(int size) : distDataPool(size), status(INITIALIZE),
	indexTmp(), index()
{
	indexTmp.reserve(size);
	index.reserve(size);
}
NodeIndex::~NodeIndex() {
	for (int i = 0; i < indexTmp.size(); i++) {
		if (indexTmp[i]) delete indexTmp[i];
	}
	for (int i = 0; i < index.size(); i++) {
		if (index[i]) delete index[i];
	}
}

void NodeIndex::addData(PairDist *pdist) {
	int id1 = pdist->getNode(0)->getID();
	int id2 = pdist->getNode(1)->getID();
	if (status == INITIALIZE) {
		addData_preproc(id1, id2, pdist);
		addData_preproc(id2, id1, pdist);
	} else {
		addData_clustering(id1, pdist);
		addData_clustering(id2, pdist);
	}
}
vector<DistData*>* NodeIndex::getTmpIdx(int idx) {
	if (idx < indexTmp.size()) {
		return(indexTmp[idx]);
	} else {
		return(NULL);
	}
}

void NodeIndex::addData_preproc(int id1, int id2, PairDist *pdist) {
	if (indexTmp.size() <= id1) {
		indexTmp.resize(id1+1);
	}
	if (indexTmp[id1] == NULL) {
		indexTmp[id1] = new vector<DistData *>();
	}
	DistData *distd = distDataPool.allocate();
	distDataPool.construct(distd, DistData(id2, pdist));
	indexTmp[id1]->push_back(distd);

}
void NodeIndex::addData_clustering(int id, PairDist *pdist) {
	NodeNbrList *nbrList;
	int idxSize = index.size();
	if (id >= idxSize) {
		do {
			nbrList = new NodeNbrList();
			index.push_back(nbrList);
		} while (id >= ++idxSize);
	} else {
		nbrList = (NodeNbrList*) index[id];
	}
	nbrList->add( pdist );
}
void NodeIndex::convert() {
//cout << "indexTmpSize:" << indexTmp.size() << endl;
	for (int i = 0; i < indexTmp.size(); i++) {
		NodeNbrList *newList = new NodeNbrList();
		index.push_back(newList);
		if (indexTmp[i]) {
			/* sort indexTmp */
			std::sort(indexTmp[i]->begin(), indexTmp[i]->end(),
					DistData::cmpDistDataByID());
//			NodeNbrList newlist = NodeNbrList();

			/* copy from indexTmp (vector) to index (list) */
			vector<DistData *>::iterator p;
			for(p = indexTmp[i]->begin(); p != indexTmp[i]->end();
						p++) {
//cout << i<<" "<< (*p)->getID()<<" "<< **p << endl;
				newList->add( (*p)->getPairDist() );
			}
		} else {
//cout << "None\n";
		}
//cout << "size:" << i << " " << index[i] << " " << index[i]->size() << " " << index.size()<< endl;
	}
	status = CLUSTERING;
}

