diff --git a/src/metacoag/metacoag_runner.py b/src/metacoag/metacoag_runner.py index e9599ba..a7aad2d 100644 --- a/src/metacoag/metacoag_runner.py +++ b/src/metacoag/metacoag_runner.py @@ -322,9 +322,14 @@ def run(args): # Map original contig identifiers to contig identifiers of MEGAHIT assembly graph graph_to_contig_map = BidirectionalMap() - for (n, m), (n2, m2) in zip(graph_contigs.items(), original_contigs.items()): - if m == m2: - graph_to_contig_map[n] = n2 + # Build reverse lookup: sequence -> original contig name + original_seq_to_name = {} + for name, seq in original_contigs.items(): + original_seq_to_name[seq] = name + + for graph_name, graph_seq in graph_contigs.items(): + if graph_seq in original_seq_to_name: + graph_to_contig_map[graph_name] = original_seq_to_name[graph_seq] graph_to_contig_map_rev = graph_to_contig_map.inverse @@ -352,6 +357,12 @@ def run(args): abundance_file=abundance_file, ) + # Assign length 0 to graph-only contigs (in GFA but not in FASTA) + # so they are excluded by all downstream min_length checks + for i in range(node_count): + if i not in contig_lengths: + contig_lengths[i] = 0 + else: sequences, coverages, contig_lengths, n_samples = feature_utils.get_cov_len( contigs_file=contigs_file, @@ -652,6 +663,7 @@ def run(args): normalized_tetramer_profiles=normalized_tetramer_profiles, coverages=coverages, w_intra=w_intra, + nthreads=nthreads, ) # Get remaining contigs with single-copy marker genes which are not assigned to bins @@ -737,6 +749,7 @@ def run(args): coverages=coverages, depth=1, weight=w_intra, + nthreads=nthreads, ) logger.debug(f"Total number of binned contigs: {len(bin_of_contig)}") @@ -765,6 +778,7 @@ def run(args): coverages=coverages, depth=depth, weight=w_inter, + nthreads=nthreads, ) logger.debug(f"Total number of binned contigs: {len(bin_of_contig)}") @@ -891,6 +905,7 @@ def thread_function( coverages=coverages, depth=depth, weight=MAX_WEIGHT, + nthreads=nthreads, ) logger.debug(f"Total number of binned contigs: {len(bin_of_contig)}") diff --git a/src/metacoag/metacoag_utils/feature_utils.py b/src/metacoag/metacoag_utils/feature_utils.py index 482150f..1951ef7 100755 --- a/src/metacoag/metacoag_utils/feature_utils.py +++ b/src/metacoag/metacoag_utils/feature_utils.py @@ -98,19 +98,24 @@ def get_tetramer_profiles( else: kmer_inds_4, kmer_count_len_4 = compute_kmer_inds(4) + # Handle both list and dict sequences + if isinstance(sequences, dict): + seq_keys = list(sequences.keys()) + seq_values = [sequences[k] for k in seq_keys] + else: + seq_keys = list(range(len(sequences))) + seq_values = sequences + pool = Pool(nthreads) record_tetramers = pool.map( - count_kmers, [(seq, 4, kmer_inds_4, kmer_count_len_4) for seq in sequences] + count_kmers, [(seq, 4, kmer_inds_4, kmer_count_len_4) for seq in seq_values] ) pool.close() normalized = [x[1] for x in record_tetramers] - i = 0 - - for l in range(len(normalized)): - normalized_tetramer_profiles[i] = normalized[l] - i += 1 + for idx, key in enumerate(seq_keys): + normalized_tetramer_profiles[key] = normalized[idx] with open( f"{output_path}{contigs_file}.normalized_contig_tetramers.pickle", "wb" @@ -121,8 +126,8 @@ def get_tetramer_profiles( tetramer_profiles = {} - for i in range(len(normalized_tetramer_profiles)): - if contig_lengths[i] >= min_length: + for i in normalized_tetramer_profiles: + if i in contig_lengths and contig_lengths[i] >= min_length: tetramer_profiles[i] = normalized_tetramer_profiles[i] return tetramer_profiles @@ -186,19 +191,23 @@ def get_cov_len_megahit( i = 0 - sequences = [] + sequences = {} for index, record in enumerate(SeqIO.parse(contigs_file, "fasta")): + if record.id not in graph_to_contig_map_rev: + continue contig_num = contig_names_rev[graph_to_contig_map_rev[record.id]] length = len(record.seq) contig_lengths[contig_num] = length - sequences.append(str(record.seq)) + sequences[contig_num] = str(record.seq) i += 1 with open(abundance_file, "r") as my_abundance: for line in my_abundance: strings = line.strip().split("\t") + if strings[0] not in graph_to_contig_map_rev: + continue contig_num = contig_names_rev[graph_to_contig_map_rev[strings[0]]] if contig_lengths[contig_num] >= min_length: diff --git a/src/metacoag/metacoag_utils/label_prop_utils.py b/src/metacoag/metacoag_utils/label_prop_utils.py index 6d3cbd1..defecf9 100755 --- a/src/metacoag/metacoag_utils/label_prop_utils.py +++ b/src/metacoag/metacoag_utils/label_prop_utils.py @@ -1,10 +1,13 @@ #!/usr/bin/env python3 +import concurrent.futures import heapq import logging import math import sys +import numpy as np + from metacoag.metacoag_utils import matching_utils MAX_WEIGHT = sys.float_info.max @@ -40,6 +43,8 @@ def run_bfs_long( assembly_graph, normalized_tetramer_profiles, coverages, + bin_tetra_mat=None, + bin_cov_mat=None, ): # Search labelled long contigs using BFS @@ -60,39 +65,50 @@ def run_bfs_long( # Get the bin of the current contig contig_bin = bin_of_contig[active_node] - bin_log_prob = 0 + if bin_tetra_mat is not None and contig_bin in bin_tetra_mat: + # Vectorised path: one cdist call for all N seed members at once. + # Mathematically identical to the scalar loop below — + # same formula, same overflow-to-MAX_WEIGHT behaviour. + bin_log_prob = matching_utils._compute_edge_weight_exact( + normalized_tetramer_profiles[node], + coverages[node], + bin_tetra_mat[contig_bin], + bin_cov_mat[contig_bin], + ) + else: + bin_log_prob = 0 - log_prob_sum = 0 + log_prob_sum = 0 - n_contigs = smg_bin_counts[contig_bin] - bin_n_contigs = 0 + n_contigs = smg_bin_counts[contig_bin] + bin_n_contigs = 0 - for j in range(n_contigs): - tetramer_dist = matching_utils.get_tetramer_distance( - normalized_tetramer_profiles[node], - normalized_tetramer_profiles[bins[contig_bin][j]], - ) - prob_comp = matching_utils.get_comp_probability(tetramer_dist) - prob_cov = matching_utils.get_cov_probability( - coverages[node], coverages[bins[contig_bin][j]] - ) + for j in range(n_contigs): + tetramer_dist = matching_utils.get_tetramer_distance( + normalized_tetramer_profiles[node], + normalized_tetramer_profiles[bins[contig_bin][j]], + ) + prob_comp = matching_utils.get_comp_probability(tetramer_dist) + prob_cov = matching_utils.get_cov_probability( + coverages[node], coverages[bins[contig_bin][j]] + ) - prob_product = prob_comp * prob_cov + prob_product = prob_comp * prob_cov - log_prob = 0 + log_prob = 0 - if prob_product > 0.0: - log_prob = -(math.log(prob_comp, 10) + math.log(prob_cov, 10)) - bin_n_contigs += 1 - else: - log_prob = MAX_WEIGHT + if prob_product > 0.0: + log_prob = -(math.log(prob_comp, 10) + math.log(prob_cov, 10)) + bin_n_contigs += 1 + else: + log_prob = MAX_WEIGHT - log_prob_sum += log_prob + log_prob_sum += log_prob - if log_prob_sum != float("inf") and bin_n_contigs != 0: - bin_log_prob = log_prob_sum / bin_n_contigs - else: - bin_log_prob = MAX_WEIGHT + if log_prob_sum != float("inf") and bin_n_contigs != 0: + bin_log_prob = log_prob_sum / bin_n_contigs + else: + bin_log_prob = MAX_WEIGHT labelled_nodes.add( (node, active_node, contig_bin, depth[active_node], bin_log_prob) @@ -151,14 +167,15 @@ def run_bfs_short( def getClosestLongVertices(graph, node, binned_contigs, contig_lengths, min_length): + # binned_contigs must support O(1) membership tests (set or dict) queu_l = [graph.neighbors(node, mode="ALL")] - visited_l = [node] + visited_l = {node} unlabelled = [] while len(queu_l) > 0: active_level = queu_l.pop(0) is_finish = False - visited_l += active_level + visited_l.update(active_level) for n in active_level: if contig_lengths[n] >= min_length and n not in binned_contigs: @@ -167,15 +184,10 @@ def getClosestLongVertices(graph, node, binned_contigs, contig_lengths, min_leng if is_finish: return unlabelled else: - temp = [] + temp = set() for n in active_level: - temp += graph.neighbors(n, mode="ALL") - temp = list(set(temp)) - temp2 = [] - - for n in temp: - if n not in visited_l: - temp2.append(n) + temp.update(graph.neighbors(n, mode="ALL")) + temp2 = [n for n in temp if n not in visited_l] if len(temp2) > 0: queu_l.append(temp2) return unlabelled @@ -196,9 +208,11 @@ def label_prop( coverages, depth, weight, + nthreads=1, ): contigs_to_bin = set() + # Use bin_of_contig directly (dict) for O(1) membership in getClosestLongVertices for contig in bin_of_contig: if contig in non_isolated and contig_lengths[contig] >= min_length: closest_neighbours = filter( @@ -206,7 +220,7 @@ def label_prop( getClosestLongVertices( assembly_graph, contig, - list(bin_of_contig.keys()), + bin_of_contig, contig_lengths, min_length, ), @@ -214,32 +228,45 @@ def label_prop( contigs_to_bin.update(closest_neighbours) sorted_node_list = [] - sorted_node_list_ = [ - list( - run_bfs_long( - x, - depth, - bin_of_contig.keys(), - bin_of_contig, - bins, - smg_bin_counts, - assembly_graph, - normalized_tetramer_profiles, - coverages, - ) - ) - for x in contigs_to_bin - ] + # Build seed-member matrices once. smg_bin_counts is frozen before label + # propagation starts (computed from the initial seed bins) and never updated + # as new contigs are added — so bins[b][:smg_bin_counts[b]] is stable + # throughout the entire function, including the per-neighbour BFS calls + # inside the assignment loop. + _seed_tetra_mat = {} + _seed_cov_mat = {} + for _b in range(len(smg_bin_counts)): + _n = smg_bin_counts[_b] + _members = bins[_b][:_n] + _seed_tetra_mat[_b] = np.array([normalized_tetramer_profiles[c] for c in _members]) + _seed_cov_mat[_b] = np.array([coverages[c] for c in _members], dtype=float) + + # All BFS calls are independent (read-only data); run them in parallel. + _binned_view = bin_of_contig.keys() + def _bfs_long_worker_lp(x): + return list(run_bfs_long( + x, depth, _binned_view, bin_of_contig, bins, smg_bin_counts, + assembly_graph, normalized_tetramer_profiles, coverages, + bin_tetra_mat=_seed_tetra_mat, bin_cov_mat=_seed_cov_mat, + )) + with concurrent.futures.ThreadPoolExecutor(max_workers=nthreads) as pool: + sorted_node_list_ = list(pool.map(_bfs_long_worker_lp, contigs_to_bin)) sorted_node_list_ = [item for sublist in sorted_node_list_ for item in sublist] for data in sorted_node_list_: - heapObj = DataWrap(data) - heapq.heappush(sorted_node_list, heapObj) + heapq.heappush(sorted_node_list, DataWrap(data)) + + # Lazy-deletion set: contigs that already have a fresh BFS entry queued + stale = set() while sorted_node_list: best_choice = heapq.heappop(sorted_node_list) to_bin, binned, bin_, dist, cov_comp_diff = best_choice.data + # Skip stale entries whose neighbourhood has already been re-queued + if to_bin in stale: + continue + can_bin = False has_mg = False @@ -276,40 +303,42 @@ def label_prop( set(bin_markers[bin_] + contig_markers[to_bin]) ) - # Discover to_bin's neighbours + # Discover to_bin's neighbours; mark old entries stale instead of + # rebuilding the heap, then push fresh BFS results. unbinned_neighbours = set( filter( lambda x: contig_lengths[x] >= min_length, getClosestLongVertices( assembly_graph, to_bin, - list(bin_of_contig.keys()), + bin_of_contig, contig_lengths, min_length, ), ) ) - sorted_node_list = list( - filter(lambda x: x.data[0] not in unbinned_neighbours, sorted_node_list) - ) - heapq.heapify(sorted_node_list) + stale.update(unbinned_neighbours) for un in unbinned_neighbours: candidates = list( run_bfs_long( un, depth, - list(bin_of_contig.keys()), + bin_of_contig.keys(), bin_of_contig, bins, smg_bin_counts, assembly_graph, normalized_tetramer_profiles, coverages, + bin_tetra_mat=_seed_tetra_mat, + bin_cov_mat=_seed_cov_mat, ) ) for c in candidates: heapq.heappush(sorted_node_list, DataWrap(c)) + # Fresh entry is now queued; remove from stale so it can be processed + stale.discard(un) return bins, bin_of_contig, bin_markers, binned_contigs_with_markers @@ -414,9 +443,11 @@ def final_label_prop( coverages, depth, weight, + nthreads=1, ): contigs_to_bin = set() + # Use bin_of_contig directly (dict) for O(1) membership in getClosestLongVertices for contig in bin_of_contig: if contig_lengths[contig] >= min_length: closest_neighbours = filter( @@ -424,7 +455,7 @@ def final_label_prop( getClosestLongVertices( assembly_graph, contig, - list(bin_of_contig.keys()), + bin_of_contig, contig_lengths, min_length, ), @@ -432,32 +463,42 @@ def final_label_prop( contigs_to_bin.update(closest_neighbours) sorted_node_list = [] - sorted_node_list_ = [ - list( - run_bfs_long( - x, - depth, - bin_of_contig.keys(), - bin_of_contig, - bins, - smg_bin_counts, - assembly_graph, - normalized_tetramer_profiles, - coverages, - ) - ) - for x in contigs_to_bin - ] + # Build seed-member matrices once. Same rationale as in label_prop: smg_bin_counts + # is frozen so bins[b][:smg_bin_counts[b]] is stable throughout this function. + _seed_tetra_mat_flp = {} + _seed_cov_mat_flp = {} + for _b in range(len(smg_bin_counts)): + _n = smg_bin_counts[_b] + _members = bins[_b][:_n] + _seed_tetra_mat_flp[_b] = np.array([normalized_tetramer_profiles[c] for c in _members]) + _seed_cov_mat_flp[_b] = np.array([coverages[c] for c in _members], dtype=float) + + # All BFS calls are independent (read-only data); run them in parallel. + _binned_view_flp = bin_of_contig.keys() + def _bfs_long_worker_flp(x): + return list(run_bfs_long( + x, depth, _binned_view_flp, bin_of_contig, bins, smg_bin_counts, + assembly_graph, normalized_tetramer_profiles, coverages, + bin_tetra_mat=_seed_tetra_mat_flp, bin_cov_mat=_seed_cov_mat_flp, + )) + with concurrent.futures.ThreadPoolExecutor(max_workers=nthreads) as pool: + sorted_node_list_ = list(pool.map(_bfs_long_worker_flp, contigs_to_bin)) sorted_node_list_ = [item for sublist in sorted_node_list_ for item in sublist] for data in sorted_node_list_: - heapObj = DataWrap(data) - heapq.heappush(sorted_node_list, heapObj) + heapq.heappush(sorted_node_list, DataWrap(data)) + + # Lazy-deletion set: contigs that already have a fresh BFS entry queued + stale = set() while sorted_node_list: best_choice = heapq.heappop(sorted_node_list) to_bin, binned, bin_, dist, cov_comp_diff = best_choice.data + # Skip stale entries whose neighbourhood has already been re-queued + if to_bin in stale: + continue + has_mg = False if to_bin in contig_markers: @@ -484,39 +525,41 @@ def final_label_prop( set(bin_markers[bin_] + contig_markers[to_bin]) ) - # Discover to_bin's neighbours + # Discover to_bin's neighbours; mark old entries stale instead of + # rebuilding the heap, then push fresh BFS results. unbinned_neighbours = set( filter( lambda x: contig_lengths[x] >= min_length, getClosestLongVertices( assembly_graph, to_bin, - list(bin_of_contig.keys()), + bin_of_contig, contig_lengths, min_length, ), ) ) - sorted_node_list = list( - filter(lambda x: x.data[0] not in unbinned_neighbours, sorted_node_list) - ) - heapq.heapify(sorted_node_list) + stale.update(unbinned_neighbours) for un in unbinned_neighbours: candidates = list( run_bfs_long( un, depth, - list(bin_of_contig.keys()), + bin_of_contig.keys(), bin_of_contig, bins, smg_bin_counts, assembly_graph, normalized_tetramer_profiles, coverages, + bin_tetra_mat=_seed_tetra_mat_flp, + bin_cov_mat=_seed_cov_mat_flp, ) ) for c in candidates: heapq.heappush(sorted_node_list, DataWrap(c)) + # Fresh entry is now queued; remove from stale so it can be processed + stale.discard(un) return bins, bin_of_contig, bin_markers, binned_contigs_with_markers diff --git a/src/metacoag/metacoag_utils/marker_gene_utils.py b/src/metacoag/metacoag_utils/marker_gene_utils.py index 72097ba..105dee6 100755 --- a/src/metacoag/metacoag_utils/marker_gene_utils.py +++ b/src/metacoag/metacoag_utils/marker_gene_utils.py @@ -223,6 +223,8 @@ def get_contigs_with_marker_genes_megahit( # Contig name contig_name = "_".join(name_strings) + if contig_name not in graph_to_contig_map_rev: + continue contig_num = contig_names_rev[graph_to_contig_map_rev[contig_name]] contig_length = contig_lengths[contig_num] diff --git a/src/metacoag/metacoag_utils/matching_utils.py b/src/metacoag/metacoag_utils/matching_utils.py index ec19fdc..557fae8 100755 --- a/src/metacoag/metacoag_utils/matching_utils.py +++ b/src/metacoag/metacoag_utils/matching_utils.py @@ -1,12 +1,15 @@ #!/usr/bin/env python3 +import concurrent.futures import logging import math import operator import sys import networkx as nx +import numpy as np from scipy.spatial import distance +from scipy.special import gammaln __author__ = "Vijini Mallawaarachchi and Yu Lin" __copyright__ = "Copyright 2020, MetaCoAG Project" @@ -28,10 +31,9 @@ def normpdf(x, mean, sd): - var = float(sd) ** 2 - denom = sd * (2 * math.pi) ** 0.5 - num = math.exp(-((float(x) - float(mean)) ** 2) / (2 * var)) - return num / denom + # Vectorised via numpy; scalar inputs also work. + x = np.asarray(x, dtype=float) + return np.exp(-0.5 * ((x - mean) / sd) ** 2) / (sd * np.sqrt(2.0 * np.pi)) def get_tetramer_distance(seq1, seq2): @@ -45,34 +47,77 @@ def get_coverage_distance(cov1, cov2): def get_comp_probability(tetramer_dist): gaus_intra = normpdf(tetramer_dist, MU_INTRA, SIGMA_INTRA) gaus_inter = normpdf(tetramer_dist, MU_INTER, SIGMA_INTER) - return gaus_intra / (gaus_intra + gaus_inter) + return float(gaus_intra / (gaus_intra + gaus_inter)) def get_cov_probability(cov1, cov2): - poisson_prod_1 = 1 - poisson_prod_2 = 1 - - for i in range(len(cov1)): - # Adapted from http://www.masaers.com/2013/10/08/Implementing-Poisson-pmf.html - poisson_pmf_1 = math.exp( - (cov1[i] * math.log(cov2[i])) - math.lgamma(cov1[i] + 1.0) - cov2[i] - ) - - poisson_pmf_2 = math.exp( - (cov2[i] * math.log(cov1[i])) - math.lgamma(cov2[i] + 1.0) - cov1[i] - ) - - if poisson_pmf_1 < VERY_SMALL_DOUBLE: - poisson_pmf_1 = VERY_SMALL_DOUBLE - - if poisson_pmf_2 < VERY_SMALL_DOUBLE: - poisson_pmf_2 = VERY_SMALL_DOUBLE - - poisson_prod_1 = poisson_prod_1 * poisson_pmf_1 - - poisson_prod_2 = poisson_prod_2 * poisson_pmf_2 - - return min(poisson_prod_1, poisson_prod_2) + # Vectorised Poisson PMF computation. + # Adapted from http://www.masaers.com/2013/10/08/Implementing-Poisson-pmf.html + c1 = np.asarray(cov1, dtype=float) + c2 = np.asarray(cov2, dtype=float) + # Guard against log(0): replace zeros with a tiny positive value + safe_c1 = np.where(c1 > 0, c1, VERY_SMALL_DOUBLE) + safe_c2 = np.where(c2 > 0, c2, VERY_SMALL_DOUBLE) + log_pmf_1 = c1 * np.log(safe_c2) - gammaln(c1 + 1.0) - c2 + log_pmf_2 = c2 * np.log(safe_c1) - gammaln(c2 + 1.0) - c1 + pmf_1 = np.maximum(np.exp(log_pmf_1), VERY_SMALL_DOUBLE) + pmf_2 = np.maximum(np.exp(log_pmf_2), VERY_SMALL_DOUBLE) + return float(min(np.prod(pmf_1), np.prod(pmf_2))) + + +def _build_bin_matrices(bins, n_bins, normalized_tetramer_profiles, coverages): + """Stack bin members into numpy arrays for vectorised per-iteration scoring.""" + bin_tetra_mat = {} + bin_cov_mat = {} + for b in range(n_bins): + members = bins[b] + bin_tetra_mat[b] = np.array([normalized_tetramer_profiles[c] for c in members]) + bin_cov_mat[b] = np.array([coverages[c] for c in members], dtype=float) + return bin_tetra_mat, bin_cov_mat + + +def _compute_edge_weight_exact(tetra_contig, cov_contig, bin_tetra_mat, bin_cov_mat): + """Vectorised, exact equivalent of the original per-member scoring loop. + + Computes mean(-log10(p_comp_j) - log10(p_cov_j)) across all N bin members + using numpy/scipy in C — identical numerical result to the Python for-j loop, + including the float-overflow -> MAX_WEIGHT behaviour. + """ + # All N tetramer distances in one cdist sweep + dists = distance.cdist([tetra_contig], bin_tetra_mat, "euclidean")[0] # (N,) + + # Composition probabilities for all N members simultaneously + gi = np.exp(-0.5 * (dists / SIGMA_INTRA) ** 2) / (SIGMA_INTRA * np.sqrt(2.0 * np.pi)) + ge = np.exp(-0.5 * ((dists - MU_INTER) / SIGMA_INTER) ** 2) / (SIGMA_INTER * np.sqrt(2.0 * np.pi)) + prob_comp_vec = gi / (gi + ge) # (N,) + + # Coverage probabilities: vectorised Poisson PMF over all N members + c1 = np.asarray(cov_contig, dtype=float) # (S,) + mat_c = bin_cov_mat # (N, S) + s1 = np.where(c1 > 0, c1, VERY_SMALL_DOUBLE) + s2 = np.where(mat_c > 0, mat_c, VERY_SMALL_DOUBLE) + lp1 = c1 * np.log(s2) - gammaln(c1 + 1.0) - mat_c # (N, S) + lp2 = mat_c * np.log(s1) - gammaln(mat_c + 1.0) - c1 # (N, S) + prod1 = np.prod(np.maximum(np.exp(lp1), VERY_SMALL_DOUBLE), axis=1) # (N,) + prod2 = np.prod(np.maximum(np.exp(lp2), VERY_SMALL_DOUBLE), axis=1) # (N,) + prob_cov_vec = np.minimum(prod1, prod2) # (N,) + + # Per-member log probabilities — same formula as the original scalar path + prob_product_vec = prob_comp_vec * prob_cov_vec + mask = prob_product_vec > 0.0 + log_probs = np.where( + mask, + -(np.log10(np.where(mask, prob_comp_vec, 1.0)) + + np.log10(np.where(mask, prob_cov_vec, 1.0))), + MAX_WEIGHT, + ) # (N,) + + # Reproduce original overflow check: any MAX_WEIGHT entry pushes the sum + # to inf, causing the same MAX_WEIGHT result as the original loop. + log_prob_sum = float(np.sum(log_probs)) + if math.isinf(log_prob_sum): + return MAX_WEIGHT + return log_prob_sum / len(bin_tetra_mat) def match_contigs( @@ -122,6 +167,13 @@ def match_contigs( binned_count = 0 + # Build per-bin member matrices once per iteration so all members + # assigned in previous iterations are included. Rebuilt next + # iteration automatically. + bin_tetra_mat, bin_cov_mat = _build_bin_matrices( + bins, n_bins, normalized_tetramer_profiles, coverages + ) + if len(to_bin) != 0: for contig in to_bin: contigid = contig @@ -129,39 +181,17 @@ def match_contigs( if contigid not in top_nodes: top_nodes.append(contigid) - for b in range(n_bins): - log_prob_sum = 0 - n_contigs = len(bins[b]) - - for j in range(n_contigs): - tetramer_dist = get_tetramer_distance( - normalized_tetramer_profiles[contigid], - normalized_tetramer_profiles[bins[b][j]], - ) - prob_comp = get_comp_probability(tetramer_dist) - prob_cov = get_cov_probability( - coverages[contigid], coverages[bins[b][j]] - ) - - prob_product = prob_comp * prob_cov - - log_prob = 0 - - if prob_product > 0.0: - log_prob = -( - math.log(prob_comp, 10) + math.log(prob_cov, 10) - ) - else: - log_prob = MAX_WEIGHT - - log_prob_sum += log_prob + tetra_contig = normalized_tetramer_profiles[contigid] + cov_contig = coverages[contigid] - if log_prob_sum != float("inf"): - edges.append( - (bins[b][0], contigid, log_prob_sum / n_contigs) - ) - else: - edges.append((bins[b][0], contigid, MAX_WEIGHT)) + for b in range(n_bins): + edge_weight = _compute_edge_weight_exact( + tetra_contig, + cov_contig, + bin_tetra_mat[b], + bin_cov_mat[b], + ) + edges.append((bins[b][0], contigid, edge_weight)) B.add_nodes_from(top_nodes, bipartite=0) B.add_nodes_from(bottom_nodes, bipartite=1) @@ -195,15 +225,15 @@ def match_contigs( my_matching[l] not in bins[b] and (l, my_matching[l]) in edge_weights ): - path_len_sum = 0 - - for contig_in_bin in bins[b]: - shortest_paths = assembly_graph.get_shortest_paths( - my_matching[l], to=contig_in_bin - ) - - if len(shortest_paths) != 0: - path_len_sum += len(shortest_paths[0]) + # Batch all targets in the bin into a single + # igraph distances() call (one BFS sweep). + all_paths = assembly_graph.distances( + my_matching[l], target=bins[b] + ) + # distances() returns a 2-D list; row 0 for our source + path_len_sum = sum( + d for d in all_paths[0] if d != float("inf") + ) avg_path_len = math.floor(path_len_sum / len(bins[b])) @@ -289,19 +319,15 @@ def match_contigs( ] if longest_nb_contig != -1: - path_len_sum = 0 - - for contig_in_bin in bins[not_binned[longest_nb_contig][1]]: - shortest_paths = assembly_graph.get_shortest_paths( - longest_nb_contig, to=contig_in_bin - ) - - if len(shortest_paths) != 0: - path_len_sum += len(shortest_paths[0]) - - avg_path_len = path_len_sum / len( - bins[not_binned[longest_nb_contig][1]] + target_bin = not_binned[longest_nb_contig][1] + all_paths = assembly_graph.distances( + longest_nb_contig, target=bins[target_bin] ) + path_len_sum = sum( + d for d in all_paths[0] if d != float("inf") + ) + + avg_path_len = path_len_sum / len(bins[target_bin]) if math.floor(avg_path_len) >= d_limit or path_len_sum == 0: logger.debug("Creating new bin...") @@ -345,6 +371,35 @@ def match_contigs( return bins, bin_of_contig, n_bins, bin_markers, binned_contigs_with_markers +def _score_contig_against_bins_exact(args): + """Worker: exact vectorised scoring of one contig against its candidate bins. + + Uses _compute_edge_weight_exact (numpy/cdist) — identical result to the + original Python for-j loop, no approximation. + Returns (contigid, best_bin_index, best_weight) or None. + """ + ( + contigid, + possible_bins, + tetra_contig, + cov_contig, + bin_tetra_mat, + bin_cov_mat, + w_intra, + ) = args + + bin_weights = [ + _compute_edge_weight_exact(tetra_contig, cov_contig, bin_tetra_mat[b], bin_cov_mat[b]) + for b in possible_bins + ] + + min_b_index, min_b_value = min(enumerate(bin_weights), key=operator.itemgetter(1)) + + if min_b_value <= w_intra: + return (contigid, possible_bins[min_b_index], min_b_value) + return None + + def further_match_contigs( unbinned_mg_contigs, min_length, @@ -357,69 +412,53 @@ def further_match_contigs( normalized_tetramer_profiles, coverages, w_intra, + nthreads=1, ): + # Build per-bin member matrices once. The scoring phase is read-only so all + # workers can share these safely. Assignments happen after scoring completes, + # so no incremental updates are needed. + bin_tetra_mat, bin_cov_mat = _build_bin_matrices( + bins, len(bins), normalized_tetramer_profiles, coverages + ) + + # Build work items for the parallel scoring phase. + work_items = [] for contig in unbinned_mg_contigs: - if contig[1] >= min_length: - possible_bins = [] - - for b in bin_markers: - common_mgs = list( - set(bin_markers[b]).intersection(set(contig_markers[contig[0]])) - ) - if len(common_mgs) == 0: - possible_bins.append(b) - - if len(possible_bins) != 0: - contigid = contig[0] - - bin_weights = [] - - for b in possible_bins: - log_prob_sum = 0 - n_contigs = len(bins[b]) - - for j in range(n_contigs): - tetramer_dist = get_tetramer_distance( - normalized_tetramer_profiles[contigid], - normalized_tetramer_profiles[bins[b][j]], - ) - prob_comp = get_comp_probability(tetramer_dist) - prob_cov = get_cov_probability( - coverages[contigid], coverages[bins[b][j]] - ) - - prob_product = prob_comp * prob_cov - - log_prob = 0 - - if prob_product > 0.0: - log_prob = -( - math.log(prob_comp, 10) + math.log(prob_cov, 10) - ) - else: - log_prob = MAX_WEIGHT - - log_prob_sum += log_prob - - if log_prob_sum != float("inf"): - bin_weights.append(log_prob_sum / n_contigs) - else: - bin_weights.append(MAX_WEIGHT) - - min_b_index, min_b_value = min( - enumerate(bin_weights), key=operator.itemgetter(1) - ) - - if min_b_value <= w_intra: - bins[possible_bins[min_b_index]].append(contigid) - bin_of_contig[contigid] = possible_bins[min_b_index] - binned_contigs_with_markers.append(contigid) - - bin_markers[possible_bins[min_b_index]] = list( - set( - bin_markers[possible_bins[min_b_index]] - + contig_markers[contigid] - ) - ) + if contig[1] < min_length: + continue + contigid = contig[0] + contig_mg_set = set(contig_markers[contigid]) + possible_bins = [ + b for b in bin_markers if not set(bin_markers[b]).intersection(contig_mg_set) + ] + if not possible_bins: + continue + work_items.append(( + contigid, + possible_bins, + normalized_tetramer_profiles[contigid], + coverages[contigid], + bin_tetra_mat, + bin_cov_mat, + w_intra, + )) + + # Score all contigs in parallel; exact per-member computation in each worker. + with concurrent.futures.ProcessPoolExecutor(max_workers=nthreads) as executor: + results = list(executor.map(_score_contig_against_bins_exact, work_items)) + + for result in results: + if result is None: + continue + contigid, best_bin, _ = result + # Guard: contig may appear in multiple work items + if contigid in bin_of_contig: + continue + bins[best_bin].append(contigid) + bin_of_contig[contigid] = best_bin + binned_contigs_with_markers.append(contigid) + bin_markers[best_bin] = list( + set(bin_markers[best_bin] + contig_markers[contigid]) + ) return bins, bin_of_contig, n_bins, bin_markers, binned_contigs_with_markers