Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 73 additions & 30 deletions gtranslate/training/ground_truth_by_taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#
# This function requires a taxonomy file that indicates the taxonomic assignment of each genome. This
# can either be a 2 column TSV file with the headers "Genome ID" and "Taxonomy", or a 3 column TSV
# file with the headers "Genome ID", "GTDB taxonomy", and "NCBI taxonomy".
# file with the headers "Genome ID", "GTDB taxonomy", and "NCBI taxonomy". Taxonomy strings must be
# in Greengenes-style and indicate all ranks, e.g.
# d__Bacteria;p__Bacillota;c__Bacilli;o__Bacillales;f__Bacillaceae;g__Bacillus;s__Bacillus subtilis

__prog_name__ = 'ground_truth_by_taxonomy.py'
__prog_desc__ = 'Determine the ground truth for genomes based on their taxonomic classification.'
Expand All @@ -18,14 +20,15 @@
__copyright__ = 'Copyright 2026'
__credits__ = ['Donovan Parks']
__license__ = 'GPL3'
__version__ = '0.1.0'
__version__ = '0.1.1'
__maintainer__ = 'Donovan Parks'
__email__ = 'donovan.parks@gmail.com'
__status__ = 'Development'


import logging
import argparse
import gzip
from collections import defaultdict

from gtranslate.biolib_lite.logger import logger_setup
Expand All @@ -41,10 +44,15 @@ def __init__(self):
self.GTDB_TT25 = set(['c__JAEDAM01'])
self.GTDB_TT4 = set(['o__Mycoplasmatales', 's__Zinderia insecticola'])

# Eggerthellacea genera using table 4; will need to be updated to names ub Parks et al., 2026
# Eggerthellacea genera using table 4; will need to be updated to names in Parks et al., 2026
# once these appear in GTDB
self.GTDB_TT4.update(set(['g__CAVGFB01', 'g__JAUNQF01']))

# Minisyncoccia family identified in gTranslate manuscript that uses table 4. The majority, but not
# all genomes in g__GCA-2747955 were also identified as using table 4. Currently, this is handled by
# explicitly indicating the species in this genus identified as using table 4.
self.GTDB_TT4.update(set(['f__JAKLIH01', 's__GCA-2747955 sp027024305', 's__GCA-2747955 sp027039745', 's__GCA-2747955 sp947311625']))

# Must include the Fastidiosibacteraceae XS4 species cluster once (if) this genome appears in GTDB:
# - https://www.ncbi.nlm.nih.gov/nuccore/AP038919.1
# - https://pmc.ncbi.nlm.nih.gov/articles/PMC12213064
Expand All @@ -58,13 +66,40 @@ def __init__(self):

self.log = logging.getLogger('timestamp')

def run(self, taxonomy_file: str, out_file: str) -> None:
def parse_manual_ground_truth_file(self, manual_gt_file: str) -> dict:
"""Parse manual ground truth file."""

manual_ground_truth = {}
open_file = gzip.open if manual_gt_file.endswith('.gz') else open
with open_file(manual_gt_file, 'rt') as f:
header = f.readline().strip().split('\t')
gid_idx = header.index("Genome ID")
gt_idx = header.index("Translation table")

for line in f:
tokens = line.strip().split('\t')
manual_ground_truth[tokens[gid_idx]] = tokens[gt_idx]

return manual_ground_truth

def run(self, taxonomy_file: str, manual_gt_file: str, out_file: str) -> None:
"""Determine the ground truth for genomes based on their taxonomic classification."""

# read files with manually specific ground truth
manual_ground_truth = {}
if manual_gt_file:
self.log.info('Parsing manual ground truth file:')
manual_ground_truth = self.parse_manual_ground_truth_file(manual_gt_file)
self.log.info(f' - identified manual ground truth for {len(manual_ground_truth):,} genomes')

# determine ground truth for genomes based on their taxonomic classification
self.log.info('Determining ground truth for genomes:')
total_genomes = 0
gt_table_count = defaultdict(int)
num_by_manual_gt = 0

with open(taxonomy_file) as f:
open_file = gzip.open if taxonomy_file.endswith('.gz') else open
with open_file(taxonomy_file, 'rt') as f:
header = f.readline().strip().split('\t')

if 'Genome ID' in header:
Expand Down Expand Up @@ -100,35 +135,40 @@ def run(self, taxonomy_file: str, out_file: str) -> None:

total_genomes += 1

# determine ground truth translation table based on GTDB
# or NCBI taxonomic classification of genome
taxa = set()
if taxonomy_idx:
taxa.update(set(tokens[taxonomy_idx].split(';')))

gtdb_taxa = set()
if gtdb_taxonomy_idx:
gtdb_taxa.update(set(tokens[gtdb_taxonomy_idx].split(';')))
gid = tokens[gid_idx]

ncbi_taxa = set()
if ncbi_taxonomy_idx:
ncbi_taxa.update(set(tokens[ncbi_taxonomy_idx].split(';')))

if taxa.intersection(self.GTDB_TT25) or gtdb_taxa.intersection(self.GTDB_TT25):
ground_truth_tt = '25'
elif taxa.intersection(self.GTDB_TT4) or gtdb_taxa.intersection(self.GTDB_TT4):
ground_truth_tt = '4'
elif taxa.intersection(self.NCBI_TT4) or ncbi_taxa.intersection(self.NCBI_TT4):
ground_truth_tt = '4'
elif taxa.intersection(self.GTDB_UNRESOLVED) or gtdb_taxa.intersection(self.GTDB_UNRESOLVED):
ground_truth_tt = 'UNRESOLVED'
if gid in manual_ground_truth:
ground_truth_tt = manual_ground_truth[gid]
num_by_manual_gt += 1
else:
ground_truth_tt = '11'
# determine ground truth translation table based on GTDB
# or NCBI taxonomic classification of genome
taxa = set()
if taxonomy_idx:
taxa.update(set(tokens[taxonomy_idx].split(';')))

gtdb_taxa = set()
if gtdb_taxonomy_idx:
gtdb_taxa.update(set(tokens[gtdb_taxonomy_idx].split(';')))

ncbi_taxa = set()
if ncbi_taxonomy_idx:
ncbi_taxa.update(set(tokens[ncbi_taxonomy_idx].split(';')))

if taxa.intersection(self.GTDB_TT25) or gtdb_taxa.intersection(self.GTDB_TT25):
ground_truth_tt = '25'
elif taxa.intersection(self.GTDB_TT4) or gtdb_taxa.intersection(self.GTDB_TT4):
ground_truth_tt = '4'
elif taxa.intersection(self.NCBI_TT4) or ncbi_taxa.intersection(self.NCBI_TT4):
ground_truth_tt = '4'
elif taxa.intersection(self.GTDB_UNRESOLVED) or gtdb_taxa.intersection(self.GTDB_UNRESOLVED):
ground_truth_tt = 'UNRESOLVED'
else:
ground_truth_tt = '11'

gt_table_count[ground_truth_tt] += 1

# write out ground truth results
gid = tokens[gid_idx]
fout.write(f'{gid}\t{ground_truth_tt}')

if taxonomy_idx:
Expand All @@ -145,7 +185,9 @@ def run(self, taxonomy_file: str, out_file: str) -> None:
fout.close()

# write out number of genomes assigned to each translation table
self.log.info(f'Total genomes: {total_genomes:,}')
self.log.info(f' - determined ground truth for {total_genomes:,} genomes')
if manual_gt_file:
self.log.info(f' - ground truth set manually for {num_by_manual_gt:,} genomes')

for tran_table, genome_count in sorted(gt_table_count.items()):
self.log.info(f'Table {tran_table}: {genome_count:,} ({100*genome_count/total_genomes:.2f}%)')
Expand All @@ -158,6 +200,7 @@ def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--taxonomy_file', required=True, help='File indicating taxonomic classification of each genome.')
parser.add_argument('--out_file', required=True, help='Output file to write ground truth translation table.')
parser.add_argument('--manual_gt_file', help='File indicating manually specific ground truth for select genomes.')

args = parser.parse_args()

Expand All @@ -166,7 +209,7 @@ def main():

# run program
p = GroundTruthByTaxonomy()
p.run(args.taxonomy_file, args.out_file)
p.run(args.taxonomy_file, args.manual_gt_file, args.out_file)


if __name__ == '__main__':
Expand Down
Loading