#!/usr/bin/env python

import os
import sys
import argparse
from operator import itemgetter

type_gene = ["gene", "snRNA_gene", "transposable_element_gene", "ncRNA_gene", "telomerase_RNA_gene", 
	"rRNA_gene", "tRNA_gene", "snoRNA_gene", "mt_gene", "miRNA_gene", "lincRNA_gene", "RNA", "VD_gene_segment"]
type_transcript = ["transcript", "primary_transcript", "mRNA", "ncRNA", "tRNA", "rRNA", "snRNA", "snoRNA", "miRNA",
	"pseudogenic_transcript", "lincRNA", "NMD_transcript_variant", "aberrant_processed_transcript",
	"nc_primary_transcript", "processed_pseudogene", "mRNA_TE_gene"]
type_exon = ["exon", "CDS", "five_prime_UTR", "three_prime_UTR", "UTR", "noncoding_exon", "pseudogenic_exon"]

# can be either gene or transcript, need special treatment
type_gene_or_transcript = ["pseudogene", "V_gene_segment", "C_gene_segment", "J_gene_segment", "processed_transcript"] 


class HelpOnErrorParser(argparse.ArgumentParser):
	def error(self, msg):
		sys.stderr.write("{0}: error: {1}\n\n".format(os.path.basename(sys.argv[0]), msg))
		self.print_help()
		sys.exit(-1)


def my_assert(bool, msg):
	if not bool:
		sys.stderr.write(msg + "\n")
		try:
			os.remove(args.output_GTF_file)
		except OSError:
			pass
		sys.exit(-1)


class Feature:
	# def gen_type_dict():
	def gen_type_dict(self):
		my_dict = {}
		for my_type in type_gene:
			my_dict[my_type] = "gene"
		for my_type in type_transcript:
			my_dict[my_type] = "transcript"
		for my_type in type_exon:
			my_dict[my_type] = "exon"

		for my_type in type_gene_or_transcript:
			my_dict[my_type] = "gene_or_transcript"

		return my_dict

	# type_dict = gen_type_dict()

	def __init__(self):
		self.type_dict = self.gen_type_dict()

	def parse(self, line, line_no):
		""" line should be free of leading and trailing spaces """

		self.line = line
		self.line_no = line_no

		fields = line.split('\t')
		my_assert(len(fields) == 9, "Line {0} does not have 9 fields:\n{1}".format(self.line_no, self.line))		

		self.seqid = fields[0]
		self.source = fields[1]
		self.original_type = fields[2]
		self.feature_type = self.type_dict.get(fields[2], None)
		self.start = int(fields[3])
		self.end = int(fields[4])
		self.strand = fields[6]
		self.attributes = fields[8][:-1] if len(fields[8]) > 0 and fields[8][-1] == ';' else fields[8]

	def parseAttributes(self):
		self.attribute_dict = {}
		for attribute in self.attributes.split(';'):
			fields = attribute.split('=')
			my_assert(len(fields) == 2, "Fail to parse attribute {0} of line {1}:\n{2}".format(attribute, self.line_no, self.line))
			tag, value = fields
			if tag == "Parent":
				self.attribute_dict[tag] = value.split(',')
			else:
				self.attribute_dict[tag] = value

	def getAttribute(self, tag, required = False):
		value = self.attribute_dict.get(tag, None)
		my_assert(not required or value != None, "Line {0} does not have attribute {1}:\n{2}".format(self.line_no, tag, self.line))
		return value


class Transcript:
	def __init__(self, tid, feature):
		self.tid = tid
		self.tname = self.ttype = None
		self.gid = self.gname = None
		self.setT = False # if a transcript feature has been set

		self.seqid = feature.seqid
		# self.source = feature.source
		self.source = None
		self.strand = feature.strand

		self.intervals = []

	def setTranscript(self, feature):
		my_assert(not self.setT, 
			"Transcript {0} appears multiple times! Last occurrence is at line {1}:\n{2}".format(self.tid, feature.line_no, feature.line))
		self.setT = True
		parents = feature.getAttribute("Parent", True)
		my_assert(len(parents) == 1, "Transcript {0} at line {1} has more than one parents:\n{2}".format(self.tid, feature.line_no, feature.line))
		self.gid = parents[0]
		self.tname = feature.getAttribute("Name")
		self.ttype = feature.original_type
		self.source = feature.source

	def addExon(self, feature):
		self.intervals.append((feature.start, feature.end))

	def merge(self):
		self.intervals.sort(key = itemgetter(0))
		self.results = []
		cstart, cend = self.intervals[0]
		for start, end in self.intervals[1:]:
			if cend + 1 >= start:
				cend = max(cend, end)
			else:
				self.results.append((cstart, cend))
				cstart = start
				cend = end
		self.results.append((cstart, cend))

	def __iter__(self):
		self.index = 0
		return self

	def next(self):
		if self.index == len(self.results):
			raise StopIteration
		interval = self.results[self.index]
		self.index += 1
		return interval


def getTranscript(tid, feature):
	assert tid != None

	pos = tid2pos.get(tid, None)
	if pos == None:
		transcript = Transcript(tid, feature)
		tid2pos[tid] = len(transcripts)
		transcripts.append(transcript)
	else:
		my_assert(pos >= 0, 
			"Line {0} describes an already processed Transcript {1}:\n{2}".format(feature.line_no, tid, feature.line))
		transcript = transcripts[pos]
		my_assert(transcript.seqid == feature.seqid and transcript.strand == feature.strand, 
				"Line {0}'s seqid/strand is not consistent with other records of transcript {1}:\n{2}".format(
					feature.line_no, tid, feature.line))

	return transcript

def flush_out(fout):
	global transcripts
	global tid2pos
	global num_trans
	global patterns

	for transcript in transcripts:
		tid2pos[transcript.tid] = -1
		if not transcript.setT or len(transcript.intervals) == 0 or (len(patterns) > 0 and transcript.ttype not in patterns):
			continue

		my_assert(transcript.gid in gid2gname, 
			"Cannot recognize transcript {0}'s parent {1}, a gene feature might be missing.".format(transcript.tid, transcript.gid))

		transcript.gname = gid2gname[transcript.gid]

		transcript.merge()

		output_string = "{0}\t{1}\texon\t{{0}}\t{{1}}\t.\t{2}\t.\tgene_id \"{3}\"; transcript_id \"{4}\";".format(
			transcript.seqid, transcript.source, transcript.strand, transcript.gid, transcript.tid)
		if transcript.gname != None:
			output_string += " gene_name \"{0}\";".format(transcript.gname)
		if transcript.tname != None:
			output_string += " transcript_name \"{0}\";".format(transcript.tname)
		output_string += "\n"

		for start, end in transcript:
			fout.write(output_string.format(start, end))

		num_trans += 1

	transcripts = []



parser = HelpOnErrorParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter, description = "Convert GFF3 files to GTF files.")
parser.add_argument("input_GFF3_file", help = "Input GFF3 file.")
parser.add_argument("output_GTF_file", help = "Output GTF file.")
parser.add_argument("--RNA-patterns", help = "Types of RNAs to be extracted, e.g. mRNA,rRNA", metavar = "<patterns>")
parser.add_argument("--extract-sequences", help = "If GFF3 file contains reference sequences, extract them to the specified file", metavar = "<output.fa>")
args = parser.parse_args()

patterns = set()
if args.RNA_patterns != None:
	patterns = set(args.RNA_patterns.split(','))

line_no = 0
feature = Feature()

gid2gname = {}

tid2pos = {}
transcripts = []

num_trans = 0

reachFASTA = False

with open(args.input_GFF3_file) as fin:
	fout = open(args.output_GTF_file, "w")

	for line in fin:
		line = line.strip()
		line_no += 1
		if line_no % 100000 == 0:
			print("Loaded {} lines".format(line_no))

		if line.startswith("##FASTA"):
			reachFASTA = True
			break

		if line.startswith("###"):
			flush_out(fout)
			continue

		if line.startswith("#"):
			continue

		feature.parse(line, line_no)
		if feature.feature_type == None:
			continue
		feature.parseAttributes()

		if feature.feature_type == "gene_or_transcript":
			parent = feature.getAttribute("Parent")
			if parent == None:
				feature.feature_type = "gene"
			else:
				feature.feature_type = "transcript"

		if feature.feature_type == "gene":
			gid = feature.getAttribute("ID", True)
			my_assert(gid not in gid2gname, 
				"Gene {0} appears multiple times! Last occurrence is at line {1}:\n{2}".format(gid, feature.line_no, feature.line))
			gid2gname[gid] = feature.getAttribute("Name")
		elif feature.feature_type == "transcript":
			transcript = getTranscript(feature.getAttribute("ID", True), feature)
			transcript.setTranscript(feature)
		else:
			assert feature.feature_type == "exon"
			for parent in feature.getAttribute("Parent", True):
				transcript = getTranscript(parent, feature)
				transcript.addExon(feature)

	flush_out(fout)
	fout.close()
	
	print("GTF file is successully generated.")
	print("There are {0} transcripts contained in the generated GTF file.".format(num_trans))	

	if reachFASTA and args.extract_sequences != None:
		with open(args.extract_sequences, "w") as fout:
			for line in fin:
				fout.write(line)
		print("FASTA file is successfully generated.")