include: "Common.snake"

import os
import os.path

from scripts.common import dump_dict

#Path to saves of necessary assembly stage
SAVES = "K{0}/saves/01_before_repeat_resolution/graph_pack".format(K)

onstart:
    try:
        os.mkdir("tmp")
    except:
        pass
    print("Detected", SAMPLE_COUNT, "samples in", IN)
    print("They form: ", GROUPS)

# ---- Main pipeline -----------------------------------------------------------

rule all:
    input:   dynamic("reassembly/{cag}.fasta")
    message: "Dataset of {SAMPLE_COUNT} samples from {IN} has been processed."

rule assemble:
    input:   left=left_reads, right=right_reads
    output:  "assembly/{sample}.fasta"
    #TODO: remove this boilerplate
    params:  left=lambda w: " ".join(expand("-1 {r}", r=left_reads(w))),
             right=lambda w: " ".join(expand("-2 {r}", r=right_reads(w))),
             dir="assembly/{sample}"
    log:     "assembly/{sample}.log"
    threads: THREADS
    message: "Assembling {wildcards.sample} with SPAdes"
    shell:   "{SPADES}/spades.py --meta -m 400 -t {threads} {params.left} {params.right}"
             " -o {params.dir} >{log} 2>&1 && "
             "cp {params.dir}/scaffolds.fasta {output}"

rule assemble_all:
    input:   expand("assembly/{sample}.fasta", sample=GROUPS)
    message: "Assembled all samples"

rule descriptions:
    output:  expand("profile/{sample}.desc", sample=SAMPLES)
    message: "Generating sample descriptions"
    run:
        for sample in SAMPLES:
            with open("profile/{}.desc".format(sample), "w") as out:
                wildcards.sample = sample
                print(left_reads(wildcards),  file=out)
                print(right_reads(wildcards), file=out)

rule kmc:
    input:   "profile/{sample}.desc"
    output:  temp("tmp/{sample}.kmc_pre"), temp("tmp/{sample}.kmc_suf")
    params:  min_mult=2, tmp="tmp/{sample}_kmc", out="tmp/{sample}"
    log:     "profile/kmc_{sample}.log"
    threads: THREADS
    message: "Running kmc for {wildcards.sample}"
    shell:   "mkdir {params.tmp} && "
             "{SOFT}/kmc -k{SMALL_K} -t{threads} -ci{params.min_mult} -cs65535"
             " @{input} {params.out} {params.tmp} >{log} 2>&1 && "
             "rm -rf {params.tmp}"

rule multiplicities:
    input:   expand("tmp/{sample}.kmc_pre", sample=SAMPLES), expand("tmp/{sample}.kmc_suf", sample=SAMPLES)
    output:  "profile/kmers.kmm"
    params:  kmc_files=" ".join(expand("tmp/{sample}", sample=SAMPLES)), out="profile/kmers"
    log:     "profile/kmers.log"
    message: "Gathering {SMALL_K}-mer multiplicities from all samples"
    shell:   "{BIN}/kmer_multiplicity_counter -n {SAMPLE_COUNT} -k {SMALL_K} -s 3"
             " -f tmp -t {threads} -o {params.out} >{log} 2>&1 && "
             "rm tmp/*.sorted"

rule profile:
    input:   contigs="assembly/{sample,\w+\d+}.fasta", mpl="profile/kmers.kmm"
    output:  id="profile/{sample}.id", mpl="profile/{sample}.mpl", splits= "assembly/{sample}_splits.fasta"
    log:     "profile/{sample}.log"
    message: "Counting contig abundancies for {wildcards.sample}"
    shell:   "{BIN}/contig_abundance_counter -k {SMALL_K} -w tmp -c {input.contigs}"
             " -n {SAMPLE_COUNT} -m profile/kmers -o profile/{wildcards.sample}"
             " -f {output.splits} -l {MIN_CONTIG_LENGTH} >{log} 2>&1"

rule binning_pre:
    input:   expand("profile/{sample}.id", sample=GROUPS)
    output:  "binning/{binner}/profiles.in"
    params:  " ".join(list(GROUPS.keys()))
    message: "Preparing input for {wildcards.binner}"
    shell:   "{SCRIPTS}/make_input.py -t {wildcards.binner} -d profile -o {output} {params}"

rule canopy:
    input:   "binning/canopy/profiles.in"
    output:  out="binning/canopy/binning.out", prof="binning/canopy/bins.prof"
    threads: THREADS
    message: "Running canopy clustering"
    shell:   "{SOFT}/cc.bin -n {threads} -i {input} -o {output.out} -c {output.prof} >binning/canopy/canopy.log 2>&1"

rule combine_splits:
    input:   expand("assembly/{sample}_splits.fasta", sample=GROUPS)
    output:  "assembly/samples_splits.fasta"
    message: "Combine splitted contigs"
    shell:   "{SCRIPTS}/combine_contigs.py -r {input} > {output}"

#FIXME what does gt1000 mean?
rule concoct:
    input:   contigs=rules.combine_splits.output[0], profiles="binning/concoct/profiles.in"
    output:  out="binning/concoct/clustering_gt1000.csv"
    params:  "binning/concoct"
    message: "Running CONCOCT clustering"
    shell:   "mkdir -p {params} && "
             "set +u; source activate concoct_env; set -u && "
             "concoct --composition_file {input.contigs} --coverage_file {input.profiles} -b {params}"

binning_inputs = {"canopy": rules.canopy.output.out, "concoct": rules.concoct.output.out}

rule binning_post:
    input:   binning_inputs[BINNER]
    output:  expand("annotation/{sample}.ann", sample=GROUPS)
    message: "Preparing raw annotations"
    shell:   "{SCRIPTS}/parse_output.py -t {BINNER} -o annotation {input}"

#Post-clustering pipeline
rule read_binning:
    input:   contigs="assembly/{sample}.fasta", ann="annotation/{sample}.ann",
             left=left_reads, right=right_reads
    output:  "propagation/{sample}_edges.ann"
    params:  saves=os.path.join("assembly/{sample}/", SAVES),
             splits="assembly/{sample}_splits.fasta",
             out="propagation/{sample}_edges",
             group=lambda wildcards: GROUPS[wildcards.sample]
             #left=" ".join(input.left), right=" ".join(input.right)
    log:     "binning/{sample}.log"
    message: "Propagating annotation & binning reads for {wildcards.sample}"
    shell:
          "{BIN}/prop_binning -k {K} -s {params.saves} -c {input.contigs}"
          " -n {params.group} -l {input.left} -r {input.right}"
          " -a {input.ann} -f {params.splits} -o binning -d {params.out} >{log} 2>&1"

#TODO: bin profiles for CONCOCT
rule choose_samples:
    input:   binned=expand("propagation/{sample}_edges.ann", sample=GROUPS),
             prof=rules.canopy.output.prof
    output:  dynamic("binning/{cag}/left.fastq"),
             dynamic("binning/{cag}/right.fastq")
    log:     "binning/choose_samples.log"
    message: "Choosing samples for all CAGs"
    shell:   "{SCRIPTS}/choose_samples.py {input.prof} binning/ >{log} 2>&1"

rule reassembly_config:
    input:   "binning/{cag}/left.fastq"
    output:  "reassembly/{cag}.yaml"
    message: "Generated config file for reassembly of {wildcards.cag}"
    run:
        with open(output[0], "w") as outfile:
            conf = {"k": SMALL_K, "sample_cnt": SAMPLE_COUNT,
                    "kmer_mult": str(rules.multiplicities.params.out),
                    "bin": wildcards.cag, "bin_prof": str(rules.canopy.output.prof),
                    "edges_sqn": "profile/{}_edges.fasta".format(wildcards.cag),
                    "edges_mpl": "profile/{}_edges.mpl".format(wildcards.cag),
                    "edge_fragments_mpl": "profile/{}_edges_frag.mpl".format(wildcards.cag),
                    "frag_size": 10000, "min_len": 100}
            dump_dict(conf, outfile)

rule reassemble:
    input:   left="binning/{cag}/left.fastq", right="binning/{cag}/right.fastq",
             config="reassembly/{cag}.yaml"
    output:  "reassembly/{cag}.fasta"
    params:  "reassembly/reassembly_{cag}"
    log:     "reassembly/reassembly_{cag}.log"
    threads: THREADS
    message: "Reassembling reads for {wildcards.cag}"
    shell:   "{SPADES_REASSEMBLY}/spades.py --meta -t {threads}"
             " --pe1-1 {input.left} --pe1-2 {input.right} --pe1-ff"
             " -o {params} --series-analysis {input.config} >{log} 2>&1 && "
             "cp {params}/scaffolds.fasta {output}"
