Source code for ample.ensembler.truncation_util

"""Truncation utility module"""

__author__ = "Jens Thomas, and Felix Simkovic"
__date__ = "02 Mar 2016"
__version__ = "1.0"

import collections
import copy
import logging
import os

from ample.ensembler._ensembler import model_core_from_fasta 
from ample.util import ample_util
from ample.util import pdb_edit
from ample.util import theseus

logger = logging.getLogger(__name__)

# Data structure to store residue information
ScoreVariances = collections.namedtuple("ScoreVariances", ["idx", "resSeq", "variance"])


[docs]def calculate_residues_focussed(var_by_res): """ The sweet spot for success seems to occur in the interval 5-40 residues. Up till now we have always worked in 5% intervals, so 20 truncation levels The new strategy is to ensure that always have at least half of the truncations in the interval < 40 residues => 10 truncations in 40, so at least 4 residue chunks in this interval. The strategy is therefore for < 80 residues, just split evenly into 20 chunks. For > 80 residues, split < 40 into 10 4-residue chunks, and split the interval 40 -> end into 10 even chunks. """ length = len(var_by_res) if length <= 80: # Just split evenly into 20 chunks return calculate_residues_percent(var_by_res, 5) # Get list of residue indices sorted by variance - from least variable to most var_by_res.sort(key=lambda x: x.variance, reverse=False) # Split a 40 - length interval into 10 even chunks. llen = 40 lower_start = _split_sequence(llen, 10) # Split remaining interval into 10 even chunks. We need to add the start sequence as we have # removed llen residues ulen = length - llen upper_start = [ i + llen for i in _split_sequence(ulen, 10) ] start_indexes = upper_start + lower_start # Calculate the percentages for each of these start points percentages = [ int(round(float(start + 1) / float(length) * 100)) for start in start_indexes ] # print "percentages ", percentages truncation_levels = percentages # print "var_by_res ",var_by_res idxs_all = [ x.idx for x in var_by_res ] resseq_all = [ x.resSeq for x in var_by_res ] variances = [ x.variance for x in var_by_res ] truncation_residue_idxs = [ sorted(idxs_all[:i + 1]) for i in start_indexes ] # print "truncation_residue_idxs ",truncation_residue_idxs truncation_residues = [ sorted(resseq_all[:i + 1]) for i in start_indexes ] # print "truncation_residues ",truncation_residues # We take the variance of the most variable residue truncation_variances = [ variances[i] for i in start_indexes ] # print "truncation_variances ",truncation_variances return truncation_levels, truncation_variances, truncation_residues, truncation_residue_idxs
[docs]def calculate_residues_percent(var_by_res, percent_interval): """Calculate the list of residues to keep if we are keeping self.percent residues under each truncation bin. The threshold is just the threshold of the most variable residue""" MIN_CHUNK = 3 # We need at least 3 residues for theseus to work length = len(var_by_res) start_idxs = _split_sequence(length, percent_interval, min_chunk=MIN_CHUNK) # Get list of residue indices sorted by variance - from least to most var_by_res.sort(key=lambda x: x.variance, reverse=False) # print "var_by_res ",var_by_res idxs_all = [ x.idx for x in var_by_res ] resseq_all = [ x.resSeq for x in var_by_res ] variances = [ x.variance for x in var_by_res ] # Get list of residues to keep under the different intevals truncation_levels = [] truncation_variances = [] truncation_residues = [] truncation_residue_idxs = [] for start in start_idxs: percent = int(round(float(start + 1) / float(length) * 100)) residues = resseq_all[:start + 1] idxs = resseq_all[:start + 1] idxs = idxs_all[:start + 1] thresh = variances[start] # For the threshold we take the threshold of the most variable residue truncation_variances.append(thresh) truncation_levels.append(percent) truncation_residues.append(sorted(residues)) truncation_residue_idxs.append(sorted(idxs)) return truncation_levels, truncation_variances, truncation_residues, truncation_residue_idxs
[docs]def calculate_residues_thresh(var_by_res, percent_interval): """Txxx """ # calculate the thresholds truncation_variances = generate_thresholds(var_by_res, percent_interval) # We run in reverse as that's how the original code worked truncation_residues = [] truncation_residue_idxs = [] truncation_levels = [] lt = len(truncation_variances) for i, truncation_threshold in enumerate(truncation_variances): truncation_level = lt - i # as going backwards truncation_levels.append(truncation_level) # Get a list of the indexes of the residues to keep to_keep = [ x.resSeq for x in var_by_res if x.variance <= truncation_threshold ] to_keep_idxs = [ x.idx for x in var_by_res if x.variance <= truncation_threshold ] truncation_residues.append(to_keep) truncation_residue_idxs.append(to_keep_idxs) # We went through in reverse so put things the right way around truncation_levels.reverse() truncation_variances.reverse() truncation_residues.reverse() truncation_residue_idxs.reverse() return truncation_levels, truncation_variances, truncation_residues, truncation_residue_idxs
[docs]def generate_thresholds(var_by_res, percent_interval): """ This is the original method developed by Jaclyn and used in all work until November 2014 (including the coiled-coil paper) Calculate the residue variance thresholds that will keep self.percent_interval residues for each truncation level """ #-------------------------------- # choose threshold type #------------------------------- FIXED_INTERVALS = False if FIXED_INTERVALS: thresholds = [ 1, 1.5, 2 , 2.5, 3, 3.5, 4, 4.5, 5, 5.5, 6, 7, 8 ] logger.debug("Got {0} thresholds: {1}".format(len(thresholds), thresholds)) return # List of variances ordered by residue index var_list = [ x.variance for x in var_by_res] length = len(var_list) if length == 0: msg = "Error generating thresholds, got len: {0}".format(length) logger.critical(msg) raise RuntimeError, msg # How many residues should fit in each bin # NB - Should round up not down with int! chunk_size = int((float(length) / 100) * float(percent_interval)) if chunk_size < 1: msg = "Error generating thresholds, got < 1 AA in chunk_size" logger.critical(msg) raise RuntimeError, msg # # try to find intervals for truncation truncation_thresholds = _generate_thresholds(var_list, chunk_size) # Jens' new untested method # truncation_thresholds=self._generate_thresholds2(var_list, chunk_size) logger.debug("Got {0} thresholds: {1}".format(len(truncation_thresholds), truncation_thresholds)) return truncation_thresholds
def _generate_thresholds(values, chunk_size): """Jaclyn's threshold method """ try_list = copy.deepcopy(values) try_list.sort() # print "try_list ",try_list # print list(chunks(try_list, int(chunk_size))) # For chunking list def chunks(a_list, chunk_size): for i in xrange(0, len(a_list), chunk_size): yield a_list[i:i + chunk_size] thresholds = [] for x in list(chunks(try_list, chunk_size)): # print x, x[-1] # For some cases, multiple residues share the same variance so we don't create a separate thereshold if x[-1] not in thresholds: thresholds.append(x[-1]) return thresholds def _generate_thresholds2(values, chunk_size): """ This is Jens's update to Jaclyn's method that groups the residues by variances so that we split them by variance, and try and fit chunk_size in each bin. Previously we tried to split by variance but didn't group the residues by variance, so the same variance bin could cover multiple residue groups. """ # Create tuple mapping values to counts data = [(i, values.count(i)) for i in sorted(set(values), reverse=True)] thresholds = [] counts = [] first = True for variance, count in data: if first or counts[-1] + count > chunk_size: thresholds.append(variance) counts.append(count) if first: first = False else: # thresholds[-1]=variance counts[-1] += count thresholds.sort() return thresholds
[docs]def prune_residues(residues, chunk_size=1, allowed_gap=2): """Remove any residues that are < chunk_size where the gap before and after is > allowed_gap""" assert chunk_size > 0 and allowed_gap > 0, \ "chunk_size and allowed_gap must be > 0!: {0} {1}".format(chunk_size, allowed_gap) if not len(residues): return residues, None lenr = len(residues) if lenr <= chunk_size: return [], residues # Build up a list of residues to remove to_remove = [] start = residues[0] last = residues[0] this_residue = None last_chunk_end = residues[0] - (allowed_gap + 1) # make sure starting gap is bigger than allowed idxLast = lenr - 1 for i in xrange(1, idxLast+1): this_residue = residues[i] if i == idxLast or this_residue != last + 1: if i == idxLast and this_residue != last + 1: start = this_residue last_chunk_end = last last = this_residue postgap = allowed_gap + 1 elif i == idxLast and this_residue == last + 1: last = this_residue postgap = allowed_gap + 1 elif i != idxLast and this_residue != last + 1: postgap = (this_residue - last) - 1 pregap = (start - last_chunk_end) - 1 this_chunk_size = (last - start) + 1 # remove if it satisfies the requirements if (this_chunk_size <= chunk_size and pregap >= allowed_gap and postgap >= allowed_gap): chunk = [x for x in range(start, last + 1)] to_remove += chunk # reset start and last_chunk_end start = this_residue last_chunk_end = last last = this_residue # Remove the chunks and return if len(to_remove): return [r for r in residues if r not in to_remove], to_remove else: return residues, None
def _split_sequence(length, percent_interval, min_chunk=3): """split a sequence of length into chunks each separated by percent_interval each being at least min_chunk size""" if length <= min_chunk: return [length - 1] # How many residues should fit in each bin chunk_size = int(round(float(length) * float(percent_interval) / 100.0)) if chunk_size <= 0: return [length - 1] idxs = [length - 1] while True: start = idxs[-1] - chunk_size if start <= 0: break remainder = start + 1 if remainder >= min_chunk: idxs.append(start) else: break return idxs
[docs]class Truncation(object): """Holds information relating to a single truncation of a cluster of models""" def __init__(self): self.cluster = None # The cluster object this truncation was created from self.directory = None self.level = None self.method = None self.models = None self.percent = None self.residues = None self.residues_idxs = None self.variances = None @property def num_residues(self): return 0 if self.residues is None else len(self.residues) def __str__(self): """Return a string representation of this object.""" _str = super(Truncation, self).__str__() + "\n" # Iterate through all attributes in order for k in sorted(self.__dict__.keys()): _str += "{0} : {1}\n".format(k, self.__dict__[k]) return _str
[docs]class Truncator(object): def __init__(self, work_dir): """Class to take one or more models and truncate them based on a supplied or generated metric""" self.work_dir = work_dir self.models = None self.aligned_models = None self.truncations = None self.theseus_exe = None # We keep these for bookeeping as they go in the ample dictionary self.truncation_levels = None self.truncation_variances = None self.truncation_nresidues = None
[docs] def calculate_truncations(self, models=None, truncation_method=None, percent_truncation=None, truncation_pruning=None, residue_scores=None, alignment_file=None, homologs=False): """Returns a list of Truncation objects, one for each truncation level. This method doesn't do any truncating - it just calculates the data for each truncation level. """ assert (len(models) > 1 or residue_scores), "Cannot truncate as < 2 models!" assert truncation_method and percent_truncation, "Missing arguments: {0} : {1}".format(truncation_method, percent_truncation) assert ample_util.is_exe(self.theseus_exe),"Cannot find theseus_exe: {0}".format(self.theseus_exe) # Create the directories we'll be working in assert self.work_dir and os.path.isdir(self.work_dir), "truncate_models needs a self.work_dir" os.chdir(self.work_dir) self.models = models # Calculate variances between pdb and align them (we currently only require the aligned models for homologs) if truncation_method != "scores": run_theseus = theseus.Theseus(work_dir=self.work_dir, theseus_exe=self.theseus_exe) try: run_theseus.superpose_models(self.models, homologs=homologs, alignment_file=alignment_file) self.aligned_models = run_theseus.aligned_models except RuntimeError as e: logger.critical(e) return [] if homologs: # If using homologs, now trim down to the core. We only do this here so that we are using the aligned models from # theseus, which makes it easier to see what the truncation is doing. models = model_core_from_fasta(self.aligned_models, alignment_file=alignment_file, work_dir=os.path.join(self.work_dir,'core_models')) # Unfortunately Theseus doesn't print all residues in its output format, so we can't use the variances we calculated before and # need to calculate the variances of the core models try: run_theseus.superpose_models(models, homologs=homologs, basename='homologs_core') self.models = run_theseus.aligned_models self.aligned_models = run_theseus.aligned_models except RuntimeError as e: logger.critical(e) return [] # No THESEUS variances required if scores for each residue provided var_by_res = run_theseus.var_by_res if truncation_method != "scores" \ else self._convert_residue_scores(residue_scores) if not len(var_by_res) > 0: msg = "Error reading residue variances!" logger.critical(msg) raise RuntimeError(msg) logger.info('Using truncation method: {0}'.format(truncation_method)) # Calculate which residues to keep under the different methods if truncation_method == 'percent': truncation_levels, truncation_variances, truncation_residues, truncation_residue_idxs = calculate_residues_percent(var_by_res, percent_truncation) elif truncation_method == 'scores': truncation_levels, truncation_variances, truncation_residues, truncation_residue_idxs = calculate_residues_percent(var_by_res, percent_truncation) elif truncation_method == 'thresh': truncation_levels, truncation_variances, truncation_residues, truncation_residue_idxs = calculate_residues_thresh(var_by_res, percent_truncation) elif truncation_method == 'focussed': truncation_levels, truncation_variances, truncation_residues, truncation_residue_idxs = calculate_residues_focussed(var_by_res) else: raise RuntimeError, "Unrecognised ensembling mode: {0}".format(truncation_method) # Somewhat of a hack to save the data so we can put it in the amoptd self.truncation_levels = truncation_levels self.truncation_variances = truncation_variances self.truncation_nresidues = [len(r) for r in truncation_residues] truncations = [] for tlevel, tvar, tresidues, tresidue_idxs in zip(truncation_levels, truncation_variances, truncation_residues, truncation_residue_idxs): # Prune singletone/doubletone etc. residues if required logger.debug("truncation_pruning: {0}".format(truncation_pruning)) if truncation_pruning == 'single': tresidue_idxs, pruned_residues = prune_residues(tresidue_idxs, chunk_size=1, allowed_gap=2) if pruned_residues: logger.debug("prune_residues removing: {0}".format(pruned_residues)) elif truncation_pruning is None: pass else: raise RuntimeError("Unrecognised truncation_pruning: {0}".format(truncation_pruning)) # Skip if there are no residues if not tresidue_idxs: logger.debug("Skipping truncation level {0} with variance {1} as no residues".format(tlevel, tvar)) continue truncation = Truncation() truncation.method = truncation_method truncation.percent = percent_truncation truncation.level = tlevel truncation.variances = tvar truncation.residues = tresidues truncation.residues_idxs = tresidue_idxs truncations.append(truncation) return truncations
[docs] def truncate_models(self, models, max_cluster_size=200, truncation_method=None, percent_truncation=None, truncation_pruning=None, residue_scores=None, homologs=False, alignment_file=None, work_dir=None): """Generate a set of Truncation objects, referencing a set of truncated models generated from the supplied models""" truncations = self.calculate_truncations(models=models, truncation_method=truncation_method, percent_truncation=percent_truncation, truncation_pruning=truncation_pruning, residue_scores=residue_scores, alignment_file=alignment_file, homologs=homologs) if truncations is None or len(truncations) < 1: msg = "Unable to truncate the ensembles - no viable truncations" logger.critical(msg) return [] # Loop through the Truncation objects, truncating the models based on the truncation data and adding # the truncated models to the Truncation.models attribute for truncation in truncations: truncation.directory = os.path.join(self.work_dir, 'tlevel_{0}'.format(truncation.level)) os.mkdir(truncation.directory) logger.info('Truncating at: {0} in directory {1}'.format(truncation.level, truncation.directory)) truncation.models = [] for infile in self.models: pdbout = ample_util.filename_append(infile, str(truncation.level), directory=truncation.directory) # Loop through PDB files and create new ones that only contain the residues left after truncation pdb_edit.select_residues(pdbin=infile, pdbout=pdbout, tokeep_idx=truncation.residues_idxs) truncation.models.append(pdbout) self.truncations = truncations return truncations
@staticmethod def _convert_residue_scores(residue_scores): """Create named tuple to match store residue data""" scores = [ScoreVariances(idx=int(res)-1, # Required to match Theseus resSeq=int(res), variance=float(sco)) \ for (res, sco) in residue_scores] return scores