#!/usr/bin/env python
#-*- coding: utf-8 -*-

import numpy as np
# own tools
import argparse
from deeptools import writeBedGraph
from deeptools import parserCommon
from deeptools import bamHandler

debug = 0

def parseArguments(args=None):

    parentParser = parserCommon.getParentArgParse()
    bamParser = parserCommon.bam()
    outputParser = parserCommon.output()
    parser = argparse.ArgumentParser(parents=[parentParser, bamParser, outputParser],
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                     description = 'Given a BAM file, generates a coverage bigwig file. The way the method works is by first calculating all the number of  reads (either extended or not) that overlap each tile (or bin) in the genome.  Tiles with counts equal to zero are skipped, i.e. not added to the output file.\nThe resulting read counts can be normalized using a given scaling factor, using the RPKM formula or normalizing to get a 1x depth of coverage.\n')

    # define the arguments
    parser.add_argument('--bam', '-b',
                        help = 'Bam file to process',
                        metavar = 'bam file',
                        required = True)

    parser.add_argument('--bamIndex', '-bai',
                        help = 'Index for the bam file. Default is to consider a the path of the bam file adding the .bai suffix.',
                        metavar = 'bam file index')


    parser.add_argument('--mappedReads',
                        help = "Use this value instead of the value found on the bam file ",
                        type = int)

    parser.add_argument('--normalizeTo1x',
                        help = "Report normalized coverage to 1x sequencing depth. Sequencing depth is defined as the total number of mapped reads * fragment length / effective genome size. To use this option, the effective genome size has to be given. Common values are: mm9: 2150570000, hg19:2451960000, dm3:121400000 and ce10:93260000. The default is not to use any normalization. ",
                        default = None,
                        type = int,
                        required = False)

    parser.add_argument('--normalizeUsingRPKM',
                        help = "Use RPKM to normalize the number of reads per bin. The formula is: RPKM (per bin) = #reads per bin / ( # of mapped reads (millions) * bin length (KB) )",
                        action = 'store_true',
                        required = False)

    args = parser.parse_args(args)

            
    args.extendPairedEnds = False if args.doNotExtendPairedEnds else True
    
    return(args)

def scaleFactor(string):
    try:
        scaleFactor1, scaleFactor2 = string.split(":")
        scaleFactors = ( float(scaleFactor1), float(scaleFactor2) )
    except:
        raise argparse.ArgumentTypeError("Format of scaleFactors is factor1:factor2. The value given ( {} ) is not valid".format(string))
    return scaleFactors


def getFragmentCenter(read, defaultFragmentLength, extendPairedEnds=True, 
                      maxPairedFragmentLength=None):
    
    """
    Takes a proper pair fragment of high quaility and limited
    to a certain length and outputs the center
    """
    fragmentStart = fragmentEnd = None
    
    if not maxPairedFragmentLength:
        maxPairedFragmentLength = 2*defaultFragmentLength if defaultFragmentLength > 0 else 1000;

    # only paired forward reads are considered
    if read.is_proper_pair and not read.is_reverse \
            and abs(read.tlen) < 250 and read.mapq > 10:

        if read.tlen % 2 == 0:
            fragmentStart = read.pos + read.tlen/2 -1
            fragmentEnd   = fragmentStart + 2
            
        else:
            fragmentStart = read.pos + read.tlen/2 - 1
            fragmentEnd   = fragmentStart + 3
            
    return (fragmentStart, fragmentEnd)

#######################################
# MAIN

def main(args):
    bamFile = args.bam
    bamHandle = bamHandler.openBam(args.bam, args.bamIndex)
    tileSize = args.binSize if args.binSize > 0 else 50
    fragmentLength = args.fragmentLength if args.fragmentLength >0 else 300

    chunkSize = int(50e6)

    global debug
    if args.verbose:
        debug = 1
    else:
        debug = 0

    if not args.mappedReads:
        args.mappedReads = bamHandle.mapped
    
    if args.normalizeTo1x:
        current_coverage = float(args.mappedReads * fragmentLength) / args.normalizeTo1x
        # the scaling sets the coverage to match 1x
        args.scaleFactor = 1.0 / current_coverage
        if debug:
            print "Estimated current coverage {}".format(current_coverage)
            print "Scaling factor {}".format(args.scaleFactor)

    elif args.normalizeUsingRPKM:
        # the RPKM is the # reads per tile / ( total reads (in millions) * tile length in Kb)
        millionReadsMapped = float(args.mappedReads)  / 1e6
        tileLengthInKb = float(args.binSize) / 1000
        args.scaleFactor =  1.0 / (millionReadsMapped * tileLengthInKb )
        if debug:
            print "scale factor using RPKM is {0}".format(args.scaleFactor)
    else:
        args.scaleFactor = 1

    funcArgs = {'scaleFactor': args.scaleFactor}

    writeBedGraph.writeBedGraph( [bamHandle.filename], args.outFileName, 
                                 fragmentLength, writeBedGraph.scaleCoverage, 
                                 funcArgs, tileSize=tileSize, region=args.region,
                                 numberOfProcessors=args.numberOfProcessors,
                                 format=args.outFileFormat,
                                 extendPairedEnds=args.extendPairedEnds,
                                 minMappingQuality=args.minMappingQuality,
                                 ignoreDuplicates=args.ignoreDuplicates,
                                 fragmentFromRead_func = getFragmentCenter)

    
if __name__ == "__main__":
    args = parseArguments()
    main(args)
