#!/usr/bin/env python
#-*- coding: utf-8 -*-

import os
import random
import sys
import argparse
import numpy as np
from matplotlib import use as mplt_use
mplt_use('Agg')
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr

import deeptools.countReadsPerBin as countR
from deeptools import parserCommon
from deeptools.correlation_heatmap import plot_correlation
from deeptools._version import __version__


def parseArguments(args=None):
    parser = \
        argparse.ArgumentParser(
            formatter_class=argparse.RawDescriptionHelpFormatter,
            description="""

bamCorrelate computes the overall similarity between two or more BAM
files based on read coverage (number of reads) within genomic regions.
The correlation analysis is performed for the entire genome by running
the program in 'bins' mode, or for certain user selected regions in 'BED-file'
mode. Pearson or Spearman methods are available to compute correlation
coefficients. Results are saved into a heat map image and further output files
are optional. NOTE: bamCorrelate usually takes a long time to finish, thus it 
is adivsable to run the tool for a tiny region (using the --region option)
to adjust plotting parameters like colors and labels before running the
whole computation.

detailed help:

  %(prog)s bins -h
  %(prog)s BED-file -h

""",
           epilog='example usages:\n%(prog)s bins '
           '--bamfiles file1.bam file2.bam --corMethod pearson -o heatmap.png\n\n'
           '%(prog)s BED-file --BED selection.bed \n'
           '--bamfiles file1.bam file2.bam --corMethod spearman\n'
           '-o heatmap.png'
           ' \n\n',
            conflict_handler='resolve')

    parser.add_argument('--version', action='version',
                          version='%(prog)s {}'.format(__version__))
    subparsers = parser.add_subparsers(
        title="commands",
        dest='command',
        metavar='')

    parent_parser = parserCommon.getParentArgParse(binSize=False)
    read_options_parser = parserCommon.read_options()
    heatmap_parser = heatmap_options()

    # bins mode options
    bins_mode = subparsers.add_parser(
        'bins',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[bamCorrelateArgs(case='bins'),
                 parent_parser, read_options_parser,
                 heatmap_parser
                 ],
        help="The correlation is based on read coverage over "
             "consecutive bins of equal "
             "size (10k bp by default). This mode is useful to assess the "
             "overall similarity of BAM files. The bin size and the "
             "distance between bins can be adjusted.",
        add_help=False,
        usage='%(prog)s '
              '---bamfiles file1.bam file2.bam '
              '--corMethod spearman -o heatmap.png \n')

    # BED file arguments
    bed_mode = subparsers.add_parser(
        'BED-file',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[bamCorrelateArgs(case='BED-file'),
                 parent_parser, read_options_parser,
                 heatmap_parser
                 ],
        help="The user provides a BED file that contains all regions "
             "that should be considered for the correlation analysis. A "
             "common use is to compare ChIP-seq coverages between two "
             "different samples for a set of peak regions.",
        usage='%(prog)s --BED selection.bed'
              '--bamfiles file1.bam file2.bam '
              '-o heatmap.png --corMethod pearson\n',
        add_help=False)

    args = parser.parse_args(args)
    args.extendPairedEnds = False if args.doNotExtendPairedEnds else True

    if args.labels and len(args.bamfiles) != len(args.labels):
        print "The number of does not match the number of bam files."
        exit(0)
    if not args.labels:
        args.labels = map(lambda x: os.path.basename(x), args.bamfiles)

    return args


def bamCorrelateArgs(case='bins'):
    parser = argparse.ArgumentParser(add_help=False)
    required = parser.add_argument_group('Required arguments')

    # define the arguments
    required.add_argument('--bamfiles', '-b',
                        metavar='FILE1 FILE2',
                        help='List of indexed bam files separated by spaces.',
                        nargs='+',
                        required=True)

    required.add_argument('--plotFile', '-o',
                        help='File name to save the file containing the heatmap '
                        'of the correlation. The given file ending will be used '
                        ' to determine the image format, for example: '
                        'heatmap.pdf will save the heatmap in PDF format. '
                        'The available options are: .png, .emf, '
                        '.eps, .pdf and .svg.',
                        type=argparse.FileType('w'),
                        metavar='FILE',
                        required=True)

    required.add_argument('--corMethod', '-c',
                          help="Correlation method.",
                          choices=['spearman', 'pearson'],
                          required=True)

    optional = parser.add_argument_group('Optional arguments')

    optional.add_argument("--help", "-h", action="help",
                        help="show this help message and exit")
    optional.add_argument('--labels', '-l',
                        metavar='sample1 sample2',
                        help='User defined labels instead of default labels from '
                            'file names. '
                            'Multiple labels have to be separated by space, e.g. '
                            '--labels sample1 sample2 sample3',
                        nargs='+')

    if case == 'bins':
        optional.add_argument('--binSize', '-bs',
                        metavar='INT',
                        help='Length in base pairs for a window used '
                            'to sample the genome.',
                        default=10000,
                        type=int)

        optional.add_argument('--distanceBetweenBins', '-n',
                              metavar='INT',
                              help='By default, bamCorrelate considers consecutive '
                              'bins of the specified --binSize. However, to '
                              'reduce the computation time, a larger distance '
                              'between bins can by given. Larger distances '
                              'result in less bins being considered.',
                        default=0,
                        type=int)

        optional.add_argument(
            '--doNotRemoveOutliers',
            help='By default, bins with very large counts are removed. '
            'By setting this option outliers will not be removed. '
            'Bins with abnormally high reads counts artificially increase '
            'pearson correlation; that\'s why, by default, bamCorrelate tries '
            'to remove outliers using the median absolute deviation (MAD) '
            'method applying a threshold of 200 to only consider extremely '
            'large deviations from the median. ENCODE blacklist page '
            '(https://sites.google.com/site/anshulkundaje/projects/blacklists) '
            'contains useful information about regions with unusually high counts.',
            action='store_true')

        required.add_argument('--BED',
                        help=argparse.SUPPRESS,
                        default=None)
    else:
        optional.add_argument('--binSize', '-bs',
                        help=argparse.SUPPRESS,
                        default=10000,
                        type=int)

        optional.add_argument('--distanceBetweenBins', '-n',
                              help=argparse.SUPPRESS,
                              metavar='INT',
                              default=0,
                              type=int)

        required.add_argument('--BED',
                        help='Limits the correlation analysis to '
                             'the regions specified in this file.',
                        metavar='bedfile',
                        type=argparse.FileType('r'),
                        required=True)

    optional.add_argument('--includeZeros',
                        help='Genomic regions that have zero values only '
                          'are included. The default behavior is to '
                          'ignore these regions.',
                        action='store_true',
                        required=False)

    group = parser.add_argument_group('Output optional options')

    group.add_argument('--outFileCorMatrix',
                        help='Save correlation matrix to file.',
                        metavar='FILE',
                        type=argparse.FileType('w'))

    group.add_argument('--outRawCounts',
                        help='Save raw counts (coverages) to file.',
                        metavar='FILE',
                        type=argparse.FileType('w'))

    group.add_argument('--plotNumbers',
                        help='If set, then the correlation number is plotted '
                          'on top of the heatmap.',
                        action='store_true',
                        required=False)

    group.add_argument('--plotFileFormat',
                        metavar='FILETYPE',
                        help='Image format type. If given, this option '
                            'overrides the image format based on the plotFile '
                            'ending. The available options are: png, emf, '
                            'eps, pdf and svg.',
                        choices=['png', 'pdf', 'svg', 'eps', 'emf'])

    optional.add_argument('--plotTitle', '-T',
                          help='Title of the plot, to be printed on top of '
                          'the generated image. Leave blank for no title.',
                          default='')

    return parser


def heatmap_options():
    """
    Options for generating the correlation heat map
    """
    parser = argparse.ArgumentParser(add_help=False)
    heatmap = parser.add_argument_group('Heat map options')

    heatmap.add_argument('--zMin', '-min',
                        #metavar='',
                        default=None,
                        help='Minimum value for the heat map intensities. '
                            'If not specified the value is set automatically')
    heatmap.add_argument('--zMax', '-max',
                        #metavar='',
                        default=None,
                        help='Maximum value for the heat map intensities.'
                            'If not specified the value is set automatically')

    from matplotlib import cm
    color_options = "', '".join([m for m in cm.datad
                                 if not m.endswith('_r')])

    heatmap.add_argument(
        '--colorMap', default='jet',
        metavar='',
        help='Color map to use for the heatmap. Available values can be '
            'seen here: '
            'http://www.astro.lsa.umich.edu/~msshin/science/code/'
            'matplotlib_cm/ The available options are: \'' +
        color_options + '\'')

    return parser


def getOutlierIndices(data, max_deviation=200):
    """
    The method is based on the median absolute deviation. See
    Boris Iglewicz and David Hoaglin (1993),
    "Volume 16: How to Detect and Handle Outliers",
    The ASQC Basic References in Quality Control:
    Statistical Techniques, Edward F. Mykytka, Ph.D., Editor.

    returns the list, without the outliers

    The max_deviation=200 is like selecting a z-score
    larger than 200, just that it is based on the median
    and the median absolute deviation instead of the 
    mean and the standard deviation.
    """
    median = np.median(data)
    b_value = 1.4826 # value set for a normal distribution
    mad = b_value * np.median(np.abs(data-median))
    outliers = []
    if mad > 0:
        deviation = abs(data - median) / mad
        """
        outliers = data[deviation > max_deviation]
        print "outliers removed {}".format(len(outliers))
        print outliers
        """
        outliers = np.flatnonzero(deviation > max_deviation)
    return outliers


def main(args):
    """
    1. get read counts at different positions either
    all of same length or from genomic regions from the BED file

    2. compute  correlation

    """
    if len(args.bamfiles) < 2:
        print "Please input at least two bam files to compare"
        exit(1)

    if args.includeZeros:
        skip_zeros = False
    else:
        skip_zeros = True

    if args.colorMap:
        try:
            plt.get_cmap(args.colorMap)
        except ValueError as error:
            sys.stderr.write(
                "A problem was found. Message: {}\n".format(error))
            exit()

    if 'BED' in args:
        bed_regions = args.BED
    else:
        bed_regions = None

    stepsize = args.binSize + args.distanceBetweenBins
    num_reads_per_bin = countR.getNumReadsPerBin(
        args.bamfiles,
        args.binSize,
        0,
        args.fragmentLength,
        numberOfProcessors=args.numberOfProcessors,
        skipZeros=skip_zeros,
        verbose=args.verbose,
        region=args.region,
        bedFile=bed_regions,
        extendPairedEnds=args.extendPairedEnds,
        minMappingQuality=args.minMappingQuality,
        ignoreDuplicates=args.ignoreDuplicates,
        stepSize=stepsize)

    sys.stderr.write("Number of non zero bins "
                     "found: {}\n".format(num_reads_per_bin.shape[0]))

    if num_reads_per_bin.shape[0] < 2:
        exit("ERROR: too few non zero bins found.\n"
             "If using --region please check that this "
             "region is covered by reads.\n")
        

    if args.outRawCounts:
        args.outRawCounts.write("'" + "'\t'".join(args.labels) + "'\n")
        fmt = "\t".join(np.repeat('%d', num_reads_per_bin.shape[1])) + "\n"
        for row in num_reads_per_bin:
            args.outRawCounts.write(fmt % tuple(row))


    # remove outliers, otherwise outliers will produce a very
    # high pearson correlation. Only for the bins option
    if 'bins' in args and args.doNotRemoveOutliers is False:
        unfiltered = len(num_reads_per_bin)
        to_remove = None
        for col in num_reads_per_bin.T:
            # get the outliers per colum using the median absolute
            # deviation method
            outliers = getOutlierIndices(col)
            if to_remove is None:
                to_remove = set(outliers)
            else:
                # only set to remove those bins in which
                # the outliers are present in all cases (colums)
                # that's why the intersection is used
                to_remove = to_remove.intersection(outliers)
        if len(to_remove):
            to_keep = [x for x in range(num_reads_per_bin.shape[0])
                       if x not in to_remove]
            num_reads_per_bin = num_reads_per_bin[to_keep, :]
            if args.verbose:
                sys.stderr.write(
                    "total/filtered/left: "
                    "{}/{}/{}\n".format(unfiltered,
                                        unfiltered - len(to_keep), len(to_keep)))

    # num_reads_per_bin: rows correspond to  bins, cols to  samples
    num_files = len(args.bamfiles)
    #initialize correlation matrix
    corr_matrix = np.zeros((num_files, num_files), dtype='float')
    options = {'spearman': spearmanr,
               'pearson': pearsonr}
    # do an all vs all correlation using the
    # indices of the upper triangle
    rows, cols = np.triu_indices(num_files)
    for index in xrange(len(rows)):
        row = rows[index]
        col = cols[index]
        corr_matrix[row, col] = options[args.corMethod](
            num_reads_per_bin[:, row],
            num_reads_per_bin[:, col])[0]
    # make the matrix symmetric
    corr_matrix = corr_matrix + np.triu(corr_matrix, 1).T

    if args.outFileCorMatrix:
        args.outFileCorMatrix.write("\t'" + "'\t'".join(args.labels) + "'\n")
        fmt = "\t".join(np.repeat('%.4f', num_reads_per_bin.shape[1])) + "\n"
        i = 0
        for row in corr_matrix:
            args.outFileCorMatrix.write(
                "'%s'\t" % args.labels[i] + fmt % tuple(row))
            i += 1

    plot_file_name = args.plotFile.name
    args.plotFile.close()
    plot_correlation(corr_matrix,
                    args.labels,
                    plot_file_name,
                    args.zMax,
                    args.zMin,
                    args.colorMap,
                    image_format=args.plotFileFormat,
                    plot_numbers=args.plotNumbers,
                    plot_title=args.plotTitle)

if __name__ == "__main__":
    ARGS = parseArguments()
    main(ARGS)
