#! /usr/bin/env python
# sga-joinpe - join the two ends of a paired end read together 
# into one composite sequence. This is used to run rmdup on the pairs
import getopt
import sys

def writeJoinedRecord(out, record1, record2):
    # Strip the last two characters from the name
    join_name_1 = record1[0].split()[0][0:-2]
    join_name_2 = record2[0].split()[0][0:-2]
    if join_name_1 != join_name_2:
        print 'Error: Reads are incorrectly paired ', record1[0], record2[0]
        sys.exit(2)

    join_pos = len(record1[1])
    join_seq = record1[1] + record2[1]
    
    # Write ID and sequence
    out.write(join_name_1 + " " + "join_pos:" + str(join_pos) + "\n")
    out.write(record1[1] + record2[1] + "\n")

    if len(record1) > 2:
        # Fastq mode, output quality and meta field too
        out.write("+\n")
        out.write(record1[3] + record2[3] + "\n")

def writeSplitRecord(out, record):
    header_fields = record[0].split()
    join_pos = -1
    if header_fields[1].find("join_pos") != -1:
        join_pos = int(header_fields[1].split(":")[1])
    else:
        print 'Error: joined record does not have the correct format ', record[0]

#    print 'join_pos', join_pos
    id1 = header_fields[0] + "/1"
    id2 = header_fields[0] + "/2"
    seq1 = record[1][0:join_pos]
    seq2 = record[1][join_pos:]
    assert(len(seq1) == join_pos and len(seq1) + len(seq2) == len(record[1]))

    qual1 = ""
    qual2 = ""
    if len(record) == 4:
            qual1 = record[3][0:join_pos]
            qual2 = record[3][join_pos:]
            assert(len(qual1) == join_pos and len(qual1) + len(qual2) == len(record[3]))
 
    # Write the first record
    out.write(id1 + "\n")
    out.write(seq1 + "\n")

    if len(record) == 4:
        out.write(record[2] + "\n")
        out.write(qual1 + "\n")

    # Write the second record
    out.write(id2 + "\n")
    out.write(seq2 + "\n")

    if len(record) == 4:
        out.write(record[2] + "\n")
        out.write(qual2 + "\n")

# Parse a record out of a fasta or fastq file
def readRecord(file):
    out = list()

    line = file.readline()
    if line == "":
        return out #done read

    id = line.rstrip()
    if id[0] == ">":
        isFastq = False
    elif id[0] == "@":
        isFastq = True
    else:
        print 'Unknown record header: ', line
        sys.exit(2)

    out.append(id)
    if isFastq:
        out.append(file.readline().rstrip())
        out.append(file.readline().rstrip())
        out.append(file.readline().rstrip())
    else:
        out.append(file.readline().rstrip())
    return out

#
def usage():
    print 'sga-joinedpe [-o outfile] FILE'
    print 'allow rmdup to be performed on entire paired end reads'
    print 'if the --join option is provided, two consequetive records in FILE will be concatenated into one read.'
    print 'if the --split option is provided, each record in FILE will be split into two'
    print '\nOptions: '
    print '          -o, --outfile=FILE    Write results to FILE [default=stdout]'
    print '              --join            Join the reads together'
    print '              --split           Split the reads apart'
    print '              --help            Print this usage message'

# Defaults
outfilename = ""
join = False
split = False

# Parse command line arguments
try:
    opts, args = getopt.gnu_getopt(sys.argv[1:], 'o:', ["outfile=",
                                                        "split",
                                                        "join",
                                                        "help"])
except getopt.GetoptError, err:
        print str(err)
        usage()
        sys.exit(2)
    
for (oflag, oarg) in opts:
        if oflag == '-o' or oflag == '--outfile':
            outfilename = oarg
        elif oflag == '--split':
            split = True
        elif oflag == "--join":
            join = True
        elif oflag == '--help':
            usage()
            sys.exit(1)
        else:
            print 'Unrecognized argument', oflag
            usage()
            sys.exit(0)

if len(args) != 1:
    print 'Error: Exactly one input file must be provided'
    sys.exit(2)

if not join and not split:
    print 'Error: One of --join or --split must be provided'
    sys.exit(2)

if join and split:
    print 'Error: Only one of --join or --split can be provided'
    sys.exit(2)

filename = args[0]

outfile = sys.stdout
if outfilename != "":
    outfile = open(outfilename, "w")

file = open(filename)

if join:
    done = False
    while not done:
        record1 = readRecord(file)
        record2 = readRecord(file)

        if len(record1) != len(record2):
            print 'Error, record lengths differ'
            sys.exit(2)

        # If no record was read, we are done
        if len(record1) == 0:
            assert(len(record2) == 0)
            break

        writeJoinedRecord(outfile, record1, record2)
elif split:
    done = False
    while not done:
        record = readRecord(file)

        # If no record was read, we are done
        if len(record) == 0:
            break

        writeSplitRecord(outfile, record)    
