diff --git a/scripts/gtdb_training_set.py b/scripts/gtdb_training_set.py index e1b3aa2..dd68f67 100755 --- a/scripts/gtdb_training_set.py +++ b/scripts/gtdb_training_set.py @@ -23,6 +23,7 @@ import logging import argparse import gzip +import random from collections import defaultdict from dataclasses import dataclass @@ -65,25 +66,37 @@ def __init__(self): # genomes to retain without needing to pass QC self.QC_EXCEPTIONS = set(['G015697925']) # s__Pinguicoccus supinus representative genome + # maximum number of genomes to select per species + self.MAX_PER_SPECIES = 100 + self.log = logging.getLogger('timestamp') - def parse_metadata_file(self, gtdb_metadata_file: str) -> dict: + def parse_metadata_file(self, gtdb_metadata_file: str) -> tuple[dict, dict]: """Parse manual ground truth file.""" genome_data = {} + gtdb_sp_rid = {} open_file = gzip.open if gtdb_metadata_file.endswith('.gz') else open with open_file(gtdb_metadata_file, 'rt') as f: header = f.readline().strip().split('\t') + gid_idx = header.index("accession") gtdb_taxonomy_idx = header.index('gtdb_taxonomy') ncbi_taxonomy_idx = header.index('ncbi_taxonomy') + gtdb_rep_idx = header.index('gtdb_representative') + for line in f: tokens = line.strip().split('\t') gid = canonical_gid(tokens[gid_idx]) genome_data[gid] = GenomeData(tokens[gtdb_taxonomy_idx], tokens[ncbi_taxonomy_idx]) - return genome_data + is_rep = tokens[gtdb_rep_idx].lower().startswith('t') + if is_rep: + gtdb_sp = [t.strip() for t in tokens[gtdb_taxonomy_idx].split(';')][-1] + gtdb_sp_rid[gtdb_sp] = gid + + return genome_data, gtdb_sp_rid def parse_taxonomy_file(self, taxonomy_file: str) -> dict: """Parse taxonomy file.""" @@ -175,12 +188,33 @@ def run(self, gtdb_bac_metadata_file: str, # parse genomes from GTDB metadata files self.log.info('Parsing GTDB metadata files:') - gtdb_bac_gids = self.parse_metadata_file(gtdb_bac_metadata_file) - gtdb_ar_gids = self.parse_metadata_file(gtdb_ar_metadata_file) + gtdb_bac_gids, gtdb_bac_sp_rid = self.parse_metadata_file(gtdb_bac_metadata_file) + gtdb_ar_gids, gtdb_ar_sp_rid = self.parse_metadata_file(gtdb_ar_metadata_file) gtdb_gids = {**gtdb_bac_gids, **gtdb_ar_gids} - self.log.info(f' - identified {len(gtdb_bac_gids):,} bacterial genomes') - self.log.info(f' - identified {len(gtdb_ar_gids):,} archaeal genomes') - self.log.info(f' - identified {len(gtdb_gids):,} total genomes') + gtdb_sp_rid = {**gtdb_bac_sp_rid, **gtdb_ar_sp_rid} + self.log.info(f' - identified {len(gtdb_bac_gids):,} bacterial genomes from {len(gtdb_bac_sp_rid):,} species') + self.log.info(f' - identified {len(gtdb_ar_gids):,} archaeal genomes from {len(gtdb_ar_sp_rid):,} species') + self.log.info(f' - identified {len(gtdb_gids):,} total genomes from {len(gtdb_sp_rid):,} species') + + # subsample to maximum number of genomes per species + self.log.info(f'Sampling to {self.MAX_PER_SPECIES} genomes per species:') + + sp_gids = defaultdict(set) + for gid, genome_data in gtdb_gids.items(): + gtdb_sp = [t.strip() for t in genome_data.gtdb_taxonomy.split(';')][-1] + sp_gids[gtdb_sp].add(gid) + + gtdb_gids_sampled = {} + for sp, gids in sp_gids.items(): + if len(gids) > self.MAX_PER_SPECIES: + rid = gtdb_sp_rid[sp] + gids.remove(rid) + gids = [rid] + random.sample(list(gids), self.MAX_PER_SPECIES - 1) + + for gid in gids: + gtdb_gids_sampled[gid] = gtdb_gids[gid] + + self.log.info(f' - retained {len(gtdb_gids_sampled):,} genomes for training') # read NCBI taxonomy for genomes self.log.info('Parsing NCBI taxonomy file:') @@ -198,7 +232,7 @@ def run(self, gtdb_bac_metadata_file: str, # combine all training genomes self.log.info('Combining all genomes to use for training:') - training_gids = {**gtdb_gids, **gtdb_reassigned_gids} + training_gids = {**gtdb_gids_sampled, **gtdb_reassigned_gids} self.log.info(f' - identified {len(training_gids):,} total genomes for training') # sanity check that all genomes that are a QC expection are accounted for