#!/usr/pkg/bin/python3.8

# Audio Tools, a module and set of tools for manipulating audio data
# Copyright (C) 2007-2015  Brian Langenberger

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA


import sys
import os
import os.path
import operator
import audiotools
import audiotools.text as _


def cmp_files(progress, audiofile1, audiofile2):
    """Returns (path1, path2, mismatch) tuple

    where mismatch is the int of the first PCM mismatch,
    None if the files match exactly or
    a negative value if some error occurs."""

    try:
        if os.path.samefile(audiofile1.filename, audiofile2.filename):
            return (audiofile1.filename,
                    audiofile2.filename,
                    None)
        elif ((audiofile1.sample_rate() != audiofile2.sample_rate()) or
              (audiofile1.bits_per_sample() != audiofile2.bits_per_sample()) or
              (audiofile1.channels() != audiofile2.channels())):
            return (audiofile1.filename,
                    audiofile2.filename,
                    -1)
        else:
            return (audiofile1.filename,
                    audiofile2.filename,
                    audiotools.pcm_frame_cmp(
                        audiotools.to_pcm_progress(audiofile1, progress),
                        audiofile2.to_pcm()))
    except (IOError, ValueError, KeyboardInterrupt, audiotools.DecodingError):
        return (audiofile1.filename,
                audiofile2.filename,
                -2)


def cmp_result(result, is_tty=False):
    (path1, path2, mismatch) = result

    assert(isinstance(path1, str))
    assert(isinstance(path2, str))
    assert((mismatch is None) or isinstance(mismatch, int))

    if mismatch is None:
        return ((_.LAB_TRACKCMP_CMP %
                 {"file1": audiotools.Filename(path1),
                  "file2": audiotools.Filename(path2)}) +
                u" : " + audiotools.output_text(
                    _.LAB_TRACKCMP_OK,
                    fg_color="green").format(is_tty))
    elif mismatch >= 0:
        return ((_.LAB_TRACKCMP_CMP %
                 {"file1": audiotools.Filename(path1),
                  "file2": audiotools.Filename(path2)}) +
                u" : " + audiotools.output_text(
                    _.LAB_TRACKCMP_MISMATCH %
                    {"frame_number": mismatch + 1},
                    fg_color="red").format(is_tty))
    elif mismatch == -1:
        return ((_.LAB_TRACKCMP_CMP %
                 {"file1": audiotools.Filename(path1),
                  "file2": audiotools.Filename(path2)}) +
                u" : " + audiotools.output_text(
                    _.LAB_TRACKCMP_PARAM_MISMATCH,
                    fg_color="red").format(is_tty))
    else:
        return ((_.LAB_TRACKCMP_CMP %
                 {"file1": audiotools.Filename(path1),
                  "file2": audiotools.Filename(path2)}) +
                u" : " + audiotools.output_text(
                    _.LAB_TRACKCMP_ERROR,
                    fg_color="red").format(is_tty))


def cmp_result_tty(result):
    return cmp_result(result, is_tty=True)


def image_compare(progress, image_audiofile, track_audiofile,
                  image_filename, track_filename,
                  pcm_frames_offset, total_pcm_frames):
    image_pcmreader = image_audiofile.to_pcm()

    # if image_pcmreader has seek(),
    # use it to reduce the amount of frames to skip
    if hasattr(image_pcmreader, "seek") and callable(image_pcmreader.seek):
        pcm_frames_offset -= image_pcmreader.seek(pcm_frames_offset)

    try:
        return (
            audiotools.pcm_frame_cmp(
                audiotools.PCMReaderWindow(image_pcmreader,
                                           pcm_frames_offset,
                                           total_pcm_frames),
                audiotools.PCMReaderProgress(track_audiofile.to_pcm(),
                                             total_pcm_frames,
                                             progress)),
            image_filename,
            track_filename)
    except (IOError, ValueError, KeyboardInterrupt, audiotools.DecodingError):
        return (-2, image_filename, track_filename)


def image_compare_raw(progress, source_filename,
                      sample_rate, channels, channel_mask, bits_per_sample,
                      track_audiofile, image_filename, track_filename,
                      pcm_frames_offset, total_pcm_frames):
    with open(source_filename, "rb") as f:
        # skip initial offset
        f.seek(pcm_frames_offset * channels * (bits_per_sample // 8))

        return (
            audiotools.pcm_frame_cmp(
                audiotools.PCMReaderHead(
                    audiotools.PCMFileReader(file=f,
                                             sample_rate=sample_rate,
                                             channels=channels,
                                             channel_mask=channel_mask,
                                             bits_per_sample=bits_per_sample),
                    total_pcm_frames),
                audiotools.PCMReaderProgress(track_audiofile.to_pcm(),
                                             total_pcm_frames,
                                             progress)),
            image_filename,
            track_filename)


def image_compare_results(result, is_tty=False):
    (mismatch, image_name, track_name) = result
    return u"%s : %s" % (_.LAB_TRACKCMP_CMP %
                         {"file1": audiotools.Filename(image_name),
                          "file2": audiotools.Filename(track_name)},
                         (audiotools.output_text(
                             _.LAB_TRACKCMP_OK,
                             fg_color="green").format(is_tty) if
                          mismatch is None else
                          audiotools.output_text(
                              _.LAB_TRACKCMP_MISMATCH %
                              {"frame_number": mismatch + 1},
                              fg_color="red").format(is_tty)))


def image_compare_results_tty(result):
    return image_compare_results(result, is_tty=True)


if (__name__ == '__main__'):
    import argparse

    parser = argparse.ArgumentParser(description=_.DESCRIPTION_TRACKCMP)

    parser.add_argument("--version",
                        action="version",
                        version="Python Audio Tools %s" % (audiotools.VERSION))

    parser.add_argument("-V", "--verbose",
                        action="store",
                        dest="verbosity",
                        choices=audiotools.VERBOSITY_LEVELS,
                        default=audiotools.DEFAULT_VERBOSITY,
                        help=_.OPT_VERBOSE)

    parser.add_argument("-j", "--joint",
                        type=int,
                        default=audiotools.MAX_JOBS,
                        dest="max_processes",
                        help=_.OPT_JOINT)

    parser.add_argument("-S", "--no-summary",
                        action="store_true",
                        dest="no_summary",
                        help=_.OPT_NO_SUMMARY)

    parser.add_argument("filename",
                        metavar="PATH",
                        help=_.OPT_INPUT_FILENAME_OR_IMAGE)

    parser.add_argument("filenames",
                        metavar="PATH",
                        nargs="+",
                        help=_.OPT_INPUT_FILENAME_OR_DIR)

    options = parser.parse_args()

    msg = audiotools.Messenger(options.verbosity == "quiet")

    args = [options.filename] + options.filenames

    if options.max_processes < 1:
        msg.error(_.ERR_INVALID_JOINT)
        sys.exit(1)

    check_function = audiotools.pcm_frame_cmp

    if len(args) == 2:
        if os.path.isfile(args[0]) and os.path.isfile(args[1]):
            # comparing two files

            filename_a = audiotools.Filename(args[0])
            filename_b = audiotools.Filename(args[1])

            try:
                audiofile_a = audiotools.open(args[0])
            except audiotools.UnsupportedFile:
                msg.error(_.ERR_UNSUPPORTED_FILE % (filename_a,))
                sys.exit(1)

            if not audiofile_a.__class__.supports_to_pcm():
                msg.error(_.ERR_UNSUPPORTED_TO_PCM %
                          {"filename": filename_a, "type": audiofile_a.NAME})
                sys.exit(1)

            try:
                audiofile_b = audiotools.open(args[1])
            except audiotools.UnsupportedFile:
                msg.error(_.ERR_UNSUPPORTED_FILE % (filename_b,))
                sys.exit(1)

            if not audiofile_b.__class__.supports_to_pcm():
                msg.error(_.ERR_UNSUPPORTED_TO_PCM %
                          {"filename": filename_b, "type": audiofile_b.NAME})
                sys.exit(1)

            (path1, path2, mismatch) = cmp_files(None,
                                                 audiofile_a,
                                                 audiofile_b)
            if mismatch is not None:
                msg.output(cmp_result((path1, path2, mismatch),
                                      msg.output_isatty()))
                sys.exit(1)
        elif os.path.isdir(args[0]) and os.path.isdir(args[1]):
            # comparing two directories

            to_compare = []
            results = []

            dir1 = args[0]
            files1 = {}
            dir2 = args[1]
            files2 = {}

            for (files, dir) in [(files1, dir1), (files2, dir2)]:
                for path in [os.path.join(dir, f) for f in os.listdir(dir)]:
                    if os.path.isfile(path):
                        try:
                            audiofile = audiotools.open(path)
                            if audiofile.__class__.supports_to_pcm():
                                files[audiofile.filename] = audiofile
                            else:
                                msg.warning(_.ERR_UNSUPPORTED_TO_PCM %
                                            {"filename": audiofile.filename,
                                             "type": audiofile.NAME})
                        except audiotools.UnsupportedFile:
                            pass
                        except audiotools.InvalidFile as err:
                            msg.warning(str(err))
                        except IOError:
                            msg.warning(_.ERR_OPEN_IOERROR % (path,))

            # first, attempt to match files by their stream characteristics
            streams1 = {}
            streams2 = {}

            for (files, streams) in [(files1, streams1),
                                     (files2, streams2)]:
                for f in files.values():
                    streams.setdefault((f.bits_per_sample(),
                                        f.channels(),
                                        f.sample_rate(),
                                        f.total_frames()), []).append(f)

            # anything with matching specs
            # and only a single possible match per directory
            # is queued for comparison
            for specs in set(streams1.keys()) & set(streams2.keys()):
                if (((len(streams1[specs]) == 1) and
                     (len(streams2[specs]) == 1))):
                    file1 = streams1[specs][0]
                    file2 = streams2[specs][0]

                    # remove matched files from lists
                    del(files1[file1.filename])
                    del(files2[file2.filename])

                    # queue up comparison job
                    to_compare.append((file1, file2))

            # then, attempt to match leftover files by metadata
            # such as album_number and track_number
            metadatas1 = {}
            metadatas2 = {}

            for (files, metadatas) in [(files1, metadatas1),
                                       (files2, metadatas2)]:
                for f in files.values():
                    m = f.get_metadata()
                    if m is not None:
                        metadatas.setdefault((m.track_number,
                                              m.album_number), []).append(f)
                    else:
                        metadatas.setdefault((None,
                                              None), []).append(f)

            for metadata in set(metadatas1.keys()) & set(metadatas2.keys()):
                if (((len(metadatas1[metadata]) == 1) and
                     (len(metadatas2[metadata]) == 1))):
                    file1 = metadatas1[metadata][0]
                    file2 = metadatas2[metadata][0]

                    # remove matched files from lists
                    del(files1[file1.filename])
                    del(files2[file2.filename])

                    # queue up comparison job
                    to_compare.append((file1, file2))

            # anything left over is marked as a missing file
            for (files, directory) in [(files1, args[1]), (files2, args[0])]:
                for filename in files.keys():
                    msg.info(
                        audiotools.output_text(
                            _.LAB_TRACKCMP_MISSING %
                            {"filename": audiotools.Filename(
                             os.path.basename(filename)),
                             "directory": audiotools.Filename(directory)},
                            fg_color="red").format(msg.info_isatty()))
                    sys.stdout.flush()
                    results.append((filename, None, 0))

            queue = audiotools.ExecProgressQueue(msg)

            for (track1, track2) in sorted(to_compare,
                                           key=lambda f: f[0].filename):
                queue.execute(
                    function=cmp_files,
                    progress_text=_.LAB_TRACKCMP_CMP %
                    {"file1": audiotools.Filename(track1.filename),
                     "file2": audiotools.Filename(track2.filename)},
                    completion_output=(cmp_result_tty
                                       if msg.output_isatty() else
                                       cmp_result),
                    audiofile1=track1,
                    audiofile2=track2)

            try:
                results.extend(queue.run(options.max_processes))
            except KeyboardInterrupt:
                msg.error(_.ERR_CANCELLED)
                sys.exit(1)
            successes = len([r for r in results if r[2] is None])
            failures = len(results) - successes

            if not options.no_summary:
                msg.output(_.LAB_TRACKCMP_RESULTS)
                msg.output(u"")

                table = audiotools.output_table()
                row = table.row()
                row.add_column(_.LAB_TRACKCMP_HEADER_SUCCESS, "right")
                row.add_column(u" ")
                row.add_column(_.LAB_TRACKCMP_HEADER_FAILURE, "right")
                row.add_column(u" ")
                row.add_column(_.LAB_TRACKCMP_HEADER_TOTAL, "right")

                table.divider_row([_.DIV, u" ", _.DIV, u" ", _.DIV])

                row = table.row()
                row.add_column(u"%d" % (successes), "right")
                row.add_column(u" ")
                row.add_column(u"%d" % (failures), "right")
                row.add_column(u" ")
                row.add_column(u"%d" % (successes + failures), "right")

                for row in table.format(msg.output_isatty()):
                    msg.output(row)

            if failures > 0:
                sys.exit(1)
        else:
            # comparison mismatch
            msg.error((_.LAB_TRACKCMP_CMP %
                       {"file1": audiotools.Filename(args[0]),
                        "file2": audiotools.Filename(args[1])}) +
                      u" : " +
                      audiotools.output_text(
                          _.LAB_TRACKCMP_TYPE_MISMATCH,
                          fg_color="red").format(msg.error_isatty()))
            sys.exit(1)
    elif len(args) > 2:
        # possibly comparing disk image against tracks
        audiofiles = []
        for arg in args:
            if os.path.isfile(arg):
                try:
                    audiofiles.append(audiotools.open(arg))
                except audiotools.UnsupportedFile:
                    pass
                except audiotools.InvalidFile as err:
                    msg.warning(str(err))
                except IOError:
                    msg.warning(_.ERR_OPEN_IOERROR % (path,))
        audiofiles.sort(key=lambda t: t.total_frames())

        if (sum(t.total_frames() for t in audiofiles[0:-1]) !=
            audiofiles[-1].total_frames()):
            msg.output(_.USAGE_TRACKCMP_CDIMAGE)
            sys.exit(1)

        cd_image = audiofiles[-1]
        tracks = audiofiles[0:-1]

        image_name = audiotools.Filename(cd_image.filename)

        # all tracks should have the same album number and track total
        tracks = audiotools.sorted_tracks(tracks)

        cd_data = audiotools.BufferedPCMReader(cd_image.to_pcm())

        queue = audiotools.ExecProgressQueue(msg)

        pcm_offset = 0

        if cd_image.seekable():
            for (i, track) in enumerate(tracks):
                track_name = audiotools.Filename(track.filename)
                queue.execute(
                    function=image_compare,
                    progress_text=_.LAB_TRACKCMP_CMP %
                    {"file1": image_name, "file2": track_name},
                    completion_output=(image_compare_results_tty
                                       if msg.output_isatty()
                                       else image_compare_results),
                    image_audiofile=cd_image,
                    track_audiofile=track,
                    image_filename=str(image_name),
                    track_filename=str(track_name),
                    pcm_frames_offset=pcm_offset,
                    total_pcm_frames=track.total_frames())

                pcm_offset += track.total_frames()

            try:
                if ({r[0] for r in
                     queue.run(options.max_processes)} != {None}):
                    sys.exit(1)
            except KeyboardInterrupt:
                msg.error(_.ERR_CANCELLED)
                sys.exit(1)
        else:
            import tempfile

            # if file isn't seekable

            # decode it to a single PCM blob of binary data
            temp_blob = tempfile.NamedTemporaryFile()
            cache_progress = audiotools.SingleProgressDisplay(
                msg, _.LAB_CACHING_FILE)
            try:
                audiotools.transfer_framelist_data(
                    audiotools.PCMReaderProgress(
                        cd_image.to_pcm(),
                        cd_image.total_frames(),
                        cache_progress.update),
                    temp_blob.write)
            except audiotools.DecodingError as err:
                cache_progress.clear_rows()
                msg.error(err)
                temp_blob.close()
                sys.exit(1)
            except KeyboardInterrupt:
                cache_progress.clear_rows()
                msg.error(_.ERR_CANCELLED)
                temp_blob.close()
                sys.exit(1)

            cache_progress.clear_rows()
            temp_blob.flush()

            # compare the blob using multiple jobs
            for (i, track) in enumerate(tracks):
                track_name = audiotools.Filename(track.filename)
                queue.execute(
                    function=image_compare_raw,
                    progress_text=_.LAB_TRACKCMP_CMP %
                    {"file1": image_name, "file2": track_name},
                    completion_output=(image_compare_results_tty
                                       if msg.output_isatty()
                                       else image_compare_results),
                    source_filename=temp_blob.name,
                    sample_rate=cd_image.sample_rate(),
                    channels=cd_image.channels(),
                    channel_mask=int(cd_image.channel_mask()),
                    bits_per_sample=cd_image.bits_per_sample(),
                    track_audiofile=track,
                    image_filename=str(image_name),
                    track_filename=str(track_name),
                    pcm_frames_offset=pcm_offset,
                    total_pcm_frames=track.total_frames())

                pcm_offset += track.total_frames()

            try:
                if ({r[0] for r in
                     queue.run(options.max_processes)} != {None}):
                    temp_blob.close()
                    sys.exit(1)
            except KeyboardInterrupt:
                msg.error(_.ERR_CANCELLED)
                temp_blob.close()
                sys.exit(1)

            # then delete the blob when finished
            temp_blob.close()
    else:
        msg.output(_.USAGE_TRACKCMP_FILES)
        sys.exit(1)
