# MARC Summer Workshop 2008
#
# Sample bioinformatics functions un the Python Programming Language
# 
# Written by: Bienvenido Velez (UPR Mayaguez)
#

from string import *

######################################################################
#
#  Generally useful utility functions for handling sequences
#
######################################################################

def stringToList(theString):
    'Returns the input string as a list of single characters'
    result = []
    for element in theString:
        result = result + [element]
    return result

def listToString(theList):
    'Returns the input list of characters as a string'
    result = ""
    for element in theList:
        result = result + element
    return result

DNANucleotides='acgt'
DNAComplements='tgca'

RNANucleotides='acgu'
RNAComplements='ugca'

# Another alternative is to use a Dictionary to map nucleotides to their complements
NucleotideComplement = {
    'A':'T', 'C':'G', 'R':'Y', 'X':'X',
    'T':'A', 'G':'C', 'Y':'R' 
    }

def isDNANucleotide(n):
    'Returns True when n is a DNA nucleotide'
    return (type(n) == type("") and len(n)==1 and n.lower() in DNANucleotides)

def isRNANucleotide(n):
    'Returns True when n is a RNA nucleotide'
    return (type(n) == type("") and len(n)==1 and n.lower() in RNANucleotides)

def isDNASequence(sequence):
    'Returns True when sequence is a DNA sequence'
    if type(sequence) != type(""):
        return False;
    for base in sequence:
        if (not isDNANucleotide(base.lower())):
            return False
    return True

def isRNASequence(sequence):
    'Returns True when sequence is a RNA sequence'
    if type(sequence) != type(""):
        return False;
    for base in sequence:
        if (not isRNANucleotide(base.lower())):
            return False
    return True

######################################################################
#
#  Finding patterns in DNA/RNA sequences
#
######################################################################

def searchPattern(dna, pattern):
    'Prints all start positions of a pattern string inside a target string'
    site = find (dna, pattern)
    while site != -1:
        print 'pattern %s at position %d' % (pattern, site)
        site = find (dna, pattern, site + 1)

def searchDNAPattern(dna, pattern):
    'Prints all start positions of a pattern DNA sequence inside a target DNA sequence'
    site = findDNAPattern(dna, pattern)
    while site != -1:
        print 'pattern %s at position %d' % (pattern, site)
        site = findDNAPattern(dna, pattern, site + 1)

# The searchPattern function will not work properly is sequences hace unknown bases
# The following functions overcome this limitation

def matchDNANucleotides(base1, base2):
    'Returns True is nucleotide bases are equal or one of them is unknown'
    return (base1 == 'x' or base2 == 'x' or (isDNANucleotide(base1) and (base1 == base2)))

def matchDNAPattern(sequence, pattern):
    'Determines if DNA pattern is a prefix of DNA sequence'
    i = 0
    while ((i < len(pattern)) and (i < len(sequence))):
        if (not matchDNANucleotides(sequence[i], pattern[i])):
            return False
        i = i + 1
    return (i == len(pattern))

def findDNAPattern(dna, pattern,startPosition=0, endPosition=None):
    'Finds the index of the first ocurrence of DNA pattern within DNA sequence from start and end positions'
    if (endPosition == None):
        endPosition = len(dna)
    dna = dna.lower() # Force sequence and pattern to lower case
    pattern = pattern.lower()
    for i in xrange(startPosition, endPosition):
        # Attempt to match patter starting at position i
        if (matchDNAPattern(dna[i:],pattern)):
            return i
    return -1

######################################################################
#
#  Complementing and Reversing DNA/RNA Sequences
#
######################################################################

def reverse(sequence):
    'Recursively returns the reverse string of the argument sequence'
    if (len(sequence)>1):
        return reverse(sequence[1:])+sequence[0]
    else:
        return sequence

def getComplementDNANucleotide(n):
    'Returns the DNA Nucleotide complement of n'
    if (isDNANucleotide(n)):
        return (DNAComplements[find(DNANucleotides,n.lower())])
    else:
        raise Exception

def getComplementRNANucleotide(n):
    'Returns the RNA Nucleotide complement of n'
    if (isRNANucleotide(n)):
        return (RNAComplements[find(RNANucleotides,n.lower())])
    else:
        raise Exception

def getComplementDNASequence(sequence):
    'Returns the complementary DNA sequence'
    if (not isDNASequence(sequence)):
        raise Exception("getComplementDNASequence: Invalid DNA sequence: " + sequence)
    result = ""
    for base in sequence:
        result = result + getComplementDNANucleotide(base)
    return result
        
def getComplementRNASequence(sequence):
    'Returns the complementary RNA sequence'
    if (not isRNASequence(sequence)):
        raise Exception("getComplementRNASequence: Invalid RNA sequence: " + sequence)
    result = ""
    for base in sequence:
        result = result + getComplementRNANucleotide(base)
    return result 

def getComplementDNASequences(sequences):
    'Returns a list of the complements of list of DNA sequences'
    result = []
    for sequence in sequences:
        result = result + [getComplementDNASequence(sequence)]
    return result

def getComplementRNASequences(sequences):
    'Returns a list of the complements of list of RNA sequences'
    result = []
    for sequence in sequences:
        result = result + [getComplementRNASequence(sequence)]
    return result

def getReverseComplementDNASequence(sequence):
    'Returns the reverse complement of a DNA sequence'
    return reverse(getComplementDNASequence(sequence))

######################################################################
#
#  Finding open reading frames (ORF's) in DNA sequences
#
######################################################################

def findDNAORFPosition(sequence, minLen, startCodon, stopCodon, startPosition, endPosition):
    'Finds the position and length of the first ORF found between startPosition and endPosition'
    while (startPosition < endPosition):
        startCodonPosition = find(sequence, startCodon, startPosition, endPosition)
        if (startCodonPosition >= 0):
            stopCodonPosition = find(sequence, stopCodon, startCodonPosition, endPosition)
            if (stopCodonPosition >= 0):
                if ((stopCodonPosition - startCodonPosition) > minLen):
                    return [startCodonPosition + 3, (stopCodonPosition - startCodonPosition) - 3]
                else:
                    startPosition = startPosition + 3
            else:
                return [-1,0] # Finished the sequence without finding stop codon
        else:
            return [-1,0] # Could not find any more start codons
    
def extractDNAORF(sequence, minLen, startCodon, stopCodon, startPosition, endPosition):
    'Returns the first ORF of length >= minLen found in sequence between startPosition and endPosition'
    ORFPosition = findDNAORFPosition(sequence, minLen, startCodon, stopCodon, startPosition, endPosition)
    startPositionORF = ORFPosition[0]
    endPositionORF = startPositionORF + ORFPosition[1]
    if (startPositionORF >= 0):
        return sequence[ORFPosition[0]: ORFPosition[0]+ORFPosition[1]]
    else:
        return ""

######################################################################
#
#  Translating DNA/RNA into proteins
#
######################################################################

GeneticCode =   {   'ttt': 'F', 'tct': 'S', 'tat': 'Y', 'tgt': 'C',
                    'ttc': 'F', 'tcc': 'S', 'tac': 'Y', 'tgc': 'C',
                    'tta': 'L', 'tca': 'S', 'taa': '*', 'tga': '*',
                    'ttg': 'L', 'tcg': 'S', 'tag': '*', 'tgg': 'W',
                    'ctt': 'L', 'cct': 'P', 'cat': 'H', 'cgt': 'R',
                    'ctc': 'L', 'ccc': 'P', 'cac': 'H', 'cgc': 'R',
                    'cta': 'L', 'cca': 'P', 'caa': 'Q', 'cga': 'R',
                    'ctg': 'L', 'ccg': 'P', 'cag': 'Q', 'cgg': 'R',
                    'att': 'I', 'act': 'T', 'aat': 'N', 'agt': 'S',
                    'atc': 'I', 'acc': 'T', 'aac': 'N', 'agc': 'S',
                    'ata': 'I', 'aca': 'T', 'aaa': 'K', 'aga': 'R',
                    'atg': 'M', 'acg': 'T', 'aag': 'K', 'agg': 'R',
                    'gtt': 'V', 'gct': 'A', 'gat': 'D', 'ggt': 'G',
                    'gtc': 'V', 'gcc': 'A', 'gac': 'D', 'ggc': 'G',
                    'gta': 'V', 'gca': 'A', 'gaa': 'E', 'gga': 'G',
                    'gtg': 'V', 'gcg': 'A', 'gag': 'E', 'ggg': 'G'
                }

def translateDNASequence(dna):
    'Translates the given dna sequence and returns the corresponding protein sequence'
    if (not isDNASequence(dna)):
        raise Exception('translateDNASequence: Invalid DNA sequence')
    prot = ""
    for i in xrange(0,len(dna),3):
        codon = dna[i:i+3]
        prot = prot + GeneticCode[codon]
    return prot

cds ='''atgagtgaacgtctgagcattaccccgctggggccgtatatcggcgcacaaa
tttcgggtgccgacctgacgcgcccgttaagcgataatcagtttgaacagctttaccatgcggtg
ctgcgccatcaggtggtgtttctacgcgatcaagctattacgccgcagcagcaacgcgcgctggc
ccagcgttttggcgaattgcatattcaccctgtttacccgcatgccgaaggggttgacgagatca
tcgtgctggatacccataacgataatccgccagataacgacaactggcataccgatgtgacattt
attgaaacgccacccgcaggggcgattctggcagctaaagagttaccttcgaccggcggtgatac
gctctggaccagcggtattgcggcctatgaggcgctctctgttcccttccgccagctgctgagtg
ggctgcgtgcggagcatgatttccgtaaatcgttcccggaatacaaataccgcaaaaccgaggag
gaacatcaacgctggcgcgaggcggtcgcgaaaaacccgccgttgctacatccggtggtgcgaac
gcatccggtgagcggtaaacaggcgctgtttgtgaatgaaggctttactacgcgaattgttgatg
tgagcgagaaagagagcgaagccttgttaagttttttgtttgcccatatcaccaaaccggagttt
caggtgcgctggcgctggcaaccaaatgatattgcgatttgggataaccgcgtgacccagcacta
tgccaatgccgattacctgccacagcgacggataatgcatcgggcgacgatccttggggataaac
cgttttatcgggcggggtaa'''.replace('\n','')


######################################################################
#
#  Clustering sequences into trees using similarity functions
#
#  A Starting point towards computing your own Phylogenetic Trees
#
######################################################################

def simpleMismatchPercentDistance(seq1, seq2):
    'Returns the percentage sequence identity of two sequences'
    if (len(seq1) < len(seq2)):
        seq1, seq2 = seq2, seq1
    count = 0
    for i in xrange(0,len(seq2)):
        if (seq2[i] != seq1[i]):
            count = count + 1
    count = count + len(seq1) - len(seq2) 
    return (count*2.0) / float(len(seq1) + len(seq2))

def getCentroid(cluster):
    'For now simply return the leftmost leaf of the cluster tree as a centroid'
    # Later select the most representative sequence as the centroid
    if (type(cluster) == type([])):
        return getCentroid(cluster[0])
    elif (type(cluster) == type("")):
        return cluster[find(cluster, ':')+1:]
    else:
        raise Exception

def findMostSimilarClusters(clusters, distanceFunction):
    'Returns a list with the indices of the two closest clusters using the distanceFunction'
    if (len(clusters) < 2):
        raise Exception
    minDistance = distanceFunction(getCentroid(clusters[0]), getCentroid(clusters[1]))
    minClusters = [0,1]
    for i in xrange(0,len(clusters)):
        for j in xrange(i+1,len(clusters)):
            nextDistance = distanceFunction(getCentroid(clusters[i]), getCentroid(clusters[j]))
            #print 'DEBUG: ', i, clusters[i], j, clusters[j],nextDistance
            if (nextDistance < minDistance):
                minClusters = [i,j]
    return minClusters
    
def collapseClusters(cluster1, cluster2):
    'Collapse two clusters into a single cluster'
    return [cluster1, cluster2]

def mergeClusters(clusters, clusterIndex1, clusterIndex2):
    'Returns the cluster with two indicated clusters collapsed'
    if (clusterIndex1 > clusterIndex2):
        clusterIndex1, clusterIndex2 = clusterIndex2, clusterIndex1
    nc = collapseClusters(clusters[clusterIndex1], clusters[clusterIndex2])
    return clusters[0:max([clusterIndex1,0])] +  [nc] + clusters[clusterIndex1+1:clusterIndex2] + clusters[clusterIndex2 + 1:]

def buildTree(sequences, distanceFunction=simpleMismatchPercentDistance):
    'Applies agglomerative clustering to return the tree obtained by subsequently collapsing the two most similar clusters'
    clusters = sequences
    while (len(clusters) > 1):
        positions = findMostSimilarClusters(clusters,distanceFunction)
        clusters = mergeClusters(clusters, positions[0], positions[1])
    return clusters

def printTree(tree, level=0):
    'Prints a tree with the names of sequences at the leaves'
    if (type(tree) == type([])):
        printTree(tree[1],level+2)
        print ' '*level + '(*)'
        printTree(tree[0],level+2)
    else:
        print ' '*level + tree[:find(tree,':')]

def extractSequence(fileLine):
    sequence = fileLine[find(fileLine,' '):].replace(' ','')
    return sequence.replace('\r','').replace('\n','')

def extractID(fileLine):
    id = fileLine[:find(fileLine,' ')].replace(' ','')
    return id

def loadClustalwAlignment(pathname):
    sequences = []
    id2SequenceMap = dict()
    f=open(pathname, 'r')
    f.readline()
    for line in f:
        line = line.replace('\n','')
        if (len(line)>1):
            sequenceID = extractID(line)
            if (sequenceID in id2SequenceMap):
                prevSequence = id2SequenceMap[sequenceID]
            else:
                prevSequence = ' '
            id2SequenceMap[sequenceID]=prevSequence + extractSequence(line)
    f.close()
    sequences = []
    ids = []
    for key in id2SequenceMap.keys():
        sequences = sequences + [key+":"+id2SequenceMap[key]]
    return sequences

# The following test sequence set show the format expected by the tree building functions
testSequences=['0:aaaacccctgggtttt', '1:aaaaccccggggtttc', '2:ggggccccggggttcc', '3:aaaaccccggggtccc']

def doIt():
    print 'Building a small test tree'
    tree1=buildTree(testSequences, simpleMismatchPercentDistance)
    print 'Printing the small tree'
    printTree(tree1)
    print 'Loading the multiple sequence alignment file'
    alignment=loadClustalwAlignment(pathname)
    print 'Building a tree from a multiple sequence alignment'
    tree2-buildTree(alignment, simpleMismatchPercentDistance)
    print 'Printing the MSA tree'
    printTree(tree2)

    
    
