"""
Created on 23 Jul 2021
@author: hlasimpk
"""
from collections import OrderedDict
import gemmi
from itertools import groupby
import logging
from operator import itemgetter
import os
import numpy as np
import requests
from simbad.util.pdb_util import PdbStructure
[docs]class PdbModelException(Exception):
pass
AF_BASE_URL = 'https://alphafold.ebi.ac.uk/entry/'
AF2_DIR = 'AF2_files'
MODELS_DIR = 'models'
logger = logging.getLogger(__name__)
[docs]class ModelData(object):
OBJECT_ATTRIBUTES = ['hit', 'region']
def __init__(self):
self.avg_plddt = None
self.sum_plddt = None
self.date_made = None
self.molecular_weight = None
self.model_url = None
self.pdb_file = None
self.h_score = None
self.rmsd = None
self.hit = None
self.region = None
self.plddt_regions = None
@property
def length(self):
return self._get_child_attr('hit', 'length')
@property
def name(self):
fields = self._get_child_attr('hit', 'name').split("-")
return fields[1] + fields[2][2:]
@property
def model_id(self):
return self._get_child_attr('hit', 'name').split("-")[1]
@property
def range(self):
return (self.query_start, self.query_stop)
@property
def region_id(self):
return self._get_child_attr('hit', 'region_id')
@property
def region_index(self):
return self._get_child_attr('hit', 'region_index')
@property
def score(self):
return self._get_child_attr('hit', 'score')
@property
def seq_ident(self):
seq_id = self._get_child_attr('hit', 'local_sequence_identity')
if seq_id:
return seq_id / 100.0
return None
@property
def query_start(self):
return self._get_child_attr('hit', 'query_start')
@property
def query_stop(self):
return self._get_child_attr('hit', 'query_stop')
# miscellaneous properties
@property
def static_dict(self):
"""Return a self representation with all properties resolved, suitable for JSON"""
d = {k: self.__dict__[k] for k in self.__dict__.keys() if k not in self.OBJECT_ATTRIBUTES}
# Get all properties
for name in dir(self.__class__):
obj = getattr(self.__class__, name)
if name != 'static_dict' and isinstance(obj, property):
val = obj.__get__(self, self.__class__)
d[name] = val
# Need to add in properties as these aren't included
# FIX ONCE UPDATED JS TO HANDLE TWO INTS
d['range'] = "{}-{}".format(*self.range)
return d
def _get_child_attr(self, child, attr):
if hasattr(self, child):
child = getattr(self, child)
if hasattr(child, attr):
return getattr(child, attr)
return None
def __str__(self):
attrs = [k for k in self.__dict__.keys() if not k.startswith('_')]
line_template = " {} : {}\n"
out_str = "Class: {}\nData:\n".format(self.__class__)
for a in sorted(attrs):
out_str += line_template.format(a, self.__dict__[a])
return out_str
[docs]def models_from_hits(hits):
if not os.path.isdir(AF2_DIR):
os.mkdir(AF2_DIR)
if not os.path.isdir(MODELS_DIR):
os.mkdir(MODELS_DIR)
models = OrderedDict()
for hit in hits.values():
mlog = ModelData()
mlog.hit = hit
hit._homolog = mlog
mlog.model_url = AF_BASE_URL + hit.pdb_id.split('-')[1]
try:
mlog.pdb_file, mlog.molecular_weight, \
mlog.avg_plddt, mlog.sum_plddt, mlog.h_score, mlog.date_made, mlog.plddt_regions = prepare_pdb(hit)
except PdbModelException as e:
logger.critical("Error processing hit pdb %s", e.message)
models[mlog.name] = mlog
return models
[docs]def download_model(pdb_name):
"""Download AlphaFold2 model"""
url = 'https://alphafold.ebi.ac.uk/files/' + pdb_name
query = requests.get(url)
return query.text
[docs]def prepare_pdb(hit):
"""
Download pdb or take file from cache
trucate to required residues
calculate the MW
"""
pdb_name = "{0}_{1}.pdb".format(hit.pdb_id, hit.chain_id)
pdb_struct = PdbStructure()
try:
pdb_string = download_model(pdb_name)
pdb_struct.structure = gemmi.read_pdb_string(pdb_string)
date_made = pdb_string.split('\n')[0].split()[-1]
except RuntimeError:
# SIMBAD currently raises an empty RuntimeError for download problems.
raise PdbModelException("Error downloading PDB file for: {}".format(hit.pdb_id))
pdb_file = os.path.join(AF2_DIR, pdb_name)
pdb_struct.save(pdb_file)
seqid_range = range(hit.hit_start, hit.hit_stop + 1)
pdb_struct.select_residues(to_keep_idx=seqid_range)
avg_plddt = calculate_avg_plddt(pdb_struct.structure)
sum_plddt = calculate_sum_plddt(pdb_struct.structure)
h_score = calculate_quality_h_score(pdb_struct.structure)
plddt_regions = get_plddt_regions(pdb_struct.structure, hit.seq_ali)
# Convert plddt to bfactor score
pdb_struct.structure = convert_plddt_to_bfactor(pdb_struct.structure)
truncated_pdb_name = "{}_{}_{}-{}.pdb".format(hit.pdb_id, hit.chain_id, hit.hit_start, hit.hit_stop)
truncated_pdb_path = os.path.join(MODELS_DIR, truncated_pdb_name)
pdb_struct.save(truncated_pdb_path,
remarks=["PHASER ENSEMBLE MODEL 1 ID {}".format(hit.local_sequence_identity)])
return truncated_pdb_path, int(pdb_struct.molecular_weight), avg_plddt, sum_plddt, h_score, date_made, plddt_regions
[docs]def calculate_quality_threshold(struct, plddt_threshold=70):
res = above_threshold = 0
for chain in struct[0]:
for residue in chain:
res += 1
if residue[0].b_iso >= plddt_threshold:
above_threshold += 1
return (100.0 / res) * above_threshold
[docs]def calculate_quality_h_score(struct):
score = 0
for i in reversed(range(1, 101)):
if calculate_quality_threshold(struct, plddt_threshold=i) >= i:
score = i
break
return score
[docs]def get_plddt(struct):
plddt_values = []
for chain in struct[0]:
for residue in chain:
plddt_values.append(residue[0].b_iso)
return plddt_values
[docs]def get_plddt_regions(struct, seqid_range):
regions = {}
plddt_values = get_plddt(struct)
residues = zip(seqid_range, plddt_values)
v_low = []
low = []
confident = []
v_high = []
for i, plddt in residues:
if plddt < 50:
v_low.append(i)
elif 70 > plddt >= 50:
low.append(i)
elif 90 > plddt >= 70:
confident.append(i)
elif plddt >= 90:
v_high.append(i)
regions['v_low'] = _get_regions(v_low)
regions['low'] = _get_regions(low)
regions['confident'] = _get_regions(confident)
regions['v_high'] = _get_regions(v_high)
return regions
def _get_regions(residues):
regions = []
for k, g in groupby(enumerate(residues), lambda x: x[0] - x[1]):
group = (map(itemgetter(1), g))
group = list(map(int, group))
regions.append((group[0], group[-1]))
return regions
[docs]def calculate_avg_plddt(struct):
plddt_values = get_plddt(struct)
return sum(plddt_values) / len(plddt_values)
[docs]def calculate_sum_plddt(struct):
plddt_values = get_plddt(struct)
return sum(plddt_values)
[docs]def convert_plddt_to_bfactor(struct):
for chain in struct[0]:
for residue in chain:
for atom in residue:
plddt_value = atom.b_iso
atom.b_iso = _convert_plddt_to_bfactor(plddt_value)
return struct
def _convert_plddt_to_bfactor(plddt):
lddt = plddt / 100
if lddt <= 0.5:
return 657.97 # Same as the b-factor value with an rmsd estimate of 5.0
rmsd_est = (0.6 / (lddt ** 3))
bfactor = ((8 * (np.pi ** 2)) / 3.0) * (rmsd_est ** 2)
if bfactor > 999.99:
return 999.99
return bfactor