Source code for mrparse.mr_classify
"""
Created on 18 Oct 2018
@author: jmht
"""
import logging
import threading
import sys
from mrparse.mr_deepcoil import CCPred
# from mrparse.mr_topcons import TMPred
from mrparse.mr_tmhmm import TMPred
from mrparse.mr_jpred import JPred
from mrparse.mr_pfam import pfam_dict_from_annotation
logger = logging.getLogger(__name__)
[docs]class PredictorThread(threading.Thread):
def __init__(self, classifier):
super(PredictorThread, self).__init__()
self.classifier = classifier
self.exc_info = None
self.exception = None
[docs] def run(self):
try:
self.classifier.get_prediction()
except Exception as e:
self.exc_info = sys.exc_info()
self.exception = e
[docs]class MrClassifier(object):
def __init__(self, seq_info, do_ss_predictor=True, do_cc_predictor=True, do_tm_predictor=True, tmhmm_exe=None,
deepcoil_exe=None):
self.seq_info = seq_info
self.do_ss_predictor = do_ss_predictor
self.do_cc_predictor = do_cc_predictor
self.do_tm_predictor = do_tm_predictor
self.tmhmm_exe = tmhmm_exe
self.deepcoil_exe = deepcoil_exe
self.ss_prediction = None
self.classification_prediction = None
def __call__(self):
"""Required so that we can use multiprocessing pool. We need to be able to pickle the object passed
to the pool and instance methods don't work, so we add the object to the pool and define __call__
https://stackoverflow.com/questions/1816958/cant-pickle-type-instancemethod-when-using-multiprocessing-pool-map/6975654#6975654
"""
self.get_prediction()
return self
[docs] def get_prediction(self):
if self.do_cc_predictor:
cc_predictor = CCPred(self.seq_info, self.deepcoil_exe)
cc_thread = PredictorThread(cc_predictor)
cc_thread.start()
if self.do_tm_predictor:
tm_predictor = TMPred(self.seq_info, self.tmhmm_exe)
tm_thread = PredictorThread(tm_predictor)
tm_thread.start()
if self.do_ss_predictor:
ss_predictor = JPred(seq_info=self.seq_info)
ss_thread = PredictorThread(ss_predictor)
ss_thread.start()
# wait for jobs to finish
if self.do_cc_predictor:
cc_thread.join()
logger.info('Coiled-Coil predictor finished')
if self.do_tm_predictor:
tm_thread.join()
logger.info('TM predictor finished')
if self.do_ss_predictor:
ss_thread.join()
logger.info('SS predictor finished')
# Handle errors
if self.do_cc_predictor and cc_thread.exception:
logger.critical("Coiled-Coil predictor raised an exception: %s" % cc_thread.exception)
logger.debug("Traceback is:", exc_info=cc_thread.exc_info)
self.do_cc_predictor = False
if self.do_tm_predictor and tm_thread.exception:
logger.critical("Transmembrane predictor raised an exception: %s" % tm_thread.exception)
logger.debug("Traceback is:", exc_info=tm_thread.exc_info)
self.do_tm_predictor = False
# Determine pediction
if self.do_cc_predictor and self.do_tm_predictor:
self.classification_prediction = self.generate_consensus_classification([cc_predictor.prediction, tm_predictor.prediction])
elif self.do_cc_predictor:
self.classification_prediction = cc_predictor.prediction
elif self.do_tm_predictor:
self.classification_prediction = tm_predictor.prediction
if self.do_ss_predictor:
if ss_thread.exception:
logger.critical("JPred predictor raised error: %s" % ss_thread.exception)
logger.debug("Traceback is:", exc_info=ss_thread.exc_info)
else:
self.ss_prediction = ss_predictor.prediction
[docs] @staticmethod
def generate_consensus_classification(annotations):
lengths = [len(a) for a in annotations]
assert lengths.count(lengths[0]) == len(lengths), "Annotations have different lengths: %s" % lengths
for i, a in enumerate(annotations):
if i == 0:
consensus = a
continue
consensus = consensus + a
return consensus
[docs] def pfam_dict(self):
d = {}
if self.classification_prediction:
d['classification'] = pfam_dict_from_annotation(self.classification_prediction)
if self.ss_prediction:
d['ss_pred'] = pfam_dict_from_annotation(self.ss_prediction)
return d