######################################################################
#
#  Clustering sequences into trees using similarity functions
#
#  A Starting point towards computing your own Phylogenetic Trees
#  
#  Implemented By Bienvenido Velez UPR Mayaguez
#
#  Funded by:
#   NIH MARC Assisting Bioinformatics Efforts at Minority Institutions
#
######################################################################

def simpleMismatchPercentDistance(seq1, seq2):
    'Returns the percentage sequence identity of two sequences'
    if (len(seq1) < len(seq2)):
        seq1, seq2 = seq2, seq1
    count = 0
    for i in xrange(0,len(seq2)):
        if (seq2[i] != seq1[i]):
            count = count + 1
    count = count + len(seq1) - len(seq2) 
    return (count*2.0) / float(len(seq1) + len(seq2))

def getCentroid(cluster):
    'For now simply return the leftmost leaf of the cluster tree as a centroid'
    # Later select the most representative sequence as the centroid
    if (type(cluster) == type([])):
        return getCentroid(cluster[0])
    elif (type(cluster) == type("")):
        return cluster[find(cluster, ':')+1:]
    else:
        raise Exception

def findMostSimilarClusters(clusters, distanceFunction):
    'Returns a list with the indices of the two closest clusters using the distanceFunction'
    if (len(clusters) < 2):
        raise Exception
    minDistance = distanceFunction(getCentroid(clusters[0]), getCentroid(clusters[1]))
    minClusters = [0,1]
    for i in xrange(0,len(clusters)):
        for j in xrange(i+1,len(clusters)):
            nextDistance = distanceFunction(getCentroid(clusters[i]), getCentroid(clusters[j]))
            #print 'DEBUG: ', i, clusters[i], j, clusters[j],nextDistance
            if (nextDistance < minDistance):
                minClusters = [i,j]
    return minClusters
    
def collapseClusters(cluster1, cluster2):
    'Collapse two clusters into a single cluster'
    return [cluster1, cluster2]

def mergeClusters(clusters, clusterIndex1, clusterIndex2):
    'Returns the cluster with two indicated clusters collapsed'
    if (clusterIndex1 > clusterIndex2):
        clusterIndex1, clusterIndex2 = clusterIndex2, clusterIndex1
    nc = collapseClusters(clusters[clusterIndex1], clusters[clusterIndex2])
    return clusters[0:max([clusterIndex1,0])] +  [nc] + clusters[clusterIndex1+1:clusterIndex2] + clusters[clusterIndex2 + 1:]

def buildTree(sequences, distanceFunction=simpleMismatchPercentDistance):
    'Applies agglomerative clustering to return the tree obtained by subsequently collapsing the two most similar clusters'
    clusters = sequences
    while (len(clusters) > 1):
        positions = findMostSimilarClusters(clusters,distanceFunction)
        clusters = mergeClusters(clusters, positions[0], positions[1])
    return clusters

def printTree(tree, level=0):
    'Prints a tree with the names of sequences at the leaves'
    if (type(tree) == type([])):
        printTree(tree[1],level+2)
        print ' '*level + '(*)'
        printTree(tree[0],level+2)
    else:
        print ' '*level + tree[:find(tree,':')]

def extractSequence(fileLine):
    sequence = fileLine[find(fileLine,' '):].replace(' ','')
    return sequence.replace('\r','').replace('\n','')

def extractID(fileLine):
    id = fileLine[:find(fileLine,' ')].replace(' ','')
    return id

def loadClustalwAlignment(pathname):
    sequences = []
    id2SequenceMap = dict()
    f=open(pathname, 'r')
    f.readline()
    for line in f:
        line = line.replace('\n','')
        if (len(line)>1):
            sequenceID = extractID(line)
            if (sequenceID in id2SequenceMap):
                prevSequence = id2SequenceMap[sequenceID]
            else:
                prevSequence = ' '
            id2SequenceMap[sequenceID]=prevSequence + extractSequence(line)
    f.close()
    sequences = []
    ids = []
    for key in id2SequenceMap.keys():
        sequences = sequences + [key+":"+id2SequenceMap[key]]
    return sequences

# The following test sequence set show the format expected by the tree building functions
testSequences=['0:aaaacccctgggtttt', '1:aaaaccccggggtttc', '2:ggggccccggggttcc', '3:aaaaccccggggtccc']

def doIt():
    print 'Building a small test tree'
    tree1=buildTree(testSequences, simpleMismatchPercentDistance)
    print 'Printing the small tree'
    printTree(tree1)
    print 'Loading the multiple sequence alignment file'
    alignment=loadClustalwAlignment(pathname)
    print 'Building a tree from a multiple sequence alignment'
    tree2-buildTree(alignment, simpleMismatchPercentDistance)
    print 'Printing the MSA tree'
    printTree(tree2)
