diff --git a/gtranslate/training/ground_truth_by_taxonomy.py b/gtranslate/training/ground_truth_by_taxonomy.py index b914378..6f6d278 100755 --- a/gtranslate/training/ground_truth_by_taxonomy.py +++ b/gtranslate/training/ground_truth_by_taxonomy.py @@ -9,7 +9,9 @@ # # This function requires a taxonomy file that indicates the taxonomic assignment of each genome. This # can either be a 2 column TSV file with the headers "Genome ID" and "Taxonomy", or a 3 column TSV -# file with the headers "Genome ID", "GTDB taxonomy", and "NCBI taxonomy". +# file with the headers "Genome ID", "GTDB taxonomy", and "NCBI taxonomy". Taxonomy strings must be +# in Greengenes-style and indicate all ranks, e.g. +# d__Bacteria;p__Bacillota;c__Bacilli;o__Bacillales;f__Bacillaceae;g__Bacillus;s__Bacillus subtilis __prog_name__ = 'ground_truth_by_taxonomy.py' __prog_desc__ = 'Determine the ground truth for genomes based on their taxonomic classification.' @@ -18,7 +20,7 @@ __copyright__ = 'Copyright 2026' __credits__ = ['Donovan Parks'] __license__ = 'GPL3' -__version__ = '0.1.0' +__version__ = '0.1.1' __maintainer__ = 'Donovan Parks' __email__ = 'donovan.parks@gmail.com' __status__ = 'Development' @@ -26,6 +28,7 @@ import logging import argparse +import gzip from collections import defaultdict from gtranslate.biolib_lite.logger import logger_setup @@ -41,10 +44,15 @@ def __init__(self): self.GTDB_TT25 = set(['c__JAEDAM01']) self.GTDB_TT4 = set(['o__Mycoplasmatales', 's__Zinderia insecticola']) - # Eggerthellacea genera using table 4; will need to be updated to names ub Parks et al., 2026 + # Eggerthellacea genera using table 4; will need to be updated to names in Parks et al., 2026 # once these appear in GTDB self.GTDB_TT4.update(set(['g__CAVGFB01', 'g__JAUNQF01'])) + # Minisyncoccia family identified in gTranslate manuscript that uses table 4. The majority, but not + # all genomes in g__GCA-2747955 were also identified as using table 4. Currently, this is handled by + # explicitly indicating the species in this genus identified as using table 4. + self.GTDB_TT4.update(set(['f__JAKLIH01', 's__GCA-2747955 sp027024305', 's__GCA-2747955 sp027039745', 's__GCA-2747955 sp947311625'])) + # Must include the Fastidiosibacteraceae XS4 species cluster once (if) this genome appears in GTDB: # - https://www.ncbi.nlm.nih.gov/nuccore/AP038919.1 # - https://pmc.ncbi.nlm.nih.gov/articles/PMC12213064 @@ -58,13 +66,40 @@ def __init__(self): self.log = logging.getLogger('timestamp') - def run(self, taxonomy_file: str, out_file: str) -> None: + def parse_manual_ground_truth_file(self, manual_gt_file: str) -> dict: + """Parse manual ground truth file.""" + + manual_ground_truth = {} + open_file = gzip.open if manual_gt_file.endswith('.gz') else open + with open_file(manual_gt_file, 'rt') as f: + header = f.readline().strip().split('\t') + gid_idx = header.index("Genome ID") + gt_idx = header.index("Translation table") + + for line in f: + tokens = line.strip().split('\t') + manual_ground_truth[tokens[gid_idx]] = tokens[gt_idx] + + return manual_ground_truth + + def run(self, taxonomy_file: str, manual_gt_file: str, out_file: str) -> None: """Determine the ground truth for genomes based on their taxonomic classification.""" + # read files with manually specific ground truth + manual_ground_truth = {} + if manual_gt_file: + self.log.info('Parsing manual ground truth file:') + manual_ground_truth = self.parse_manual_ground_truth_file(manual_gt_file) + self.log.info(f' - identified manual ground truth for {len(manual_ground_truth):,} genomes') + + # determine ground truth for genomes based on their taxonomic classification + self.log.info('Determining ground truth for genomes:') total_genomes = 0 gt_table_count = defaultdict(int) + num_by_manual_gt = 0 - with open(taxonomy_file) as f: + open_file = gzip.open if taxonomy_file.endswith('.gz') else open + with open_file(taxonomy_file, 'rt') as f: header = f.readline().strip().split('\t') if 'Genome ID' in header: @@ -100,35 +135,40 @@ def run(self, taxonomy_file: str, out_file: str) -> None: total_genomes += 1 - # determine ground truth translation table based on GTDB - # or NCBI taxonomic classification of genome - taxa = set() - if taxonomy_idx: - taxa.update(set(tokens[taxonomy_idx].split(';'))) - - gtdb_taxa = set() - if gtdb_taxonomy_idx: - gtdb_taxa.update(set(tokens[gtdb_taxonomy_idx].split(';'))) + gid = tokens[gid_idx] - ncbi_taxa = set() - if ncbi_taxonomy_idx: - ncbi_taxa.update(set(tokens[ncbi_taxonomy_idx].split(';'))) - - if taxa.intersection(self.GTDB_TT25) or gtdb_taxa.intersection(self.GTDB_TT25): - ground_truth_tt = '25' - elif taxa.intersection(self.GTDB_TT4) or gtdb_taxa.intersection(self.GTDB_TT4): - ground_truth_tt = '4' - elif taxa.intersection(self.NCBI_TT4) or ncbi_taxa.intersection(self.NCBI_TT4): - ground_truth_tt = '4' - elif taxa.intersection(self.GTDB_UNRESOLVED) or gtdb_taxa.intersection(self.GTDB_UNRESOLVED): - ground_truth_tt = 'UNRESOLVED' + if gid in manual_ground_truth: + ground_truth_tt = manual_ground_truth[gid] + num_by_manual_gt += 1 else: - ground_truth_tt = '11' + # determine ground truth translation table based on GTDB + # or NCBI taxonomic classification of genome + taxa = set() + if taxonomy_idx: + taxa.update(set(tokens[taxonomy_idx].split(';'))) + + gtdb_taxa = set() + if gtdb_taxonomy_idx: + gtdb_taxa.update(set(tokens[gtdb_taxonomy_idx].split(';'))) + + ncbi_taxa = set() + if ncbi_taxonomy_idx: + ncbi_taxa.update(set(tokens[ncbi_taxonomy_idx].split(';'))) + + if taxa.intersection(self.GTDB_TT25) or gtdb_taxa.intersection(self.GTDB_TT25): + ground_truth_tt = '25' + elif taxa.intersection(self.GTDB_TT4) or gtdb_taxa.intersection(self.GTDB_TT4): + ground_truth_tt = '4' + elif taxa.intersection(self.NCBI_TT4) or ncbi_taxa.intersection(self.NCBI_TT4): + ground_truth_tt = '4' + elif taxa.intersection(self.GTDB_UNRESOLVED) or gtdb_taxa.intersection(self.GTDB_UNRESOLVED): + ground_truth_tt = 'UNRESOLVED' + else: + ground_truth_tt = '11' gt_table_count[ground_truth_tt] += 1 # write out ground truth results - gid = tokens[gid_idx] fout.write(f'{gid}\t{ground_truth_tt}') if taxonomy_idx: @@ -145,7 +185,9 @@ def run(self, taxonomy_file: str, out_file: str) -> None: fout.close() # write out number of genomes assigned to each translation table - self.log.info(f'Total genomes: {total_genomes:,}') + self.log.info(f' - determined ground truth for {total_genomes:,} genomes') + if manual_gt_file: + self.log.info(f' - ground truth set manually for {num_by_manual_gt:,} genomes') for tran_table, genome_count in sorted(gt_table_count.items()): self.log.info(f'Table {tran_table}: {genome_count:,} ({100*genome_count/total_genomes:.2f}%)') @@ -158,6 +200,7 @@ def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--taxonomy_file', required=True, help='File indicating taxonomic classification of each genome.') parser.add_argument('--out_file', required=True, help='Output file to write ground truth translation table.') + parser.add_argument('--manual_gt_file', help='File indicating manually specific ground truth for select genomes.') args = parser.parse_args() @@ -166,7 +209,7 @@ def main(): # run program p = GroundTruthByTaxonomy() - p.run(args.taxonomy_file, args.out_file) + p.run(args.taxonomy_file, args.manual_gt_file, args.out_file) if __name__ == '__main__':