Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions scripts/gtdb_training_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import logging
import argparse
import gzip
import random
from collections import defaultdict
from dataclasses import dataclass

Expand Down Expand Up @@ -65,25 +66,37 @@ def __init__(self):
# genomes to retain without needing to pass QC
self.QC_EXCEPTIONS = set(['G015697925']) # s__Pinguicoccus supinus representative genome

# maximum number of genomes to select per species
self.MAX_PER_SPECIES = 100

self.log = logging.getLogger('timestamp')

def parse_metadata_file(self, gtdb_metadata_file: str) -> dict:
def parse_metadata_file(self, gtdb_metadata_file: str) -> tuple[dict, dict]:
"""Parse manual ground truth file."""

genome_data = {}
gtdb_sp_rid = {}
open_file = gzip.open if gtdb_metadata_file.endswith('.gz') else open
with open_file(gtdb_metadata_file, 'rt') as f:
header = f.readline().strip().split('\t')

gid_idx = header.index("accession")
gtdb_taxonomy_idx = header.index('gtdb_taxonomy')
ncbi_taxonomy_idx = header.index('ncbi_taxonomy')

gtdb_rep_idx = header.index('gtdb_representative')

for line in f:
tokens = line.strip().split('\t')
gid = canonical_gid(tokens[gid_idx])
genome_data[gid] = GenomeData(tokens[gtdb_taxonomy_idx], tokens[ncbi_taxonomy_idx])

return genome_data
is_rep = tokens[gtdb_rep_idx].lower().startswith('t')
if is_rep:
gtdb_sp = [t.strip() for t in tokens[gtdb_taxonomy_idx].split(';')][-1]
gtdb_sp_rid[gtdb_sp] = gid

return genome_data, gtdb_sp_rid

def parse_taxonomy_file(self, taxonomy_file: str) -> dict:
"""Parse taxonomy file."""
Expand Down Expand Up @@ -175,12 +188,33 @@ def run(self, gtdb_bac_metadata_file: str,

# parse genomes from GTDB metadata files
self.log.info('Parsing GTDB metadata files:')
gtdb_bac_gids = self.parse_metadata_file(gtdb_bac_metadata_file)
gtdb_ar_gids = self.parse_metadata_file(gtdb_ar_metadata_file)
gtdb_bac_gids, gtdb_bac_sp_rid = self.parse_metadata_file(gtdb_bac_metadata_file)
gtdb_ar_gids, gtdb_ar_sp_rid = self.parse_metadata_file(gtdb_ar_metadata_file)
gtdb_gids = {**gtdb_bac_gids, **gtdb_ar_gids}
self.log.info(f' - identified {len(gtdb_bac_gids):,} bacterial genomes')
self.log.info(f' - identified {len(gtdb_ar_gids):,} archaeal genomes')
self.log.info(f' - identified {len(gtdb_gids):,} total genomes')
gtdb_sp_rid = {**gtdb_bac_sp_rid, **gtdb_ar_sp_rid}
self.log.info(f' - identified {len(gtdb_bac_gids):,} bacterial genomes from {len(gtdb_bac_sp_rid):,} species')
self.log.info(f' - identified {len(gtdb_ar_gids):,} archaeal genomes from {len(gtdb_ar_sp_rid):,} species')
self.log.info(f' - identified {len(gtdb_gids):,} total genomes from {len(gtdb_sp_rid):,} species')

# subsample to maximum number of genomes per species
self.log.info(f'Sampling to {self.MAX_PER_SPECIES} genomes per species:')

sp_gids = defaultdict(set)
for gid, genome_data in gtdb_gids.items():
gtdb_sp = [t.strip() for t in genome_data.gtdb_taxonomy.split(';')][-1]
sp_gids[gtdb_sp].add(gid)

gtdb_gids_sampled = {}
for sp, gids in sp_gids.items():
if len(gids) > self.MAX_PER_SPECIES:
rid = gtdb_sp_rid[sp]
gids.remove(rid)
gids = [rid] + random.sample(list(gids), self.MAX_PER_SPECIES - 1)

for gid in gids:
gtdb_gids_sampled[gid] = gtdb_gids[gid]

self.log.info(f' - retained {len(gtdb_gids_sampled):,} genomes for training')

# read NCBI taxonomy for genomes
self.log.info('Parsing NCBI taxonomy file:')
Expand All @@ -198,7 +232,7 @@ def run(self, gtdb_bac_metadata_file: str,

# combine all training genomes
self.log.info('Combining all genomes to use for training:')
training_gids = {**gtdb_gids, **gtdb_reassigned_gids}
training_gids = {**gtdb_gids_sampled, **gtdb_reassigned_gids}
self.log.info(f' - identified {len(training_gids):,} total genomes for training')

# sanity check that all genomes that are a QC expection are accounted for
Expand Down
Loading