mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
Merge branch 'master' of gitlab.alibaba-inc.com:Ali-MaaS/MaaS-lib into release/1.0
This commit is contained in:
@@ -99,6 +99,10 @@ class Models(object):
|
||||
team = 'team-multi-modal-similarity'
|
||||
video_clip = 'video-clip-multi-modal-embedding'
|
||||
|
||||
# science models
|
||||
unifold = 'unifold'
|
||||
unifold_symmetry = 'unifold-symmetry'
|
||||
|
||||
|
||||
class TaskModels(object):
|
||||
# nlp task
|
||||
@@ -266,6 +270,9 @@ class Pipelines(object):
|
||||
image_text_retrieval = 'image-text-retrieval'
|
||||
ofa_ocr_recognition = 'ofa-ocr-recognition'
|
||||
|
||||
# science tasks
|
||||
protein_structure = 'unifold-protein-structure'
|
||||
|
||||
|
||||
class Trainers(object):
|
||||
""" Names for different trainer.
|
||||
@@ -368,6 +375,9 @@ class Preprocessors(object):
|
||||
ofa_tasks_preprocessor = 'ofa-tasks-preprocessor'
|
||||
mplug_tasks_preprocessor = 'mplug-tasks-preprocessor'
|
||||
|
||||
# science preprocessor
|
||||
unifold_preprocessor = 'unifold-preprocessor'
|
||||
|
||||
|
||||
class Metrics(object):
|
||||
""" Names for different metrics.
|
||||
|
||||
@@ -10,6 +10,9 @@ class Metric(ABC):
|
||||
complex metrics for a specific task with or without other Metric subclasses.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add(self, outputs: Dict, inputs: Dict):
|
||||
""" Append logits and labels within an eval loop.
|
||||
|
||||
@@ -2,13 +2,19 @@
|
||||
|
||||
from modelscope.utils.config import ConfigDict
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule
|
||||
from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg
|
||||
|
||||
MODELS = Registry('models')
|
||||
BACKBONES = Registry('backbones')
|
||||
BACKBONES._modules = MODELS._modules
|
||||
BACKBONES = MODELS
|
||||
HEADS = Registry('heads')
|
||||
|
||||
modules = LazyImportModule.AST_INDEX[INDEX_KEY]
|
||||
for module_index in list(modules.keys()):
|
||||
if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES':
|
||||
modules[(MODELS.name.upper(), module_index[1],
|
||||
module_index[2])] = modules[module_index]
|
||||
|
||||
|
||||
def build_model(cfg: ConfigDict,
|
||||
task_name: str = None,
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import subprocess
|
||||
import uuid
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -84,7 +85,9 @@ class ActionDetONNX(Model):
|
||||
def forward_video(self, video_name, scale):
|
||||
min_size, max_size = self._get_sizes(scale)
|
||||
|
||||
tmp_dir = osp.join(self.tmp_dir, osp.basename(video_name)[:-4])
|
||||
tmp_dir = osp.join(
|
||||
self.tmp_dir,
|
||||
str(uuid.uuid1()) + '_' + osp.basename(video_name)[:-4])
|
||||
if osp.exists(tmp_dir):
|
||||
shutil.rmtree(tmp_dir)
|
||||
os.makedirs(tmp_dir)
|
||||
@@ -110,6 +113,7 @@ class ActionDetONNX(Model):
|
||||
len(frame_names) * self.temporal_stride,
|
||||
self.temporal_stride))
|
||||
batch_imgs = [self.parse_frames(names) for names in frame_names]
|
||||
shutil.rmtree(tmp_dir)
|
||||
|
||||
N, _, T, H, W = batch_imgs[0].shape
|
||||
scale_min = min_size / min(H, W)
|
||||
@@ -128,7 +132,6 @@ class ActionDetONNX(Model):
|
||||
'timestamp': t,
|
||||
'actions': res
|
||||
} for t, res in zip(timestamp, results)]
|
||||
shutil.rmtree(tmp_dir)
|
||||
return results
|
||||
|
||||
def forward(self, video_name):
|
||||
|
||||
@@ -25,9 +25,9 @@ emotion_list = [
|
||||
]
|
||||
|
||||
|
||||
def inference(image_path, model, face_model, score_thre=0.5, GPU=0):
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
|
||||
def inference(image, model, face_model, score_thre=0.5, GPU=0):
|
||||
image = image.cpu().numpy()
|
||||
image = Image.fromarray(image)
|
||||
face, bbox = face_detection_PIL_v2(image, face_model)
|
||||
if bbox is None:
|
||||
logger.warn('no face detected!')
|
||||
|
||||
@@ -115,9 +115,9 @@ std = [57.375, 57.12, 58.395]
|
||||
class_names = ['person', 'face', 'hand']
|
||||
|
||||
|
||||
def inference(model, device, img_path):
|
||||
def inference(model, device, img):
|
||||
img = img.cpu().numpy()
|
||||
img_info = {'id': 0}
|
||||
img = cv2.imread(img_path)
|
||||
height, width = img.shape[:2]
|
||||
img_info['height'] = height
|
||||
img_info['width'] = width
|
||||
@@ -130,4 +130,9 @@ def inference(model, device, img_path):
|
||||
with torch.no_grad():
|
||||
res = model(meta)
|
||||
result = overlay_bbox_cv(res[0], class_names, score_thresh=0.35)
|
||||
return result
|
||||
cls_list, bbox_list, score_list = [], [], []
|
||||
for pred in result:
|
||||
cls_list.append(pred[0])
|
||||
bbox_list.append([pred[1], pred[2], pred[3], pred[4]])
|
||||
score_list.append(pred[5])
|
||||
return cls_list, bbox_list, score_list
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from torchvision.transforms import transforms
|
||||
from torchvision import transforms
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import TorchModel
|
||||
@@ -80,9 +80,9 @@ class HandStatic(TorchModel):
|
||||
return pred_result
|
||||
|
||||
|
||||
def infer(img_path, model, device):
|
||||
|
||||
img = Image.open(img_path)
|
||||
def infer(img, model, device):
|
||||
img = img.cpu().numpy()
|
||||
img = Image.fromarray(img)
|
||||
clip = spatial_transform(img)
|
||||
clip = clip.unsqueeze(0).to(device).float()
|
||||
outputs = model(clip)
|
||||
|
||||
@@ -59,9 +59,8 @@ mean, std = np.array([[[124.55, 118.90,
|
||||
102.94]]]), np.array([[[56.77, 55.97, 57.50]]])
|
||||
|
||||
|
||||
def inference(model, device, input_path):
|
||||
img = Image.open(input_path)
|
||||
img = np.array(img.convert('RGB')).astype(np.float32)
|
||||
def inference(model, device, img):
|
||||
img = img.cpu().numpy()
|
||||
img = (img - mean) / std
|
||||
img = cv2.resize(img, dsize=(448, 448), interpolation=cv2.INTER_LINEAR)
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
@@ -236,8 +236,10 @@ class VideoCLIPForMultiModalEmbedding(TorchModel):
|
||||
logger.info('text feature: {}'.format(sequence_output[0][0][0]))
|
||||
logger.info('video feature: {}'.format(visual_output[0][0][0]))
|
||||
|
||||
output[OutputKeys.VIDEO_EMBEDDING] = visual_output
|
||||
output[OutputKeys.TEXT_EMBEDDING] = sequence_output
|
||||
output[
|
||||
OutputKeys.VIDEO_EMBEDDING] = visual_output.cpu().detach().numpy()
|
||||
output[OutputKeys.TEXT_EMBEDDING] = sequence_output.cpu().detach(
|
||||
).numpy()
|
||||
return output
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@@ -19,6 +19,7 @@ if TYPE_CHECKING:
|
||||
SbertForSequenceClassification,
|
||||
SbertForTokenClassification,
|
||||
SbertTokenizer,
|
||||
SbertModel,
|
||||
SbertTokenizerFast,
|
||||
)
|
||||
from .bert import (
|
||||
@@ -61,6 +62,7 @@ else:
|
||||
'SbertForTokenClassification',
|
||||
'SbertTokenizer',
|
||||
'SbertTokenizerFast',
|
||||
'SbertModel',
|
||||
],
|
||||
'veco': [
|
||||
'VecoModel', 'VecoConfig', 'VecoForTokenClassification',
|
||||
|
||||
@@ -18,14 +18,12 @@ logger = logging.get_logger(__name__)
|
||||
@MODELS.register_module(Tasks.text_ranking, module_name=Models.bert)
|
||||
class BertForTextRanking(BertForSequenceClassification):
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
def __init__(self, config, *args, **kwargs):
|
||||
super().__init__(config)
|
||||
self.train_batch_size = kwargs.get('train_batch_size', 4)
|
||||
neg_sample = kwargs.get('neg_sample', 8)
|
||||
self.neg_sample = neg_sample
|
||||
setattr(self, self.base_model_prefix,
|
||||
BertModel(self.config, add_pooling_layer=True))
|
||||
self.register_buffer(
|
||||
'target_label',
|
||||
torch.zeros(self.train_batch_size, dtype=torch.long))
|
||||
|
||||
def forward(self,
|
||||
input_ids=None,
|
||||
@@ -55,9 +53,12 @@ class BertForTextRanking(BertForSequenceClassification):
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
if self.base_model.training:
|
||||
scores = logits.view(self.train_batch_size, -1)
|
||||
scores = logits.view(-1, self.neg_sample + 1)
|
||||
batch_size = scores.size(0)
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
loss = loss_fct(scores, self.target_label)
|
||||
target_label = torch.zeros(
|
||||
batch_size, dtype=torch.long, device=scores.device)
|
||||
loss = loss_fct(scores, target_label)
|
||||
return AttentionTextClassificationModelOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@@ -78,9 +79,11 @@ class BertForTextRanking(BertForSequenceClassification):
|
||||
Returns:
|
||||
The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
|
||||
"""
|
||||
|
||||
num_labels = kwargs.get('num_labels', 1)
|
||||
neg_sample = kwargs.get('neg_sample', 4)
|
||||
model_args = {} if num_labels is None else {'num_labels': num_labels}
|
||||
if neg_sample is not None:
|
||||
model_args['neg_sample'] = neg_sample
|
||||
|
||||
model_dir = kwargs.get('model_dir')
|
||||
model = super(Model, cls).from_pretrained(
|
||||
|
||||
@@ -98,7 +98,7 @@ def compute_adv_loss(embedding,
|
||||
if is_nan:
|
||||
logger.warning('Nan occured when calculating adv loss.')
|
||||
return ori_loss
|
||||
emb_grad = emb_grad / emb_grad_norm
|
||||
emb_grad = emb_grad / (emb_grad_norm + 1e-6)
|
||||
embedding_2 = embedding_1 + adv_grad_factor * emb_grad
|
||||
embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2)
|
||||
embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2)
|
||||
|
||||
21
modelscope/models/science/__init__.py
Normal file
21
modelscope/models/science/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from .unifold import UnifoldForProteinStructrue
|
||||
|
||||
else:
|
||||
_import_structure = {'unifold': ['UnifoldForProteinStructrue']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
1
modelscope/models/science/unifold/__init__.py
Normal file
1
modelscope/models/science/unifold/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .model import UnifoldForProteinStructrue
|
||||
636
modelscope/models/science/unifold/config.py
Normal file
636
modelscope/models/science/unifold/config.py
Normal file
@@ -0,0 +1,636 @@
|
||||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
|
||||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
|
||||
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
import ml_collections as mlc
|
||||
|
||||
N_RES = 'number of residues'
|
||||
N_MSA = 'number of MSA sequences'
|
||||
N_EXTRA_MSA = 'number of extra MSA sequences'
|
||||
N_TPL = 'number of templates'
|
||||
|
||||
d_pair = mlc.FieldReference(128, field_type=int)
|
||||
d_msa = mlc.FieldReference(256, field_type=int)
|
||||
d_template = mlc.FieldReference(64, field_type=int)
|
||||
d_extra_msa = mlc.FieldReference(64, field_type=int)
|
||||
d_single = mlc.FieldReference(384, field_type=int)
|
||||
max_recycling_iters = mlc.FieldReference(3, field_type=int)
|
||||
chunk_size = mlc.FieldReference(4, field_type=int)
|
||||
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
|
||||
eps = mlc.FieldReference(1e-8, field_type=float)
|
||||
inf = mlc.FieldReference(3e4, field_type=float)
|
||||
use_templates = mlc.FieldReference(True, field_type=bool)
|
||||
is_multimer = mlc.FieldReference(False, field_type=bool)
|
||||
|
||||
|
||||
def base_config():
|
||||
return mlc.ConfigDict({
|
||||
'data': {
|
||||
'common': {
|
||||
'features': {
|
||||
'aatype': [N_RES],
|
||||
'all_atom_mask': [N_RES, None],
|
||||
'all_atom_positions': [N_RES, None, None],
|
||||
'alt_chi_angles': [N_RES, None],
|
||||
'atom14_alt_gt_exists': [N_RES, None],
|
||||
'atom14_alt_gt_positions': [N_RES, None, None],
|
||||
'atom14_atom_exists': [N_RES, None],
|
||||
'atom14_atom_is_ambiguous': [N_RES, None],
|
||||
'atom14_gt_exists': [N_RES, None],
|
||||
'atom14_gt_positions': [N_RES, None, None],
|
||||
'atom37_atom_exists': [N_RES, None],
|
||||
'frame_mask': [N_RES],
|
||||
'true_frame_tensor': [N_RES, None, None],
|
||||
'bert_mask': [N_MSA, N_RES],
|
||||
'chi_angles_sin_cos': [N_RES, None, None],
|
||||
'chi_mask': [N_RES, None],
|
||||
'extra_msa_deletion_value': [N_EXTRA_MSA, N_RES],
|
||||
'extra_msa_has_deletion': [N_EXTRA_MSA, N_RES],
|
||||
'extra_msa': [N_EXTRA_MSA, N_RES],
|
||||
'extra_msa_mask': [N_EXTRA_MSA, N_RES],
|
||||
'extra_msa_row_mask': [N_EXTRA_MSA],
|
||||
'is_distillation': [],
|
||||
'msa_feat': [N_MSA, N_RES, None],
|
||||
'msa_mask': [N_MSA, N_RES],
|
||||
'msa_chains': [N_MSA, None],
|
||||
'msa_row_mask': [N_MSA],
|
||||
'num_recycling_iters': [],
|
||||
'pseudo_beta': [N_RES, None],
|
||||
'pseudo_beta_mask': [N_RES],
|
||||
'residue_index': [N_RES],
|
||||
'residx_atom14_to_atom37': [N_RES, None],
|
||||
'residx_atom37_to_atom14': [N_RES, None],
|
||||
'resolution': [],
|
||||
'rigidgroups_alt_gt_frames': [N_RES, None, None, None],
|
||||
'rigidgroups_group_exists': [N_RES, None],
|
||||
'rigidgroups_group_is_ambiguous': [N_RES, None],
|
||||
'rigidgroups_gt_exists': [N_RES, None],
|
||||
'rigidgroups_gt_frames': [N_RES, None, None, None],
|
||||
'seq_length': [],
|
||||
'seq_mask': [N_RES],
|
||||
'target_feat': [N_RES, None],
|
||||
'template_aatype': [N_TPL, N_RES],
|
||||
'template_all_atom_mask': [N_TPL, N_RES, None],
|
||||
'template_all_atom_positions': [N_TPL, N_RES, None, None],
|
||||
'template_alt_torsion_angles_sin_cos': [
|
||||
N_TPL,
|
||||
N_RES,
|
||||
None,
|
||||
None,
|
||||
],
|
||||
'template_frame_mask': [N_TPL, N_RES],
|
||||
'template_frame_tensor': [N_TPL, N_RES, None, None],
|
||||
'template_mask': [N_TPL],
|
||||
'template_pseudo_beta': [N_TPL, N_RES, None],
|
||||
'template_pseudo_beta_mask': [N_TPL, N_RES],
|
||||
'template_sum_probs': [N_TPL, None],
|
||||
'template_torsion_angles_mask': [N_TPL, N_RES, None],
|
||||
'template_torsion_angles_sin_cos':
|
||||
[N_TPL, N_RES, None, None],
|
||||
'true_msa': [N_MSA, N_RES],
|
||||
'use_clamped_fape': [],
|
||||
'assembly_num_chains': [1],
|
||||
'asym_id': [N_RES],
|
||||
'sym_id': [N_RES],
|
||||
'entity_id': [N_RES],
|
||||
'num_sym': [N_RES],
|
||||
'asym_len': [None],
|
||||
'cluster_bias_mask': [N_MSA],
|
||||
},
|
||||
'masked_msa': {
|
||||
'profile_prob': 0.1,
|
||||
'same_prob': 0.1,
|
||||
'uniform_prob': 0.1,
|
||||
},
|
||||
'block_delete_msa': {
|
||||
'msa_fraction_per_block': 0.3,
|
||||
'randomize_num_blocks': False,
|
||||
'num_blocks': 5,
|
||||
'min_num_msa': 16,
|
||||
},
|
||||
'random_delete_msa': {
|
||||
'max_msa_entry': 1 << 25, # := 33554432
|
||||
},
|
||||
'v2_feature':
|
||||
False,
|
||||
'gumbel_sample':
|
||||
False,
|
||||
'max_extra_msa':
|
||||
1024,
|
||||
'msa_cluster_features':
|
||||
True,
|
||||
'reduce_msa_clusters_by_max_templates':
|
||||
True,
|
||||
'resample_msa_in_recycling':
|
||||
True,
|
||||
'template_features': [
|
||||
'template_all_atom_positions',
|
||||
'template_sum_probs',
|
||||
'template_aatype',
|
||||
'template_all_atom_mask',
|
||||
],
|
||||
'unsupervised_features': [
|
||||
'aatype',
|
||||
'residue_index',
|
||||
'msa',
|
||||
'msa_chains',
|
||||
'num_alignments',
|
||||
'seq_length',
|
||||
'between_segment_residues',
|
||||
'deletion_matrix',
|
||||
'num_recycling_iters',
|
||||
'crop_and_fix_size_seed',
|
||||
],
|
||||
'recycling_features': [
|
||||
'msa_chains',
|
||||
'msa_mask',
|
||||
'msa_row_mask',
|
||||
'bert_mask',
|
||||
'true_msa',
|
||||
'msa_feat',
|
||||
'extra_msa_deletion_value',
|
||||
'extra_msa_has_deletion',
|
||||
'extra_msa',
|
||||
'extra_msa_mask',
|
||||
'extra_msa_row_mask',
|
||||
'is_distillation',
|
||||
],
|
||||
'multimer_features': [
|
||||
'assembly_num_chains',
|
||||
'asym_id',
|
||||
'sym_id',
|
||||
'num_sym',
|
||||
'entity_id',
|
||||
'asym_len',
|
||||
'cluster_bias_mask',
|
||||
],
|
||||
'use_templates':
|
||||
use_templates,
|
||||
'is_multimer':
|
||||
is_multimer,
|
||||
'use_template_torsion_angles':
|
||||
use_templates,
|
||||
'max_recycling_iters':
|
||||
max_recycling_iters,
|
||||
},
|
||||
'supervised': {
|
||||
'use_clamped_fape_prob':
|
||||
1.0,
|
||||
'supervised_features': [
|
||||
'all_atom_mask',
|
||||
'all_atom_positions',
|
||||
'resolution',
|
||||
'use_clamped_fape',
|
||||
'is_distillation',
|
||||
],
|
||||
},
|
||||
'predict': {
|
||||
'fixed_size': True,
|
||||
'subsample_templates': False,
|
||||
'block_delete_msa': False,
|
||||
'random_delete_msa': True,
|
||||
'masked_msa_replace_fraction': 0.15,
|
||||
'max_msa_clusters': 128,
|
||||
'max_templates': 4,
|
||||
'num_ensembles': 2,
|
||||
'crop': False,
|
||||
'crop_size': None,
|
||||
'supervised': False,
|
||||
'biased_msa_by_chain': False,
|
||||
'share_mask': False,
|
||||
},
|
||||
'eval': {
|
||||
'fixed_size': True,
|
||||
'subsample_templates': False,
|
||||
'block_delete_msa': False,
|
||||
'random_delete_msa': True,
|
||||
'masked_msa_replace_fraction': 0.15,
|
||||
'max_msa_clusters': 128,
|
||||
'max_templates': 4,
|
||||
'num_ensembles': 1,
|
||||
'crop': False,
|
||||
'crop_size': None,
|
||||
'spatial_crop_prob': 0.5,
|
||||
'ca_ca_threshold': 10.0,
|
||||
'supervised': True,
|
||||
'biased_msa_by_chain': False,
|
||||
'share_mask': False,
|
||||
},
|
||||
'train': {
|
||||
'fixed_size': True,
|
||||
'subsample_templates': True,
|
||||
'block_delete_msa': True,
|
||||
'random_delete_msa': True,
|
||||
'masked_msa_replace_fraction': 0.15,
|
||||
'max_msa_clusters': 128,
|
||||
'max_templates': 4,
|
||||
'num_ensembles': 1,
|
||||
'crop': True,
|
||||
'crop_size': 256,
|
||||
'spatial_crop_prob': 0.5,
|
||||
'ca_ca_threshold': 10.0,
|
||||
'supervised': True,
|
||||
'use_clamped_fape_prob': 1.0,
|
||||
'max_distillation_msa_clusters': 1000,
|
||||
'biased_msa_by_chain': True,
|
||||
'share_mask': True,
|
||||
},
|
||||
},
|
||||
'globals': {
|
||||
'chunk_size': chunk_size,
|
||||
'block_size': None,
|
||||
'd_pair': d_pair,
|
||||
'd_msa': d_msa,
|
||||
'd_template': d_template,
|
||||
'd_extra_msa': d_extra_msa,
|
||||
'd_single': d_single,
|
||||
'eps': eps,
|
||||
'inf': inf,
|
||||
'max_recycling_iters': max_recycling_iters,
|
||||
'alphafold_original_mode': False,
|
||||
},
|
||||
'model': {
|
||||
'is_multimer': is_multimer,
|
||||
'input_embedder': {
|
||||
'tf_dim': 22,
|
||||
'msa_dim': 49,
|
||||
'd_pair': d_pair,
|
||||
'd_msa': d_msa,
|
||||
'relpos_k': 32,
|
||||
'max_relative_chain': 2,
|
||||
},
|
||||
'recycling_embedder': {
|
||||
'd_pair': d_pair,
|
||||
'd_msa': d_msa,
|
||||
'min_bin': 3.25,
|
||||
'max_bin': 20.75,
|
||||
'num_bins': 15,
|
||||
'inf': 1e8,
|
||||
},
|
||||
'template': {
|
||||
'distogram': {
|
||||
'min_bin': 3.25,
|
||||
'max_bin': 50.75,
|
||||
'num_bins': 39,
|
||||
},
|
||||
'template_angle_embedder': {
|
||||
'd_in': 57,
|
||||
'd_out': d_msa,
|
||||
},
|
||||
'template_pair_embedder': {
|
||||
'd_in': 88,
|
||||
'v2_d_in': [39, 1, 22, 22, 1, 1, 1, 1],
|
||||
'd_pair': d_pair,
|
||||
'd_out': d_template,
|
||||
'v2_feature': False,
|
||||
},
|
||||
'template_pair_stack': {
|
||||
'd_template': d_template,
|
||||
'd_hid_tri_att': 16,
|
||||
'd_hid_tri_mul': 64,
|
||||
'num_blocks': 2,
|
||||
'num_heads': 4,
|
||||
'pair_transition_n': 2,
|
||||
'dropout_rate': 0.25,
|
||||
'inf': 1e9,
|
||||
'tri_attn_first': True,
|
||||
},
|
||||
'template_pointwise_attention': {
|
||||
'enabled': True,
|
||||
'd_template': d_template,
|
||||
'd_pair': d_pair,
|
||||
'd_hid': 16,
|
||||
'num_heads': 4,
|
||||
'inf': 1e5,
|
||||
},
|
||||
'inf': 1e5,
|
||||
'eps': 1e-6,
|
||||
'enabled': use_templates,
|
||||
'embed_angles': use_templates,
|
||||
},
|
||||
'extra_msa': {
|
||||
'extra_msa_embedder': {
|
||||
'd_in': 25,
|
||||
'd_out': d_extra_msa,
|
||||
},
|
||||
'extra_msa_stack': {
|
||||
'd_msa': d_extra_msa,
|
||||
'd_pair': d_pair,
|
||||
'd_hid_msa_att': 8,
|
||||
'd_hid_opm': 32,
|
||||
'd_hid_mul': 128,
|
||||
'd_hid_pair_att': 32,
|
||||
'num_heads_msa': 8,
|
||||
'num_heads_pair': 4,
|
||||
'num_blocks': 4,
|
||||
'transition_n': 4,
|
||||
'msa_dropout': 0.15,
|
||||
'pair_dropout': 0.25,
|
||||
'inf': 1e9,
|
||||
'eps': 1e-10,
|
||||
'outer_product_mean_first': False,
|
||||
},
|
||||
'enabled': True,
|
||||
},
|
||||
'evoformer_stack': {
|
||||
'd_msa': d_msa,
|
||||
'd_pair': d_pair,
|
||||
'd_hid_msa_att': 32,
|
||||
'd_hid_opm': 32,
|
||||
'd_hid_mul': 128,
|
||||
'd_hid_pair_att': 32,
|
||||
'd_single': d_single,
|
||||
'num_heads_msa': 8,
|
||||
'num_heads_pair': 4,
|
||||
'num_blocks': 48,
|
||||
'transition_n': 4,
|
||||
'msa_dropout': 0.15,
|
||||
'pair_dropout': 0.25,
|
||||
'inf': 1e9,
|
||||
'eps': 1e-10,
|
||||
'outer_product_mean_first': False,
|
||||
},
|
||||
'structure_module': {
|
||||
'd_single': d_single,
|
||||
'd_pair': d_pair,
|
||||
'd_ipa': 16,
|
||||
'd_angle': 128,
|
||||
'num_heads_ipa': 12,
|
||||
'num_qk_points': 4,
|
||||
'num_v_points': 8,
|
||||
'dropout_rate': 0.1,
|
||||
'num_blocks': 8,
|
||||
'no_transition_layers': 1,
|
||||
'num_resnet_blocks': 2,
|
||||
'num_angles': 7,
|
||||
'trans_scale_factor': 10,
|
||||
'epsilon': 1e-12,
|
||||
'inf': 1e5,
|
||||
'separate_kv': False,
|
||||
'ipa_bias': True,
|
||||
},
|
||||
'heads': {
|
||||
'plddt': {
|
||||
'num_bins': 50,
|
||||
'd_in': d_single,
|
||||
'd_hid': 128,
|
||||
},
|
||||
'distogram': {
|
||||
'd_pair': d_pair,
|
||||
'num_bins': aux_distogram_bins,
|
||||
'disable_enhance_head': False,
|
||||
},
|
||||
'pae': {
|
||||
'd_pair': d_pair,
|
||||
'num_bins': aux_distogram_bins,
|
||||
'enabled': False,
|
||||
'iptm_weight': 0.8,
|
||||
'disable_enhance_head': False,
|
||||
},
|
||||
'masked_msa': {
|
||||
'd_msa': d_msa,
|
||||
'd_out': 23,
|
||||
'disable_enhance_head': False,
|
||||
},
|
||||
'experimentally_resolved': {
|
||||
'd_single': d_single,
|
||||
'd_out': 37,
|
||||
'enabled': False,
|
||||
'disable_enhance_head': False,
|
||||
},
|
||||
},
|
||||
},
|
||||
'loss': {
|
||||
'distogram': {
|
||||
'min_bin': 2.3125,
|
||||
'max_bin': 21.6875,
|
||||
'num_bins': 64,
|
||||
'eps': 1e-6,
|
||||
'weight': 0.3,
|
||||
},
|
||||
'experimentally_resolved': {
|
||||
'eps': 1e-8,
|
||||
'min_resolution': 0.1,
|
||||
'max_resolution': 3.0,
|
||||
'weight': 0.0,
|
||||
},
|
||||
'fape': {
|
||||
'backbone': {
|
||||
'clamp_distance': 10.0,
|
||||
'clamp_distance_between_chains': 30.0,
|
||||
'loss_unit_distance': 10.0,
|
||||
'loss_unit_distance_between_chains': 20.0,
|
||||
'weight': 0.5,
|
||||
'eps': 1e-4,
|
||||
},
|
||||
'sidechain': {
|
||||
'clamp_distance': 10.0,
|
||||
'length_scale': 10.0,
|
||||
'weight': 0.5,
|
||||
'eps': 1e-4,
|
||||
},
|
||||
'weight': 1.0,
|
||||
},
|
||||
'plddt': {
|
||||
'min_resolution': 0.1,
|
||||
'max_resolution': 3.0,
|
||||
'cutoff': 15.0,
|
||||
'num_bins': 50,
|
||||
'eps': 1e-10,
|
||||
'weight': 0.01,
|
||||
},
|
||||
'masked_msa': {
|
||||
'eps': 1e-8,
|
||||
'weight': 2.0,
|
||||
},
|
||||
'supervised_chi': {
|
||||
'chi_weight': 0.5,
|
||||
'angle_norm_weight': 0.01,
|
||||
'eps': 1e-6,
|
||||
'weight': 1.0,
|
||||
},
|
||||
'violation': {
|
||||
'violation_tolerance_factor': 12.0,
|
||||
'clash_overlap_tolerance': 1.5,
|
||||
'bond_angle_loss_weight': 0.3,
|
||||
'eps': 1e-6,
|
||||
'weight': 0.0,
|
||||
},
|
||||
'pae': {
|
||||
'max_bin': 31,
|
||||
'num_bins': 64,
|
||||
'min_resolution': 0.1,
|
||||
'max_resolution': 3.0,
|
||||
'eps': 1e-8,
|
||||
'weight': 0.0,
|
||||
},
|
||||
'repr_norm': {
|
||||
'weight': 0.01,
|
||||
'tolerance': 1.0,
|
||||
},
|
||||
'chain_centre_mass': {
|
||||
'weight': 0.0,
|
||||
'eps': 1e-8,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
def recursive_set(c: mlc.ConfigDict, key: str, value: Any, ignore: str = None):
|
||||
with c.unlocked():
|
||||
for k, v in c.items():
|
||||
if ignore is not None and k == ignore:
|
||||
continue
|
||||
if isinstance(v, mlc.ConfigDict):
|
||||
recursive_set(v, key, value)
|
||||
elif k == key:
|
||||
c[k] = value
|
||||
|
||||
|
||||
def model_config(name, train=False):
|
||||
c = copy.deepcopy(base_config())
|
||||
|
||||
def model_2_v2(c):
|
||||
recursive_set(c, 'v2_feature', True)
|
||||
recursive_set(c, 'gumbel_sample', True)
|
||||
c.model.heads.masked_msa.d_out = 22
|
||||
c.model.structure_module.separate_kv = True
|
||||
c.model.structure_module.ipa_bias = False
|
||||
c.model.template.template_angle_embedder.d_in = 34
|
||||
return c
|
||||
|
||||
def multimer(c):
|
||||
recursive_set(c, 'is_multimer', True)
|
||||
recursive_set(c, 'max_extra_msa', 1152)
|
||||
recursive_set(c, 'max_msa_clusters', 128)
|
||||
recursive_set(c, 'v2_feature', True)
|
||||
recursive_set(c, 'gumbel_sample', True)
|
||||
c.model.template.template_angle_embedder.d_in = 34
|
||||
c.model.template.template_pair_stack.tri_attn_first = False
|
||||
c.model.template.template_pointwise_attention.enabled = False
|
||||
c.model.heads.pae.enabled = True
|
||||
# we forget to enable it in our training, so disable it here
|
||||
c.model.heads.pae.disable_enhance_head = True
|
||||
c.model.heads.masked_msa.d_out = 22
|
||||
c.model.structure_module.separate_kv = True
|
||||
c.model.structure_module.ipa_bias = False
|
||||
c.model.structure_module.trans_scale_factor = 20
|
||||
c.loss.pae.weight = 0.1
|
||||
c.model.input_embedder.tf_dim = 21
|
||||
c.data.train.crop_size = 384
|
||||
c.loss.violation.weight = 0.02
|
||||
c.loss.chain_centre_mass.weight = 1.0
|
||||
return c
|
||||
|
||||
if name == 'model_1':
|
||||
pass
|
||||
elif name == 'model_1_ft':
|
||||
recursive_set(c, 'max_extra_msa', 5120)
|
||||
recursive_set(c, 'max_msa_clusters', 512)
|
||||
c.data.train.crop_size = 384
|
||||
c.loss.violation.weight = 0.02
|
||||
elif name == 'model_1_af2':
|
||||
recursive_set(c, 'max_extra_msa', 5120)
|
||||
recursive_set(c, 'max_msa_clusters', 512)
|
||||
c.data.train.crop_size = 384
|
||||
c.loss.violation.weight = 0.02
|
||||
c.loss.repr_norm.weight = 0
|
||||
c.model.heads.experimentally_resolved.enabled = True
|
||||
c.loss.experimentally_resolved.weight = 0.01
|
||||
c.globals.alphafold_original_mode = True
|
||||
elif name == 'model_2':
|
||||
pass
|
||||
elif name == 'model_init':
|
||||
pass
|
||||
elif name == 'model_init_af2':
|
||||
c.globals.alphafold_original_mode = True
|
||||
pass
|
||||
elif name == 'model_2_ft':
|
||||
recursive_set(c, 'max_extra_msa', 1024)
|
||||
recursive_set(c, 'max_msa_clusters', 512)
|
||||
c.data.train.crop_size = 384
|
||||
c.loss.violation.weight = 0.02
|
||||
elif name == 'model_2_af2':
|
||||
recursive_set(c, 'max_extra_msa', 1024)
|
||||
recursive_set(c, 'max_msa_clusters', 512)
|
||||
c.data.train.crop_size = 384
|
||||
c.loss.violation.weight = 0.02
|
||||
c.loss.repr_norm.weight = 0
|
||||
c.model.heads.experimentally_resolved.enabled = True
|
||||
c.loss.experimentally_resolved.weight = 0.01
|
||||
c.globals.alphafold_original_mode = True
|
||||
elif name == 'model_2_v2':
|
||||
c = model_2_v2(c)
|
||||
elif name == 'model_2_v2_ft':
|
||||
c = model_2_v2(c)
|
||||
recursive_set(c, 'max_extra_msa', 1024)
|
||||
recursive_set(c, 'max_msa_clusters', 512)
|
||||
c.data.train.crop_size = 384
|
||||
c.loss.violation.weight = 0.02
|
||||
elif name == 'model_3_af2' or name == 'model_4_af2':
|
||||
recursive_set(c, 'max_extra_msa', 5120)
|
||||
recursive_set(c, 'max_msa_clusters', 512)
|
||||
c.data.train.crop_size = 384
|
||||
c.loss.violation.weight = 0.02
|
||||
c.loss.repr_norm.weight = 0
|
||||
c.model.heads.experimentally_resolved.enabled = True
|
||||
c.loss.experimentally_resolved.weight = 0.01
|
||||
c.globals.alphafold_original_mode = True
|
||||
c.model.template.enabled = False
|
||||
c.model.template.embed_angles = False
|
||||
recursive_set(c, 'use_templates', False)
|
||||
recursive_set(c, 'use_template_torsion_angles', False)
|
||||
elif name == 'model_5_af2':
|
||||
recursive_set(c, 'max_extra_msa', 1024)
|
||||
recursive_set(c, 'max_msa_clusters', 512)
|
||||
c.data.train.crop_size = 384
|
||||
c.loss.violation.weight = 0.02
|
||||
c.loss.repr_norm.weight = 0
|
||||
c.model.heads.experimentally_resolved.enabled = True
|
||||
c.loss.experimentally_resolved.weight = 0.01
|
||||
c.globals.alphafold_original_mode = True
|
||||
c.model.template.enabled = False
|
||||
c.model.template.embed_angles = False
|
||||
recursive_set(c, 'use_templates', False)
|
||||
recursive_set(c, 'use_template_torsion_angles', False)
|
||||
elif name == 'multimer':
|
||||
c = multimer(c)
|
||||
elif name == 'multimer_ft':
|
||||
c = multimer(c)
|
||||
recursive_set(c, 'max_extra_msa', 1152)
|
||||
recursive_set(c, 'max_msa_clusters', 256)
|
||||
c.data.train.crop_size = 384
|
||||
c.loss.violation.weight = 0.5
|
||||
elif name == 'multimer_af2':
|
||||
recursive_set(c, 'max_extra_msa', 1152)
|
||||
recursive_set(c, 'max_msa_clusters', 256)
|
||||
recursive_set(c, 'is_multimer', True)
|
||||
recursive_set(c, 'v2_feature', True)
|
||||
recursive_set(c, 'gumbel_sample', True)
|
||||
c.model.template.template_angle_embedder.d_in = 34
|
||||
c.model.template.template_pair_stack.tri_attn_first = False
|
||||
c.model.template.template_pointwise_attention.enabled = False
|
||||
c.model.heads.pae.enabled = True
|
||||
c.model.heads.experimentally_resolved.enabled = True
|
||||
c.model.heads.masked_msa.d_out = 22
|
||||
c.model.structure_module.separate_kv = True
|
||||
c.model.structure_module.ipa_bias = False
|
||||
c.model.structure_module.trans_scale_factor = 20
|
||||
c.loss.pae.weight = 0.1
|
||||
c.loss.violation.weight = 0.5
|
||||
c.loss.experimentally_resolved.weight = 0.01
|
||||
c.model.input_embedder.tf_dim = 21
|
||||
c.globals.alphafold_original_mode = True
|
||||
c.data.train.crop_size = 384
|
||||
c.loss.repr_norm.weight = 0
|
||||
c.loss.chain_centre_mass.weight = 1.0
|
||||
recursive_set(c, 'outer_product_mean_first', True)
|
||||
else:
|
||||
raise ValueError(f'invalid --model-name: {name}.')
|
||||
if train:
|
||||
c.globals.chunk_size = None
|
||||
recursive_set(c, 'inf', 3e4)
|
||||
recursive_set(c, 'eps', 1e-5, 'loss')
|
||||
return c
|
||||
14
modelscope/models/science/unifold/data/__init__.py
Normal file
14
modelscope/models/science/unifold/data/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Data pipeline for model features."""
|
||||
1397
modelscope/models/science/unifold/data/data_ops.py
Normal file
1397
modelscope/models/science/unifold/data/data_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
526
modelscope/models/science/unifold/data/msa_pairing.py
Normal file
526
modelscope/models/science/unifold/data/msa_pairing.py
Normal file
@@ -0,0 +1,526 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pairing logic for multimer data """
|
||||
|
||||
import collections
|
||||
from typing import Dict, Iterable, List, Sequence
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import scipy.linalg
|
||||
|
||||
from .data_ops import NumpyDict
|
||||
from .residue_constants import restypes_with_x_and_gap
|
||||
|
||||
MSA_GAP_IDX = restypes_with_x_and_gap.index('-')
|
||||
SEQUENCE_GAP_CUTOFF = 0.5
|
||||
SEQUENCE_SIMILARITY_CUTOFF = 0.9
|
||||
|
||||
MSA_PAD_VALUES = {
|
||||
'msa_all_seq': MSA_GAP_IDX,
|
||||
'msa_mask_all_seq': 1,
|
||||
'deletion_matrix_all_seq': 0,
|
||||
'deletion_matrix_int_all_seq': 0,
|
||||
'msa': MSA_GAP_IDX,
|
||||
'msa_mask': 1,
|
||||
'deletion_matrix': 0,
|
||||
'deletion_matrix_int': 0,
|
||||
}
|
||||
|
||||
MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int')
|
||||
SEQ_FEATURES = (
|
||||
'residue_index',
|
||||
'aatype',
|
||||
'all_atom_positions',
|
||||
'all_atom_mask',
|
||||
'seq_mask',
|
||||
'between_segment_residues',
|
||||
'has_alt_locations',
|
||||
'has_hetatoms',
|
||||
'asym_id',
|
||||
'entity_id',
|
||||
'sym_id',
|
||||
'entity_mask',
|
||||
'deletion_mean',
|
||||
'prediction_atom_mask',
|
||||
'literature_positions',
|
||||
'atom_indices_to_group_indices',
|
||||
'rigid_group_default_frame',
|
||||
# zy
|
||||
'num_sym',
|
||||
)
|
||||
TEMPLATE_FEATURES = (
|
||||
'template_aatype',
|
||||
'template_all_atom_positions',
|
||||
'template_all_atom_mask',
|
||||
)
|
||||
CHAIN_FEATURES = ('num_alignments', 'seq_length')
|
||||
|
||||
|
||||
def create_paired_features(chains: Iterable[NumpyDict], ) -> List[NumpyDict]:
|
||||
"""Returns the original chains with paired NUM_SEQ features.
|
||||
|
||||
Args:
|
||||
chains: A list of feature dictionaries for each chain.
|
||||
|
||||
Returns:
|
||||
A list of feature dictionaries with sequence features including only
|
||||
rows to be paired.
|
||||
"""
|
||||
chains = list(chains)
|
||||
chain_keys = chains[0].keys()
|
||||
|
||||
if len(chains) < 2:
|
||||
return chains
|
||||
else:
|
||||
updated_chains = []
|
||||
paired_chains_to_paired_row_indices = pair_sequences(chains)
|
||||
paired_rows = reorder_paired_rows(paired_chains_to_paired_row_indices)
|
||||
|
||||
for chain_num, chain in enumerate(chains):
|
||||
new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k}
|
||||
for feature_name in chain_keys:
|
||||
if feature_name.endswith('_all_seq'):
|
||||
feats_padded = pad_features(chain[feature_name],
|
||||
feature_name)
|
||||
new_chain[feature_name] = feats_padded[
|
||||
paired_rows[:, chain_num]]
|
||||
new_chain['num_alignments_all_seq'] = np.asarray(
|
||||
len(paired_rows[:, chain_num]))
|
||||
updated_chains.append(new_chain)
|
||||
return updated_chains
|
||||
|
||||
|
||||
def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
|
||||
"""Add a 'padding' row at the end of the features list.
|
||||
|
||||
The padding row will be selected as a 'paired' row in the case of partial
|
||||
alignment - for the chain that doesn't have paired alignment.
|
||||
|
||||
Args:
|
||||
feature: The feature to be padded.
|
||||
feature_name: The name of the feature to be padded.
|
||||
|
||||
Returns:
|
||||
The feature with an additional padding row.
|
||||
"""
|
||||
assert feature.dtype != np.dtype(np.string_)
|
||||
if feature_name in (
|
||||
'msa_all_seq',
|
||||
'msa_mask_all_seq',
|
||||
'deletion_matrix_all_seq',
|
||||
'deletion_matrix_int_all_seq',
|
||||
):
|
||||
num_res = feature.shape[1]
|
||||
padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res],
|
||||
feature.dtype)
|
||||
elif feature_name == 'msa_species_identifiers_all_seq':
|
||||
padding = [b'']
|
||||
else:
|
||||
return feature
|
||||
feats_padded = np.concatenate([feature, padding], axis=0)
|
||||
return feats_padded
|
||||
|
||||
|
||||
def _make_msa_df(chain_features: NumpyDict) -> pd.DataFrame:
|
||||
"""Makes dataframe with msa features needed for msa pairing."""
|
||||
chain_msa = chain_features['msa_all_seq']
|
||||
query_seq = chain_msa[0]
|
||||
per_seq_similarity = np.sum(
|
||||
query_seq[None] == chain_msa, axis=-1) / float(len(query_seq))
|
||||
per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq))
|
||||
msa_df = pd.DataFrame({
|
||||
'msa_species_identifiers':
|
||||
chain_features['msa_species_identifiers_all_seq'],
|
||||
'msa_row':
|
||||
np.arange(len(chain_features['msa_species_identifiers_all_seq'])),
|
||||
'msa_similarity':
|
||||
per_seq_similarity,
|
||||
'gap':
|
||||
per_seq_gap,
|
||||
})
|
||||
return msa_df
|
||||
|
||||
|
||||
def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
|
||||
"""Creates mapping from species to msa dataframe of that species."""
|
||||
species_lookup = {}
|
||||
for species, species_df in msa_df.groupby('msa_species_identifiers'):
|
||||
species_lookup[species] = species_df
|
||||
return species_lookup
|
||||
|
||||
|
||||
def _match_rows_by_sequence_similarity(
|
||||
this_species_msa_dfs: List[pd.DataFrame], ) -> List[List[int]]: # noqa
|
||||
"""Finds MSA sequence pairings across chains based on sequence similarity.
|
||||
|
||||
Each chain's MSA sequences are first sorted by their sequence similarity to
|
||||
their respective target sequence. The sequences are then paired, starting
|
||||
from the sequences most similar to their target sequence.
|
||||
|
||||
Args:
|
||||
this_species_msa_dfs: a list of dataframes containing MSA features for
|
||||
sequences for a specific species.
|
||||
|
||||
Returns:
|
||||
A list of lists, each containing M indices corresponding to paired MSA rows,
|
||||
where M is the number of chains.
|
||||
"""
|
||||
all_paired_msa_rows = []
|
||||
|
||||
num_seqs = [
|
||||
len(species_df) for species_df in this_species_msa_dfs
|
||||
if species_df is not None
|
||||
]
|
||||
take_num_seqs = np.min(num_seqs)
|
||||
|
||||
# sort_by_similarity = lambda x: x.sort_values(
|
||||
# 'msa_similarity', axis=0, ascending=False)
|
||||
|
||||
def sort_by_similarity(x):
|
||||
return x.sort_values('msa_similarity', axis=0, ascending=False)
|
||||
|
||||
for species_df in this_species_msa_dfs:
|
||||
if species_df is not None:
|
||||
species_df_sorted = sort_by_similarity(species_df)
|
||||
msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values
|
||||
else:
|
||||
msa_rows = [-1] * take_num_seqs # take the last 'padding' row
|
||||
all_paired_msa_rows.append(msa_rows)
|
||||
all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose())
|
||||
return all_paired_msa_rows
|
||||
|
||||
|
||||
def pair_sequences(examples: List[NumpyDict]) -> Dict[int, np.ndarray]:
|
||||
"""Returns indices for paired MSA sequences across chains."""
|
||||
|
||||
num_examples = len(examples)
|
||||
|
||||
all_chain_species_dict = []
|
||||
common_species = set()
|
||||
for chain_features in examples:
|
||||
msa_df = _make_msa_df(chain_features)
|
||||
species_dict = _create_species_dict(msa_df)
|
||||
all_chain_species_dict.append(species_dict)
|
||||
common_species.update(set(species_dict))
|
||||
|
||||
common_species = sorted(common_species)
|
||||
common_species.remove(b'') # Remove target sequence species.
|
||||
|
||||
all_paired_msa_rows = [np.zeros(len(examples), int)]
|
||||
all_paired_msa_rows_dict = {k: [] for k in range(num_examples)}
|
||||
all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)]
|
||||
|
||||
for species in common_species:
|
||||
if not species:
|
||||
continue
|
||||
this_species_msa_dfs = []
|
||||
species_dfs_present = 0
|
||||
for species_dict in all_chain_species_dict:
|
||||
if species in species_dict:
|
||||
this_species_msa_dfs.append(species_dict[species])
|
||||
species_dfs_present += 1
|
||||
else:
|
||||
this_species_msa_dfs.append(None)
|
||||
|
||||
# Skip species that are present in only one chain.
|
||||
if species_dfs_present <= 1:
|
||||
continue
|
||||
|
||||
if np.any(
|
||||
np.array([
|
||||
len(species_df) for species_df in this_species_msa_dfs
|
||||
if isinstance(species_df, pd.DataFrame)
|
||||
]) > 600):
|
||||
continue
|
||||
|
||||
paired_msa_rows = _match_rows_by_sequence_similarity(
|
||||
this_species_msa_dfs)
|
||||
all_paired_msa_rows.extend(paired_msa_rows)
|
||||
all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
|
||||
all_paired_msa_rows_dict = {
|
||||
num_examples: np.array(paired_msa_rows)
|
||||
for num_examples, paired_msa_rows in all_paired_msa_rows_dict.items()
|
||||
}
|
||||
return all_paired_msa_rows_dict
|
||||
|
||||
|
||||
def reorder_paired_rows(
|
||||
all_paired_msa_rows_dict: Dict[int, np.ndarray]) -> np.ndarray:
|
||||
"""Creates a list of indices of paired MSA rows across chains.
|
||||
|
||||
Args:
|
||||
all_paired_msa_rows_dict: a mapping from the number of paired chains to the
|
||||
paired indices.
|
||||
|
||||
Returns:
|
||||
a list of lists, each containing indices of paired MSA rows across chains.
|
||||
The paired-index lists are ordered by:
|
||||
1) the number of chains in the paired alignment, i.e, all-chain pairings
|
||||
will come first.
|
||||
2) e-values
|
||||
"""
|
||||
all_paired_msa_rows = []
|
||||
|
||||
for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True):
|
||||
paired_rows = all_paired_msa_rows_dict[num_pairings]
|
||||
paired_rows_product = np.abs(
|
||||
np.array(
|
||||
[np.prod(rows.astype(np.float64)) for rows in paired_rows]))
|
||||
paired_rows_sort_index = np.argsort(paired_rows_product)
|
||||
all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index])
|
||||
|
||||
return np.array(all_paired_msa_rows)
|
||||
|
||||
|
||||
def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
|
||||
"""Like scipy.linalg.block_diag but with an optional padding value."""
|
||||
ones_arrs = [np.ones_like(x) for x in arrs]
|
||||
off_diag_mask = 1 - scipy.linalg.block_diag(*ones_arrs)
|
||||
diag = scipy.linalg.block_diag(*arrs)
|
||||
diag += (off_diag_mask * pad_value).astype(diag.dtype)
|
||||
return diag
|
||||
|
||||
|
||||
def _correct_post_merged_feats(np_example: NumpyDict,
|
||||
np_chains_list: Sequence[NumpyDict],
|
||||
pair_msa_sequences: bool) -> NumpyDict:
|
||||
"""Adds features that need to be computed/recomputed post merging."""
|
||||
|
||||
np_example['seq_length'] = np.asarray(
|
||||
np_example['aatype'].shape[0], dtype=np.int32)
|
||||
np_example['num_alignments'] = np.asarray(
|
||||
np_example['msa'].shape[0], dtype=np.int32)
|
||||
|
||||
if not pair_msa_sequences:
|
||||
# Generate a bias that is 1 for the first row of every block in the
|
||||
# block diagonal MSA - i.e. make sure the cluster stack always includes
|
||||
# the query sequences for each chain (since the first row is the query
|
||||
# sequence).
|
||||
cluster_bias_masks = []
|
||||
for chain in np_chains_list:
|
||||
mask = np.zeros(chain['msa'].shape[0])
|
||||
mask[0] = 1
|
||||
cluster_bias_masks.append(mask)
|
||||
np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks)
|
||||
|
||||
# Initialize Bert mask with masked out off diagonals.
|
||||
msa_masks = [
|
||||
np.ones(x['msa'].shape, dtype=np.int8) for x in np_chains_list
|
||||
]
|
||||
|
||||
np_example['bert_mask'] = block_diag(*msa_masks, pad_value=0)
|
||||
else:
|
||||
np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0])
|
||||
np_example['cluster_bias_mask'][0] = 1
|
||||
|
||||
# Initialize Bert mask with masked out off diagonals.
|
||||
msa_masks = [
|
||||
np.ones(x['msa'].shape, dtype=np.int8) for x in np_chains_list
|
||||
]
|
||||
msa_masks_all_seq = [
|
||||
np.ones(x['msa_all_seq'].shape, dtype=np.int8)
|
||||
for x in np_chains_list
|
||||
]
|
||||
|
||||
msa_mask_block_diag = block_diag(*msa_masks, pad_value=0)
|
||||
msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1)
|
||||
np_example['bert_mask'] = np.concatenate(
|
||||
[msa_mask_all_seq, msa_mask_block_diag], axis=0)
|
||||
return np_example
|
||||
|
||||
|
||||
def _pad_templates(chains: Sequence[NumpyDict],
|
||||
max_templates: int) -> Sequence[NumpyDict]:
|
||||
"""For each chain pad the number of templates to a fixed size.
|
||||
|
||||
Args:
|
||||
chains: A list of protein chains.
|
||||
max_templates: Each chain will be padded to have this many templates.
|
||||
|
||||
Returns:
|
||||
The list of chains, updated to have template features padded to
|
||||
max_templates.
|
||||
"""
|
||||
for chain in chains:
|
||||
for k, v in chain.items():
|
||||
if k in TEMPLATE_FEATURES:
|
||||
padding = np.zeros_like(v.shape)
|
||||
padding[0] = max_templates - v.shape[0]
|
||||
padding = [(0, p) for p in padding]
|
||||
chain[k] = np.pad(v, padding, mode='constant')
|
||||
return chains
|
||||
|
||||
|
||||
def _merge_features_from_multiple_chains(
|
||||
chains: Sequence[NumpyDict], pair_msa_sequences: bool) -> NumpyDict:
|
||||
"""Merge features from multiple chains.
|
||||
|
||||
Args:
|
||||
chains: A list of feature dictionaries that we want to merge.
|
||||
pair_msa_sequences: Whether to concatenate MSA features along the
|
||||
num_res dimension (if True), or to block diagonalize them (if False).
|
||||
|
||||
Returns:
|
||||
A feature dictionary for the merged example.
|
||||
"""
|
||||
merged_example = {}
|
||||
for feature_name in chains[0]:
|
||||
feats = [x[feature_name] for x in chains]
|
||||
feature_name_split = feature_name.split('_all_seq')[0]
|
||||
if feature_name_split in MSA_FEATURES:
|
||||
if pair_msa_sequences or '_all_seq' in feature_name:
|
||||
merged_example[feature_name] = np.concatenate(feats, axis=1)
|
||||
if feature_name_split == 'msa':
|
||||
merged_example['msa_chains_all_seq'] = np.ones(
|
||||
merged_example[feature_name].shape[0]).reshape(-1, 1)
|
||||
else:
|
||||
merged_example[feature_name] = block_diag(
|
||||
*feats, pad_value=MSA_PAD_VALUES[feature_name])
|
||||
if feature_name_split == 'msa':
|
||||
msa_chains = []
|
||||
for i, feat in enumerate(feats):
|
||||
cur_shape = feat.shape[0]
|
||||
vals = np.ones(cur_shape) * (i + 2)
|
||||
msa_chains.append(vals)
|
||||
merged_example['msa_chains'] = np.concatenate(
|
||||
msa_chains).reshape(-1, 1)
|
||||
elif feature_name_split in SEQ_FEATURES:
|
||||
merged_example[feature_name] = np.concatenate(feats, axis=0)
|
||||
elif feature_name_split in TEMPLATE_FEATURES:
|
||||
merged_example[feature_name] = np.concatenate(feats, axis=1)
|
||||
elif feature_name_split in CHAIN_FEATURES:
|
||||
merged_example[feature_name] = np.sum(feats).astype(np.int32)
|
||||
else:
|
||||
merged_example[feature_name] = feats[0]
|
||||
return merged_example
|
||||
|
||||
|
||||
def _merge_homomers_dense_msa(
|
||||
chains: Iterable[NumpyDict]) -> Sequence[NumpyDict]:
|
||||
"""Merge all identical chains, making the resulting MSA dense.
|
||||
|
||||
Args:
|
||||
chains: An iterable of features for each chain.
|
||||
|
||||
Returns:
|
||||
A list of feature dictionaries. All features with the same entity_id
|
||||
will be merged - MSA features will be concatenated along the num_res
|
||||
dimension - making them dense.
|
||||
"""
|
||||
entity_chains = collections.defaultdict(list)
|
||||
for chain in chains:
|
||||
entity_id = chain['entity_id'][0]
|
||||
entity_chains[entity_id].append(chain)
|
||||
|
||||
grouped_chains = []
|
||||
for entity_id in sorted(entity_chains):
|
||||
chains = entity_chains[entity_id]
|
||||
grouped_chains.append(chains)
|
||||
chains = [
|
||||
_merge_features_from_multiple_chains(chains, pair_msa_sequences=True)
|
||||
for chains in grouped_chains
|
||||
]
|
||||
return chains
|
||||
|
||||
|
||||
def _concatenate_paired_and_unpaired_features(example: NumpyDict) -> NumpyDict:
|
||||
"""Merges paired and block-diagonalised features."""
|
||||
features = MSA_FEATURES + ('msa_chains', )
|
||||
for feature_name in features:
|
||||
if feature_name in example:
|
||||
feat = example[feature_name]
|
||||
feat_all_seq = example[feature_name + '_all_seq']
|
||||
try:
|
||||
merged_feat = np.concatenate([feat_all_seq, feat], axis=0)
|
||||
except Exception as ex:
|
||||
raise Exception(
|
||||
'concat failed.',
|
||||
feature_name,
|
||||
feat_all_seq.shape,
|
||||
feat.shape,
|
||||
ex.__class__,
|
||||
ex,
|
||||
)
|
||||
example[feature_name] = merged_feat
|
||||
example['num_alignments'] = np.array(
|
||||
example['msa'].shape[0], dtype=np.int32)
|
||||
return example
|
||||
|
||||
|
||||
def merge_chain_features(np_chains_list: List[NumpyDict],
|
||||
pair_msa_sequences: bool,
|
||||
max_templates: int) -> NumpyDict:
|
||||
"""Merges features for multiple chains to single FeatureDict.
|
||||
|
||||
Args:
|
||||
np_chains_list: List of FeatureDicts for each chain.
|
||||
pair_msa_sequences: Whether to merge paired MSAs.
|
||||
max_templates: The maximum number of templates to include.
|
||||
|
||||
Returns:
|
||||
Single FeatureDict for entire complex.
|
||||
"""
|
||||
np_chains_list = _pad_templates(
|
||||
np_chains_list, max_templates=max_templates)
|
||||
np_chains_list = _merge_homomers_dense_msa(np_chains_list)
|
||||
# Unpaired MSA features will be always block-diagonalised; paired MSA
|
||||
# features will be concatenated.
|
||||
np_example = _merge_features_from_multiple_chains(
|
||||
np_chains_list, pair_msa_sequences=False)
|
||||
if pair_msa_sequences:
|
||||
np_example = _concatenate_paired_and_unpaired_features(np_example)
|
||||
np_example = _correct_post_merged_feats(
|
||||
np_example=np_example,
|
||||
np_chains_list=np_chains_list,
|
||||
pair_msa_sequences=pair_msa_sequences,
|
||||
)
|
||||
|
||||
return np_example
|
||||
|
||||
|
||||
def deduplicate_unpaired_sequences(
|
||||
np_chains: List[NumpyDict]) -> List[NumpyDict]:
|
||||
"""Removes unpaired sequences which duplicate a paired sequence."""
|
||||
|
||||
feature_names = np_chains[0].keys()
|
||||
msa_features = MSA_FEATURES
|
||||
cache_msa_features = {}
|
||||
for chain in np_chains:
|
||||
entity_id = int(chain['entity_id'][0])
|
||||
if entity_id not in cache_msa_features:
|
||||
sequence_set = set(s.tobytes() for s in chain['msa_all_seq'])
|
||||
keep_rows = []
|
||||
# Go through unpaired MSA seqs and remove any rows that correspond to the
|
||||
# sequences that are already present in the paired MSA.
|
||||
for row_num, seq in enumerate(chain['msa']):
|
||||
if seq.tobytes() not in sequence_set:
|
||||
keep_rows.append(row_num)
|
||||
new_msa_features = {}
|
||||
for feature_name in feature_names:
|
||||
if feature_name in msa_features:
|
||||
if keep_rows:
|
||||
new_msa_features[feature_name] = chain[feature_name][
|
||||
keep_rows]
|
||||
else:
|
||||
new_shape = list(chain[feature_name].shape)
|
||||
new_shape[0] = 0
|
||||
new_msa_features[feature_name] = np.zeros(
|
||||
new_shape, dtype=chain[feature_name].dtype)
|
||||
cache_msa_features[entity_id] = new_msa_features
|
||||
for feature_name in cache_msa_features[entity_id]:
|
||||
chain[feature_name] = cache_msa_features[entity_id][feature_name]
|
||||
chain['num_alignments'] = np.array(
|
||||
chain['msa'].shape[0], dtype=np.int32)
|
||||
return np_chains
|
||||
264
modelscope/models/science/unifold/data/process.py
Normal file
264
modelscope/models/science/unifold/data/process.py
Normal file
@@ -0,0 +1,264 @@
|
||||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
|
||||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.models.science.unifold.data import data_ops
|
||||
|
||||
|
||||
def nonensembled_fns(common_cfg, mode_cfg):
|
||||
"""Input pipeline data transformers that are not ensembled."""
|
||||
v2_feature = common_cfg.v2_feature
|
||||
operators = []
|
||||
if mode_cfg.random_delete_msa:
|
||||
operators.append(
|
||||
data_ops.random_delete_msa(common_cfg.random_delete_msa))
|
||||
operators.extend([
|
||||
data_ops.cast_to_64bit_ints,
|
||||
data_ops.correct_msa_restypes,
|
||||
data_ops.squeeze_features,
|
||||
data_ops.randomly_replace_msa_with_unknown(0.0),
|
||||
data_ops.make_seq_mask,
|
||||
data_ops.make_msa_mask,
|
||||
])
|
||||
operators.append(data_ops.make_hhblits_profile_v2
|
||||
if v2_feature else data_ops.make_hhblits_profile)
|
||||
if common_cfg.use_templates:
|
||||
operators.extend([
|
||||
data_ops.make_template_mask,
|
||||
data_ops.make_pseudo_beta('template_'),
|
||||
])
|
||||
operators.append(
|
||||
data_ops.crop_templates(
|
||||
max_templates=mode_cfg.max_templates,
|
||||
subsample_templates=mode_cfg.subsample_templates,
|
||||
))
|
||||
|
||||
if common_cfg.use_template_torsion_angles:
|
||||
operators.extend([
|
||||
data_ops.atom37_to_torsion_angles('template_'),
|
||||
])
|
||||
|
||||
operators.append(data_ops.make_atom14_masks)
|
||||
operators.append(data_ops.make_target_feat)
|
||||
|
||||
return operators
|
||||
|
||||
|
||||
def crop_and_fix_size_fns(common_cfg, mode_cfg, crop_and_fix_size_seed):
|
||||
operators = []
|
||||
if common_cfg.reduce_msa_clusters_by_max_templates:
|
||||
pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
|
||||
else:
|
||||
pad_msa_clusters = mode_cfg.max_msa_clusters
|
||||
crop_feats = dict(common_cfg.features)
|
||||
if mode_cfg.fixed_size:
|
||||
if mode_cfg.crop:
|
||||
if common_cfg.is_multimer:
|
||||
crop_fn = data_ops.crop_to_size_multimer(
|
||||
crop_size=mode_cfg.crop_size,
|
||||
shape_schema=crop_feats,
|
||||
seed=crop_and_fix_size_seed,
|
||||
spatial_crop_prob=mode_cfg.spatial_crop_prob,
|
||||
ca_ca_threshold=mode_cfg.ca_ca_threshold,
|
||||
)
|
||||
else:
|
||||
crop_fn = data_ops.crop_to_size_single(
|
||||
crop_size=mode_cfg.crop_size,
|
||||
shape_schema=crop_feats,
|
||||
seed=crop_and_fix_size_seed,
|
||||
)
|
||||
operators.append(crop_fn)
|
||||
|
||||
operators.append(data_ops.select_feat(crop_feats))
|
||||
|
||||
operators.append(
|
||||
data_ops.make_fixed_size(
|
||||
crop_feats,
|
||||
pad_msa_clusters,
|
||||
common_cfg.max_extra_msa,
|
||||
mode_cfg.crop_size,
|
||||
mode_cfg.max_templates,
|
||||
))
|
||||
return operators
|
||||
|
||||
|
||||
def ensembled_fns(common_cfg, mode_cfg):
|
||||
"""Input pipeline data transformers that can be ensembled and averaged."""
|
||||
operators = []
|
||||
multimer_mode = common_cfg.is_multimer
|
||||
v2_feature = common_cfg.v2_feature
|
||||
# multimer don't use block delete msa
|
||||
if mode_cfg.block_delete_msa and not multimer_mode:
|
||||
operators.append(
|
||||
data_ops.block_delete_msa(common_cfg.block_delete_msa))
|
||||
if 'max_distillation_msa_clusters' in mode_cfg:
|
||||
operators.append(
|
||||
data_ops.sample_msa_distillation(
|
||||
mode_cfg.max_distillation_msa_clusters))
|
||||
|
||||
if common_cfg.reduce_msa_clusters_by_max_templates:
|
||||
pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
|
||||
else:
|
||||
pad_msa_clusters = mode_cfg.max_msa_clusters
|
||||
|
||||
max_msa_clusters = pad_msa_clusters
|
||||
max_extra_msa = common_cfg.max_extra_msa
|
||||
|
||||
assert common_cfg.resample_msa_in_recycling
|
||||
gumbel_sample = common_cfg.gumbel_sample
|
||||
operators.append(
|
||||
data_ops.sample_msa(
|
||||
max_msa_clusters,
|
||||
keep_extra=True,
|
||||
gumbel_sample=gumbel_sample,
|
||||
biased_msa_by_chain=mode_cfg.biased_msa_by_chain,
|
||||
))
|
||||
|
||||
if 'masked_msa' in common_cfg:
|
||||
# Masked MSA should come *before* MSA clustering so that
|
||||
# the clustering and full MSA profile do not leak information about
|
||||
# the masked locations and secret corrupted locations.
|
||||
operators.append(
|
||||
data_ops.make_masked_msa(
|
||||
common_cfg.masked_msa,
|
||||
mode_cfg.masked_msa_replace_fraction,
|
||||
gumbel_sample=gumbel_sample,
|
||||
share_mask=mode_cfg.share_mask,
|
||||
))
|
||||
|
||||
if common_cfg.msa_cluster_features:
|
||||
if v2_feature:
|
||||
operators.append(data_ops.nearest_neighbor_clusters_v2())
|
||||
else:
|
||||
operators.append(data_ops.nearest_neighbor_clusters())
|
||||
operators.append(data_ops.summarize_clusters)
|
||||
|
||||
if v2_feature:
|
||||
operators.append(data_ops.make_msa_feat_v2)
|
||||
else:
|
||||
operators.append(data_ops.make_msa_feat)
|
||||
# Crop after creating the cluster profiles.
|
||||
if max_extra_msa:
|
||||
if v2_feature:
|
||||
operators.append(data_ops.make_extra_msa_feat(max_extra_msa))
|
||||
else:
|
||||
operators.append(data_ops.crop_extra_msa(max_extra_msa))
|
||||
else:
|
||||
operators.append(data_ops.delete_extra_msa)
|
||||
# operators.append(data_operators.select_feat(common_cfg.recycling_features))
|
||||
return operators
|
||||
|
||||
|
||||
def process_features(tensors, common_cfg, mode_cfg):
|
||||
"""Based on the config, apply filters and transformations to the data."""
|
||||
is_distillation = bool(tensors.get('is_distillation', 0))
|
||||
multimer_mode = common_cfg.is_multimer
|
||||
crop_and_fix_size_seed = int(tensors['crop_and_fix_size_seed'])
|
||||
crop_fn = crop_and_fix_size_fns(
|
||||
common_cfg,
|
||||
mode_cfg,
|
||||
crop_and_fix_size_seed,
|
||||
)
|
||||
|
||||
def wrap_ensemble_fn(data, i):
|
||||
"""Function to be mapped over the ensemble dimension."""
|
||||
d = data.copy()
|
||||
fns = ensembled_fns(
|
||||
common_cfg,
|
||||
mode_cfg,
|
||||
)
|
||||
new_d = compose(fns)(d)
|
||||
if not multimer_mode or is_distillation:
|
||||
new_d = data_ops.select_feat(common_cfg.recycling_features)(new_d)
|
||||
return compose(crop_fn)(new_d)
|
||||
else: # select after crop for spatial cropping
|
||||
d = compose(crop_fn)(d)
|
||||
d = data_ops.select_feat(common_cfg.recycling_features)(d)
|
||||
return d
|
||||
|
||||
nonensembled = nonensembled_fns(common_cfg, mode_cfg)
|
||||
|
||||
if mode_cfg.supervised and (not multimer_mode or is_distillation):
|
||||
nonensembled.extend(label_transform_fn())
|
||||
|
||||
tensors = compose(nonensembled)(tensors)
|
||||
|
||||
num_recycling = int(tensors['num_recycling_iters']) + 1
|
||||
num_ensembles = mode_cfg.num_ensembles
|
||||
|
||||
ensemble_tensors = map_fn(
|
||||
lambda x: wrap_ensemble_fn(tensors, x),
|
||||
torch.arange(num_recycling * num_ensembles),
|
||||
)
|
||||
tensors = compose(crop_fn)(tensors)
|
||||
# add a dummy dim to align with recycling features
|
||||
tensors = {k: torch.stack([tensors[k]], dim=0) for k in tensors}
|
||||
tensors.update(ensemble_tensors)
|
||||
return tensors
|
||||
|
||||
|
||||
@data_ops.curry1
|
||||
def compose(x, fs):
|
||||
for f in fs:
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
|
||||
def pad_then_stack(values, ):
|
||||
if len(values[0].shape) >= 1:
|
||||
size = max(v.shape[0] for v in values)
|
||||
new_values = []
|
||||
for v in values:
|
||||
if v.shape[0] < size:
|
||||
res = values[0].new_zeros(size, *v.shape[1:])
|
||||
res[:v.shape[0], ...] = v
|
||||
else:
|
||||
res = v
|
||||
new_values.append(res)
|
||||
else:
|
||||
new_values = values
|
||||
return torch.stack(new_values, dim=0)
|
||||
|
||||
|
||||
def map_fn(fun, x):
|
||||
ensembles = [fun(elem) for elem in x]
|
||||
features = ensembles[0].keys()
|
||||
ensembled_dict = {}
|
||||
for feat in features:
|
||||
ensembled_dict[feat] = pad_then_stack(
|
||||
[dict_i[feat] for dict_i in ensembles])
|
||||
return ensembled_dict
|
||||
|
||||
|
||||
def process_single_label(label: dict,
|
||||
num_ensemble: Optional[int] = None) -> dict:
|
||||
assert 'aatype' in label
|
||||
assert 'all_atom_positions' in label
|
||||
assert 'all_atom_mask' in label
|
||||
label = compose(label_transform_fn())(label)
|
||||
if num_ensemble is not None:
|
||||
label = {
|
||||
k: torch.stack([v for _ in range(num_ensemble)])
|
||||
for k, v in label.items()
|
||||
}
|
||||
return label
|
||||
|
||||
|
||||
def process_labels(labels_list, num_ensemble: Optional[int] = None):
|
||||
return [process_single_label(ll, num_ensemble) for ll in labels_list]
|
||||
|
||||
|
||||
def label_transform_fn():
|
||||
return [
|
||||
data_ops.make_atom14_masks,
|
||||
data_ops.make_atom14_positions,
|
||||
data_ops.atom37_to_frames,
|
||||
data_ops.atom37_to_torsion_angles(''),
|
||||
data_ops.make_pseudo_beta(''),
|
||||
data_ops.get_backbone_frames,
|
||||
data_ops.get_chi_angles,
|
||||
]
|
||||
417
modelscope/models/science/unifold/data/process_multimer.py
Normal file
417
modelscope/models/science/unifold/data/process_multimer.py
Normal file
@@ -0,0 +1,417 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Feature processing logic for multimer data """
|
||||
|
||||
import collections
|
||||
from typing import Iterable, List, MutableMapping
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modelscope.models.science.unifold.data import (msa_pairing,
|
||||
residue_constants)
|
||||
from .utils import correct_template_restypes
|
||||
|
||||
FeatureDict = MutableMapping[str, np.ndarray]
|
||||
|
||||
REQUIRED_FEATURES = frozenset({
|
||||
'aatype',
|
||||
'all_atom_mask',
|
||||
'all_atom_positions',
|
||||
'all_chains_entity_ids',
|
||||
'all_crops_all_chains_mask',
|
||||
'all_crops_all_chains_positions',
|
||||
'all_crops_all_chains_residue_ids',
|
||||
'assembly_num_chains',
|
||||
'asym_id',
|
||||
'bert_mask',
|
||||
'cluster_bias_mask',
|
||||
'deletion_matrix',
|
||||
'deletion_mean',
|
||||
'entity_id',
|
||||
'entity_mask',
|
||||
'mem_peak',
|
||||
'msa',
|
||||
'msa_mask',
|
||||
'num_alignments',
|
||||
'num_templates',
|
||||
'queue_size',
|
||||
'residue_index',
|
||||
'resolution',
|
||||
'seq_length',
|
||||
'seq_mask',
|
||||
'sym_id',
|
||||
'template_aatype',
|
||||
'template_all_atom_mask',
|
||||
'template_all_atom_positions',
|
||||
# zy added:
|
||||
'asym_len',
|
||||
'template_sum_probs',
|
||||
'num_sym',
|
||||
'msa_chains',
|
||||
})
|
||||
|
||||
MAX_TEMPLATES = 4
|
||||
MSA_CROP_SIZE = 2048
|
||||
|
||||
|
||||
def _is_homomer_or_monomer(chains: Iterable[FeatureDict]) -> bool:
|
||||
"""Checks if a list of chains represents a homomer/monomer example."""
|
||||
# Note that an entity_id of 0 indicates padding.
|
||||
num_unique_chains = len(
|
||||
np.unique(
|
||||
np.concatenate([
|
||||
np.unique(chain['entity_id'][chain['entity_id'] > 0])
|
||||
for chain in chains
|
||||
])))
|
||||
return num_unique_chains == 1
|
||||
|
||||
|
||||
def pair_and_merge(
|
||||
all_chain_features: MutableMapping[str, FeatureDict]) -> FeatureDict:
|
||||
"""Runs processing on features to augment, pair and merge.
|
||||
|
||||
Args:
|
||||
all_chain_features: A MutableMap of dictionaries of features for each chain.
|
||||
|
||||
Returns:
|
||||
A dictionary of features.
|
||||
"""
|
||||
|
||||
process_unmerged_features(all_chain_features)
|
||||
|
||||
np_chains_list = all_chain_features
|
||||
|
||||
pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list)
|
||||
|
||||
if pair_msa_sequences:
|
||||
np_chains_list = msa_pairing.create_paired_features(
|
||||
chains=np_chains_list)
|
||||
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(
|
||||
np_chains_list)
|
||||
np_chains_list = crop_chains(
|
||||
np_chains_list,
|
||||
msa_crop_size=MSA_CROP_SIZE,
|
||||
pair_msa_sequences=pair_msa_sequences,
|
||||
max_templates=MAX_TEMPLATES,
|
||||
)
|
||||
np_example = msa_pairing.merge_chain_features(
|
||||
np_chains_list=np_chains_list,
|
||||
pair_msa_sequences=pair_msa_sequences,
|
||||
max_templates=MAX_TEMPLATES,
|
||||
)
|
||||
np_example = process_final(np_example)
|
||||
return np_example
|
||||
|
||||
|
||||
def crop_chains(
|
||||
chains_list: List[FeatureDict],
|
||||
msa_crop_size: int,
|
||||
pair_msa_sequences: bool,
|
||||
max_templates: int,
|
||||
) -> List[FeatureDict]:
|
||||
"""Crops the MSAs for a set of chains.
|
||||
|
||||
Args:
|
||||
chains_list: A list of chains to be cropped.
|
||||
msa_crop_size: The total number of sequences to crop from the MSA.
|
||||
pair_msa_sequences: Whether we are operating in sequence-pairing mode.
|
||||
max_templates: The maximum templates to use per chain.
|
||||
|
||||
Returns:
|
||||
The chains cropped.
|
||||
"""
|
||||
|
||||
# Apply the cropping.
|
||||
cropped_chains = []
|
||||
for chain in chains_list:
|
||||
cropped_chain = _crop_single_chain(
|
||||
chain,
|
||||
msa_crop_size=msa_crop_size,
|
||||
pair_msa_sequences=pair_msa_sequences,
|
||||
max_templates=max_templates,
|
||||
)
|
||||
cropped_chains.append(cropped_chain)
|
||||
|
||||
return cropped_chains
|
||||
|
||||
|
||||
def _crop_single_chain(chain: FeatureDict, msa_crop_size: int,
|
||||
pair_msa_sequences: bool,
|
||||
max_templates: int) -> FeatureDict:
|
||||
"""Crops msa sequences to `msa_crop_size`."""
|
||||
msa_size = chain['num_alignments']
|
||||
|
||||
if pair_msa_sequences:
|
||||
msa_size_all_seq = chain['num_alignments_all_seq']
|
||||
msa_crop_size_all_seq = np.minimum(msa_size_all_seq,
|
||||
msa_crop_size // 2)
|
||||
|
||||
# We reduce the number of un-paired sequences, by the number of times a
|
||||
# sequence from this chain's MSA is included in the paired MSA. This keeps
|
||||
# the MSA size for each chain roughly constant.
|
||||
msa_all_seq = chain['msa_all_seq'][:msa_crop_size_all_seq, :]
|
||||
num_non_gapped_pairs = np.sum(
|
||||
np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1))
|
||||
num_non_gapped_pairs = np.minimum(num_non_gapped_pairs,
|
||||
msa_crop_size_all_seq)
|
||||
|
||||
# Restrict the unpaired crop size so that paired+unpaired sequences do not
|
||||
# exceed msa_seqs_per_chain for each chain.
|
||||
max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0)
|
||||
msa_crop_size = np.minimum(msa_size, max_msa_crop_size)
|
||||
else:
|
||||
msa_crop_size = np.minimum(msa_size, msa_crop_size)
|
||||
|
||||
include_templates = 'template_aatype' in chain and max_templates
|
||||
if include_templates:
|
||||
num_templates = chain['template_aatype'].shape[0]
|
||||
templates_crop_size = np.minimum(num_templates, max_templates)
|
||||
|
||||
for k in chain:
|
||||
k_split = k.split('_all_seq')[0]
|
||||
if k_split in msa_pairing.TEMPLATE_FEATURES:
|
||||
chain[k] = chain[k][:templates_crop_size, :]
|
||||
elif k_split in msa_pairing.MSA_FEATURES:
|
||||
if '_all_seq' in k and pair_msa_sequences:
|
||||
chain[k] = chain[k][:msa_crop_size_all_seq, :]
|
||||
else:
|
||||
chain[k] = chain[k][:msa_crop_size, :]
|
||||
|
||||
chain['num_alignments'] = np.asarray(msa_crop_size, dtype=np.int32)
|
||||
if include_templates:
|
||||
chain['num_templates'] = np.asarray(
|
||||
templates_crop_size, dtype=np.int32)
|
||||
if pair_msa_sequences:
|
||||
chain['num_alignments_all_seq'] = np.asarray(
|
||||
msa_crop_size_all_seq, dtype=np.int32)
|
||||
return chain
|
||||
|
||||
|
||||
def process_final(np_example: FeatureDict) -> FeatureDict:
|
||||
"""Final processing steps in data pipeline, after merging and pairing."""
|
||||
np_example = _make_seq_mask(np_example)
|
||||
np_example = _make_msa_mask(np_example)
|
||||
np_example = _filter_features(np_example)
|
||||
return np_example
|
||||
|
||||
|
||||
def _make_seq_mask(np_example):
|
||||
np_example['seq_mask'] = (np_example['entity_id'] > 0).astype(np.float32)
|
||||
return np_example
|
||||
|
||||
|
||||
def _make_msa_mask(np_example):
|
||||
"""Mask features are all ones, but will later be zero-padded."""
|
||||
|
||||
np_example['msa_mask'] = np.ones_like(np_example['msa'], dtype=np.int8)
|
||||
|
||||
seq_mask = (np_example['entity_id'] > 0).astype(np.int8)
|
||||
np_example['msa_mask'] *= seq_mask[None]
|
||||
|
||||
return np_example
|
||||
|
||||
|
||||
def _filter_features(np_example: FeatureDict) -> FeatureDict:
|
||||
"""Filters features of example to only those requested."""
|
||||
return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES}
|
||||
|
||||
|
||||
def process_unmerged_features(all_chain_features: MutableMapping[str,
|
||||
FeatureDict]):
|
||||
"""Postprocessing stage for per-chain features before merging."""
|
||||
num_chains = len(all_chain_features)
|
||||
for chain_features in all_chain_features:
|
||||
# Convert deletion matrices to float.
|
||||
if 'deletion_matrix_int' in chain_features:
|
||||
chain_features['deletion_matrix'] = np.asarray(
|
||||
chain_features.pop('deletion_matrix_int'), dtype=np.float32)
|
||||
if 'deletion_matrix_int_all_seq' in chain_features:
|
||||
chain_features['deletion_matrix_all_seq'] = np.asarray(
|
||||
chain_features.pop('deletion_matrix_int_all_seq'),
|
||||
dtype=np.float32)
|
||||
|
||||
chain_features['deletion_mean'] = np.mean(
|
||||
chain_features['deletion_matrix'], axis=0)
|
||||
|
||||
if 'all_atom_positions' not in chain_features:
|
||||
# Add all_atom_mask and dummy all_atom_positions based on aatype.
|
||||
all_atom_mask = residue_constants.STANDARD_ATOM_MASK[
|
||||
chain_features['aatype']]
|
||||
chain_features['all_atom_mask'] = all_atom_mask
|
||||
chain_features['all_atom_positions'] = np.zeros(
|
||||
list(all_atom_mask.shape) + [3])
|
||||
|
||||
# Add assembly_num_chains.
|
||||
chain_features['assembly_num_chains'] = np.asarray(num_chains)
|
||||
|
||||
# Add entity_mask.
|
||||
for chain_features in all_chain_features:
|
||||
chain_features['entity_mask'] = (
|
||||
chain_features['entity_id'] != # noqa W504
|
||||
0).astype(np.int32)
|
||||
|
||||
|
||||
def empty_template_feats(n_res):
|
||||
return {
|
||||
'template_aatype':
|
||||
np.zeros((0, n_res)).astype(np.int64),
|
||||
'template_all_atom_positions':
|
||||
np.zeros((0, n_res, 37, 3)).astype(np.float32),
|
||||
'template_sum_probs':
|
||||
np.zeros((0, 1)).astype(np.float32),
|
||||
'template_all_atom_mask':
|
||||
np.zeros((0, n_res, 37)).astype(np.float32),
|
||||
}
|
||||
|
||||
|
||||
def convert_monomer_features(monomer_features: FeatureDict) -> FeatureDict:
|
||||
"""Reshapes and modifies monomer features for multimer models."""
|
||||
if monomer_features['template_aatype'].shape[0] == 0:
|
||||
monomer_features.update(
|
||||
empty_template_feats(monomer_features['aatype'].shape[0]))
|
||||
converted = {}
|
||||
unnecessary_leading_dim_feats = {
|
||||
'sequence',
|
||||
'domain_name',
|
||||
'num_alignments',
|
||||
'seq_length',
|
||||
}
|
||||
for feature_name, feature in monomer_features.items():
|
||||
if feature_name in unnecessary_leading_dim_feats:
|
||||
# asarray ensures it's a np.ndarray.
|
||||
feature = np.asarray(feature[0], dtype=feature.dtype)
|
||||
elif feature_name == 'aatype':
|
||||
# The multimer model performs the one-hot operation itself.
|
||||
feature = np.argmax(feature, axis=-1).astype(np.int32)
|
||||
elif feature_name == 'template_aatype':
|
||||
if feature.shape[0] > 0:
|
||||
feature = correct_template_restypes(feature)
|
||||
elif feature_name == 'template_all_atom_masks':
|
||||
feature_name = 'template_all_atom_mask'
|
||||
elif feature_name == 'msa':
|
||||
feature = feature.astype(np.uint8)
|
||||
|
||||
if feature_name.endswith('_mask'):
|
||||
feature = feature.astype(np.float32)
|
||||
|
||||
converted[feature_name] = feature
|
||||
|
||||
if 'deletion_matrix_int' in monomer_features:
|
||||
monomer_features['deletion_matrix'] = monomer_features.pop(
|
||||
'deletion_matrix_int').astype(np.float32)
|
||||
|
||||
converted.pop(
|
||||
'template_sum_probs'
|
||||
) # zy: this input is checked to be dirty in shape. TODO: figure out why and make it right.
|
||||
return converted
|
||||
|
||||
|
||||
def int_id_to_str_id(num: int) -> str:
|
||||
"""Encodes a number as a string, using reverse spreadsheet style naming.
|
||||
|
||||
Args:
|
||||
num: A positive integer.
|
||||
|
||||
Returns:
|
||||
A string that encodes the positive integer using reverse spreadsheet style,
|
||||
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
|
||||
usual way to encode chain IDs in mmCIF files.
|
||||
"""
|
||||
if num <= 0:
|
||||
raise ValueError(f'Only positive integers allowed, got {num}.')
|
||||
|
||||
num = num - 1 # 1-based indexing.
|
||||
output = []
|
||||
while num >= 0:
|
||||
output.append(chr(num % 26 + ord('A')))
|
||||
num = num // 26 - 1
|
||||
return ''.join(output)
|
||||
|
||||
|
||||
def add_assembly_features(all_chain_features, ):
|
||||
"""Add features to distinguish between chains.
|
||||
|
||||
Args:
|
||||
all_chain_features: A dictionary which maps chain_id to a dictionary of
|
||||
features for each chain.
|
||||
|
||||
Returns:
|
||||
all_chain_features: A dictionary which maps strings of the form
|
||||
`<seq_id>_<sym_id>` to the corresponding chain features. E.g. two
|
||||
chains from a homodimer would have keys A_1 and A_2. Two chains from a
|
||||
heterodimer would have keys A_1 and B_1.
|
||||
"""
|
||||
# Group the chains by sequence
|
||||
seq_to_entity_id = {}
|
||||
grouped_chains = collections.defaultdict(list)
|
||||
for chain_features in all_chain_features:
|
||||
assert 'sequence' in chain_features
|
||||
seq = str(chain_features['sequence'])
|
||||
if seq not in seq_to_entity_id:
|
||||
seq_to_entity_id[seq] = len(seq_to_entity_id) + 1
|
||||
grouped_chains[seq_to_entity_id[seq]].append(chain_features)
|
||||
|
||||
new_all_chain_features = []
|
||||
chain_id = 1
|
||||
for entity_id, group_chain_features in grouped_chains.items():
|
||||
num_sym = len(group_chain_features) # zy
|
||||
for sym_id, chain_features in enumerate(group_chain_features, start=1):
|
||||
seq_length = chain_features['seq_length']
|
||||
chain_features['asym_id'] = chain_id * np.ones(seq_length)
|
||||
chain_features['sym_id'] = sym_id * np.ones(seq_length)
|
||||
chain_features['entity_id'] = entity_id * np.ones(seq_length)
|
||||
chain_features['num_sym'] = num_sym * np.ones(seq_length)
|
||||
chain_id += 1
|
||||
new_all_chain_features.append(chain_features)
|
||||
|
||||
return new_all_chain_features
|
||||
|
||||
|
||||
def pad_msa(np_example, min_num_seq):
|
||||
np_example = dict(np_example)
|
||||
num_seq = np_example['msa'].shape[0]
|
||||
if num_seq < min_num_seq:
|
||||
for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask',
|
||||
'msa_chains'):
|
||||
np_example[feat] = np.pad(np_example[feat],
|
||||
((0, min_num_seq - num_seq), (0, 0)))
|
||||
np_example['cluster_bias_mask'] = np.pad(
|
||||
np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq), ))
|
||||
return np_example
|
||||
|
||||
|
||||
def post_process(np_example):
|
||||
np_example = pad_msa(np_example, 512)
|
||||
no_dim_keys = [
|
||||
'num_alignments',
|
||||
'assembly_num_chains',
|
||||
'num_templates',
|
||||
'seq_length',
|
||||
'resolution',
|
||||
]
|
||||
for k in no_dim_keys:
|
||||
if k in np_example:
|
||||
np_example[k] = np_example[k].reshape(-1)
|
||||
return np_example
|
||||
|
||||
|
||||
def merge_msas(msa, del_mat, new_msa, new_del_mat):
|
||||
cur_msa_set = set([tuple(m) for m in msa])
|
||||
new_rows = []
|
||||
for i, s in enumerate(new_msa):
|
||||
if tuple(s) not in cur_msa_set:
|
||||
new_rows.append(i)
|
||||
ret_msa = np.concatenate([msa, new_msa[new_rows]], axis=0)
|
||||
ret_del_mat = np.concatenate([del_mat, new_del_mat[new_rows]], axis=0)
|
||||
return ret_msa, ret_del_mat
|
||||
322
modelscope/models/science/unifold/data/protein.py
Normal file
322
modelscope/models/science/unifold/data/protein.py
Normal file
@@ -0,0 +1,322 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Protein data type."""
|
||||
import dataclasses
|
||||
import io
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
import numpy as np
|
||||
from Bio.PDB import PDBParser
|
||||
|
||||
from modelscope.models.science.unifold.data import residue_constants
|
||||
|
||||
FeatureDict = Mapping[str, np.ndarray]
|
||||
ModelOutput = Mapping[str, Any] # Is a nested dict.
|
||||
|
||||
# Complete sequence of chain IDs supported by the PDB format.
|
||||
PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
||||
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Protein:
|
||||
"""Protein structure representation."""
|
||||
|
||||
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
|
||||
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
|
||||
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
|
||||
|
||||
# Amino-acid type for each residue represented as an integer between 0 and
|
||||
# 20, where 20 is 'X'.
|
||||
aatype: np.ndarray # [num_res]
|
||||
|
||||
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
|
||||
# is present and 0.0 if not. This should be used for loss masking.
|
||||
atom_mask: np.ndarray # [num_res, num_atom_type]
|
||||
|
||||
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
|
||||
residue_index: np.ndarray # [num_res]
|
||||
|
||||
# 0-indexed number corresponding to the chain in the protein that this residue
|
||||
# belongs to.
|
||||
chain_index: np.ndarray # [num_res]
|
||||
|
||||
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
|
||||
# representing the displacement of the residue from its ground truth mean
|
||||
# value.
|
||||
b_factors: np.ndarray # [num_res, num_atom_type]
|
||||
|
||||
def __post_init__(self):
|
||||
if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
|
||||
raise ValueError(
|
||||
f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains '
|
||||
'because these cannot be written to PDB format.')
|
||||
|
||||
|
||||
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
|
||||
"""Takes a PDB string and constructs a Protein object.
|
||||
|
||||
WARNING: All non-standard residue types will be converted into UNK. All
|
||||
non-standard atoms will be ignored.
|
||||
|
||||
Args:
|
||||
pdb_str: The contents of the pdb file
|
||||
chain_id: If chain_id is specified (e.g. A), then only that chain
|
||||
is parsed. Otherwise all chains are parsed.
|
||||
|
||||
Returns:
|
||||
A new `Protein` parsed from the pdb contents.
|
||||
"""
|
||||
pdb_fh = io.StringIO(pdb_str)
|
||||
parser = PDBParser(QUIET=True)
|
||||
structure = parser.get_structure('none', pdb_fh)
|
||||
models = list(structure.get_models())
|
||||
if len(models) != 1:
|
||||
raise ValueError(
|
||||
f'Only single model PDBs are supported. Found {len(models)} models.'
|
||||
)
|
||||
model = models[0]
|
||||
|
||||
atom_positions = []
|
||||
aatype = []
|
||||
atom_mask = []
|
||||
residue_index = []
|
||||
chain_ids = []
|
||||
b_factors = []
|
||||
|
||||
for chain in model:
|
||||
if chain_id is not None and chain.id != chain_id:
|
||||
continue
|
||||
for res in chain:
|
||||
if res.id[2] != ' ':
|
||||
raise ValueError(
|
||||
f'PDB contains an insertion code at chain {chain.id} and residue '
|
||||
f'index {res.id[1]}. These are not supported.')
|
||||
res_shortname = residue_constants.restype_3to1.get(
|
||||
res.resname, 'X')
|
||||
restype_idx = residue_constants.restype_order.get(
|
||||
res_shortname, residue_constants.restype_num)
|
||||
pos = np.zeros((residue_constants.atom_type_num, 3))
|
||||
mask = np.zeros((residue_constants.atom_type_num, ))
|
||||
res_b_factors = np.zeros((residue_constants.atom_type_num, ))
|
||||
for atom in res:
|
||||
if atom.name not in residue_constants.atom_types:
|
||||
continue
|
||||
pos[residue_constants.atom_order[atom.name]] = atom.coord
|
||||
mask[residue_constants.atom_order[atom.name]] = 1.0
|
||||
res_b_factors[residue_constants.atom_order[
|
||||
atom.name]] = atom.bfactor
|
||||
if np.sum(mask) < 0.5:
|
||||
# If no known atom positions are reported for the residue then skip it.
|
||||
continue
|
||||
aatype.append(restype_idx)
|
||||
atom_positions.append(pos)
|
||||
atom_mask.append(mask)
|
||||
residue_index.append(res.id[1])
|
||||
chain_ids.append(chain.id)
|
||||
b_factors.append(res_b_factors)
|
||||
|
||||
# Chain IDs are usually characters so map these to ints.
|
||||
unique_chain_ids = np.unique(chain_ids)
|
||||
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
|
||||
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
|
||||
|
||||
return Protein(
|
||||
atom_positions=np.array(atom_positions),
|
||||
atom_mask=np.array(atom_mask),
|
||||
aatype=np.array(aatype),
|
||||
residue_index=np.array(residue_index),
|
||||
chain_index=chain_index,
|
||||
b_factors=np.array(b_factors),
|
||||
)
|
||||
|
||||
|
||||
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
|
||||
chain_end = 'TER'
|
||||
return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
|
||||
f'{chain_name:>1}{residue_index:>4}')
|
||||
|
||||
|
||||
def to_pdb(prot: Protein) -> str:
|
||||
"""Converts a `Protein` instance to a PDB string.
|
||||
|
||||
Args:
|
||||
prot: The protein to convert to PDB.
|
||||
|
||||
Returns:
|
||||
PDB string.
|
||||
"""
|
||||
restypes = residue_constants.restypes + ['X']
|
||||
|
||||
# res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK')
|
||||
def res_1to3(r):
|
||||
return residue_constants.restype_1to3.get(restypes[r], 'UNK')
|
||||
|
||||
atom_types = residue_constants.atom_types
|
||||
|
||||
pdb_lines = []
|
||||
|
||||
atom_mask = prot.atom_mask
|
||||
aatype = prot.aatype
|
||||
atom_positions = prot.atom_positions
|
||||
residue_index = prot.residue_index.astype(np.int32)
|
||||
chain_index = prot.chain_index.astype(np.int32)
|
||||
b_factors = prot.b_factors
|
||||
|
||||
if np.any(aatype > residue_constants.restype_num):
|
||||
raise ValueError('Invalid aatypes.')
|
||||
|
||||
# Construct a mapping from chain integer indices to chain ID strings.
|
||||
chain_ids = {}
|
||||
for i in np.unique(chain_index): # np.unique gives sorted output.
|
||||
if i >= PDB_MAX_CHAINS:
|
||||
raise ValueError(
|
||||
f'The PDB format supports at most {PDB_MAX_CHAINS} chains.')
|
||||
chain_ids[i] = PDB_CHAIN_IDS[i]
|
||||
|
||||
pdb_lines.append('MODEL 1')
|
||||
atom_index = 1
|
||||
last_chain_index = chain_index[0]
|
||||
# Add all atom sites.
|
||||
for i in range(aatype.shape[0]):
|
||||
# Close the previous chain if in a multichain PDB.
|
||||
if last_chain_index != chain_index[i]:
|
||||
pdb_lines.append(
|
||||
_chain_end(
|
||||
atom_index,
|
||||
res_1to3(aatype[i - 1]),
|
||||
chain_ids[chain_index[i - 1]],
|
||||
residue_index[i - 1],
|
||||
))
|
||||
last_chain_index = chain_index[i]
|
||||
atom_index += 1 # Atom index increases at the TER symbol.
|
||||
|
||||
res_name_3 = res_1to3(aatype[i])
|
||||
for atom_name, pos, mask, b_factor in zip(atom_types,
|
||||
atom_positions[i],
|
||||
atom_mask[i], b_factors[i]):
|
||||
if mask < 0.5:
|
||||
continue
|
||||
|
||||
record_type = 'ATOM'
|
||||
name = atom_name if len(atom_name) == 4 else f' {atom_name}'
|
||||
alt_loc = ''
|
||||
insertion_code = ''
|
||||
occupancy = 1.00
|
||||
element = atom_name[
|
||||
0] # Protein supports only C, N, O, S, this works.
|
||||
charge = ''
|
||||
# PDB is a columnar format, every space matters here!
|
||||
atom_line = (
|
||||
f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
|
||||
f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}'
|
||||
f'{residue_index[i]:>4}{insertion_code:>1} '
|
||||
f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
|
||||
f'{occupancy:>6.2f}{b_factor:>6.2f} '
|
||||
f'{element:>2}{charge:>2}')
|
||||
pdb_lines.append(atom_line)
|
||||
atom_index += 1
|
||||
|
||||
# Close the final chain.
|
||||
pdb_lines.append(
|
||||
_chain_end(
|
||||
atom_index,
|
||||
res_1to3(aatype[-1]),
|
||||
chain_ids[chain_index[-1]],
|
||||
residue_index[-1],
|
||||
))
|
||||
pdb_lines.append('ENDMDL')
|
||||
pdb_lines.append('END')
|
||||
|
||||
# Pad all lines to 80 characters.
|
||||
pdb_lines = [line.ljust(80) for line in pdb_lines]
|
||||
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
|
||||
|
||||
|
||||
def ideal_atom_mask(prot: Protein) -> np.ndarray:
|
||||
"""Computes an ideal atom mask.
|
||||
|
||||
`Protein.atom_mask` typically is defined according to the atoms that are
|
||||
reported in the PDB. This function computes a mask according to heavy atoms
|
||||
that should be present in the given sequence of amino acids.
|
||||
|
||||
Args:
|
||||
prot: `Protein` whose fields are `numpy.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
An ideal atom mask.
|
||||
"""
|
||||
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
|
||||
|
||||
|
||||
def from_prediction(features: FeatureDict,
|
||||
result: ModelOutput,
|
||||
b_factors: Optional[np.ndarray] = None) -> Protein:
|
||||
"""Assembles a protein from a prediction.
|
||||
|
||||
Args:
|
||||
features: Dictionary holding model inputs.
|
||||
fold_output: Dictionary holding model outputs.
|
||||
b_factors: (Optional) B-factors to use for the protein.
|
||||
|
||||
Returns:
|
||||
A protein instance.
|
||||
"""
|
||||
|
||||
if 'asym_id' in features:
|
||||
chain_index = features['asym_id'] - 1
|
||||
else:
|
||||
chain_index = np.zeros_like((features['aatype']))
|
||||
|
||||
if b_factors is None:
|
||||
b_factors = np.zeros_like(result['final_atom_mask'])
|
||||
|
||||
return Protein(
|
||||
aatype=features['aatype'],
|
||||
atom_positions=result['final_atom_positions'],
|
||||
atom_mask=result['final_atom_mask'],
|
||||
residue_index=features['residue_index'] + 1,
|
||||
chain_index=chain_index,
|
||||
b_factors=b_factors,
|
||||
)
|
||||
|
||||
|
||||
def from_feature(features: FeatureDict,
|
||||
b_factors: Optional[np.ndarray] = None) -> Protein:
|
||||
"""Assembles a standard pdb from input atom positions & mask.
|
||||
|
||||
Args:
|
||||
features: Dictionary holding model inputs.
|
||||
b_factors: (Optional) B-factors to use for the protein.
|
||||
|
||||
Returns:
|
||||
A protein instance.
|
||||
"""
|
||||
|
||||
if 'asym_id' in features:
|
||||
chain_index = features['asym_id'] - 1
|
||||
else:
|
||||
chain_index = np.zeros_like((features['aatype']))
|
||||
|
||||
if b_factors is None:
|
||||
b_factors = np.zeros_like(features['all_atom_mask'])
|
||||
|
||||
return Protein(
|
||||
aatype=features['aatype'],
|
||||
atom_positions=features['all_atom_positions'],
|
||||
atom_mask=features['all_atom_mask'],
|
||||
residue_index=features['residue_index'] + 1,
|
||||
chain_index=chain_index,
|
||||
b_factors=b_factors,
|
||||
)
|
||||
1212
modelscope/models/science/unifold/data/residue_constants.py
Normal file
1212
modelscope/models/science/unifold/data/residue_constants.py
Normal file
File diff suppressed because it is too large
Load Diff
345
modelscope/models/science/unifold/data/stereo_chemical_props.txt
Normal file
345
modelscope/models/science/unifold/data/stereo_chemical_props.txt
Normal file
@@ -0,0 +1,345 @@
|
||||
Bond Residue Mean StdDev
|
||||
CA-CB ALA 1.520 0.021
|
||||
N-CA ALA 1.459 0.020
|
||||
CA-C ALA 1.525 0.026
|
||||
C-O ALA 1.229 0.019
|
||||
CA-CB ARG 1.535 0.022
|
||||
CB-CG ARG 1.521 0.027
|
||||
CG-CD ARG 1.515 0.025
|
||||
CD-NE ARG 1.460 0.017
|
||||
NE-CZ ARG 1.326 0.013
|
||||
CZ-NH1 ARG 1.326 0.013
|
||||
CZ-NH2 ARG 1.326 0.013
|
||||
N-CA ARG 1.459 0.020
|
||||
CA-C ARG 1.525 0.026
|
||||
C-O ARG 1.229 0.019
|
||||
CA-CB ASN 1.527 0.026
|
||||
CB-CG ASN 1.506 0.023
|
||||
CG-OD1 ASN 1.235 0.022
|
||||
CG-ND2 ASN 1.324 0.025
|
||||
N-CA ASN 1.459 0.020
|
||||
CA-C ASN 1.525 0.026
|
||||
C-O ASN 1.229 0.019
|
||||
CA-CB ASP 1.535 0.022
|
||||
CB-CG ASP 1.513 0.021
|
||||
CG-OD1 ASP 1.249 0.023
|
||||
CG-OD2 ASP 1.249 0.023
|
||||
N-CA ASP 1.459 0.020
|
||||
CA-C ASP 1.525 0.026
|
||||
C-O ASP 1.229 0.019
|
||||
CA-CB CYS 1.526 0.013
|
||||
CB-SG CYS 1.812 0.016
|
||||
N-CA CYS 1.459 0.020
|
||||
CA-C CYS 1.525 0.026
|
||||
C-O CYS 1.229 0.019
|
||||
CA-CB GLU 1.535 0.022
|
||||
CB-CG GLU 1.517 0.019
|
||||
CG-CD GLU 1.515 0.015
|
||||
CD-OE1 GLU 1.252 0.011
|
||||
CD-OE2 GLU 1.252 0.011
|
||||
N-CA GLU 1.459 0.020
|
||||
CA-C GLU 1.525 0.026
|
||||
C-O GLU 1.229 0.019
|
||||
CA-CB GLN 1.535 0.022
|
||||
CB-CG GLN 1.521 0.027
|
||||
CG-CD GLN 1.506 0.023
|
||||
CD-OE1 GLN 1.235 0.022
|
||||
CD-NE2 GLN 1.324 0.025
|
||||
N-CA GLN 1.459 0.020
|
||||
CA-C GLN 1.525 0.026
|
||||
C-O GLN 1.229 0.019
|
||||
N-CA GLY 1.456 0.015
|
||||
CA-C GLY 1.514 0.016
|
||||
C-O GLY 1.232 0.016
|
||||
CA-CB HIS 1.535 0.022
|
||||
CB-CG HIS 1.492 0.016
|
||||
CG-ND1 HIS 1.369 0.015
|
||||
CG-CD2 HIS 1.353 0.017
|
||||
ND1-CE1 HIS 1.343 0.025
|
||||
CD2-NE2 HIS 1.415 0.021
|
||||
CE1-NE2 HIS 1.322 0.023
|
||||
N-CA HIS 1.459 0.020
|
||||
CA-C HIS 1.525 0.026
|
||||
C-O HIS 1.229 0.019
|
||||
CA-CB ILE 1.544 0.023
|
||||
CB-CG1 ILE 1.536 0.028
|
||||
CB-CG2 ILE 1.524 0.031
|
||||
CG1-CD1 ILE 1.500 0.069
|
||||
N-CA ILE 1.459 0.020
|
||||
CA-C ILE 1.525 0.026
|
||||
C-O ILE 1.229 0.019
|
||||
CA-CB LEU 1.533 0.023
|
||||
CB-CG LEU 1.521 0.029
|
||||
CG-CD1 LEU 1.514 0.037
|
||||
CG-CD2 LEU 1.514 0.037
|
||||
N-CA LEU 1.459 0.020
|
||||
CA-C LEU 1.525 0.026
|
||||
C-O LEU 1.229 0.019
|
||||
CA-CB LYS 1.535 0.022
|
||||
CB-CG LYS 1.521 0.027
|
||||
CG-CD LYS 1.520 0.034
|
||||
CD-CE LYS 1.508 0.025
|
||||
CE-NZ LYS 1.486 0.025
|
||||
N-CA LYS 1.459 0.020
|
||||
CA-C LYS 1.525 0.026
|
||||
C-O LYS 1.229 0.019
|
||||
CA-CB MET 1.535 0.022
|
||||
CB-CG MET 1.509 0.032
|
||||
CG-SD MET 1.807 0.026
|
||||
SD-CE MET 1.774 0.056
|
||||
N-CA MET 1.459 0.020
|
||||
CA-C MET 1.525 0.026
|
||||
C-O MET 1.229 0.019
|
||||
CA-CB PHE 1.535 0.022
|
||||
CB-CG PHE 1.509 0.017
|
||||
CG-CD1 PHE 1.383 0.015
|
||||
CG-CD2 PHE 1.383 0.015
|
||||
CD1-CE1 PHE 1.388 0.020
|
||||
CD2-CE2 PHE 1.388 0.020
|
||||
CE1-CZ PHE 1.369 0.019
|
||||
CE2-CZ PHE 1.369 0.019
|
||||
N-CA PHE 1.459 0.020
|
||||
CA-C PHE 1.525 0.026
|
||||
C-O PHE 1.229 0.019
|
||||
CA-CB PRO 1.531 0.020
|
||||
CB-CG PRO 1.495 0.050
|
||||
CG-CD PRO 1.502 0.033
|
||||
CD-N PRO 1.474 0.014
|
||||
N-CA PRO 1.468 0.017
|
||||
CA-C PRO 1.524 0.020
|
||||
C-O PRO 1.228 0.020
|
||||
CA-CB SER 1.525 0.015
|
||||
CB-OG SER 1.418 0.013
|
||||
N-CA SER 1.459 0.020
|
||||
CA-C SER 1.525 0.026
|
||||
C-O SER 1.229 0.019
|
||||
CA-CB THR 1.529 0.026
|
||||
CB-OG1 THR 1.428 0.020
|
||||
CB-CG2 THR 1.519 0.033
|
||||
N-CA THR 1.459 0.020
|
||||
CA-C THR 1.525 0.026
|
||||
C-O THR 1.229 0.019
|
||||
CA-CB TRP 1.535 0.022
|
||||
CB-CG TRP 1.498 0.018
|
||||
CG-CD1 TRP 1.363 0.014
|
||||
CG-CD2 TRP 1.432 0.017
|
||||
CD1-NE1 TRP 1.375 0.017
|
||||
NE1-CE2 TRP 1.371 0.013
|
||||
CD2-CE2 TRP 1.409 0.012
|
||||
CD2-CE3 TRP 1.399 0.015
|
||||
CE2-CZ2 TRP 1.393 0.017
|
||||
CE3-CZ3 TRP 1.380 0.017
|
||||
CZ2-CH2 TRP 1.369 0.019
|
||||
CZ3-CH2 TRP 1.396 0.016
|
||||
N-CA TRP 1.459 0.020
|
||||
CA-C TRP 1.525 0.026
|
||||
C-O TRP 1.229 0.019
|
||||
CA-CB TYR 1.535 0.022
|
||||
CB-CG TYR 1.512 0.015
|
||||
CG-CD1 TYR 1.387 0.013
|
||||
CG-CD2 TYR 1.387 0.013
|
||||
CD1-CE1 TYR 1.389 0.015
|
||||
CD2-CE2 TYR 1.389 0.015
|
||||
CE1-CZ TYR 1.381 0.013
|
||||
CE2-CZ TYR 1.381 0.013
|
||||
CZ-OH TYR 1.374 0.017
|
||||
N-CA TYR 1.459 0.020
|
||||
CA-C TYR 1.525 0.026
|
||||
C-O TYR 1.229 0.019
|
||||
CA-CB VAL 1.543 0.021
|
||||
CB-CG1 VAL 1.524 0.021
|
||||
CB-CG2 VAL 1.524 0.021
|
||||
N-CA VAL 1.459 0.020
|
||||
CA-C VAL 1.525 0.026
|
||||
C-O VAL 1.229 0.019
|
||||
-
|
||||
|
||||
Angle Residue Mean StdDev
|
||||
N-CA-CB ALA 110.1 1.4
|
||||
CB-CA-C ALA 110.1 1.5
|
||||
N-CA-C ALA 111.0 2.7
|
||||
CA-C-O ALA 120.1 2.1
|
||||
N-CA-CB ARG 110.6 1.8
|
||||
CB-CA-C ARG 110.4 2.0
|
||||
CA-CB-CG ARG 113.4 2.2
|
||||
CB-CG-CD ARG 111.6 2.6
|
||||
CG-CD-NE ARG 111.8 2.1
|
||||
CD-NE-CZ ARG 123.6 1.4
|
||||
NE-CZ-NH1 ARG 120.3 0.5
|
||||
NE-CZ-NH2 ARG 120.3 0.5
|
||||
NH1-CZ-NH2 ARG 119.4 1.1
|
||||
N-CA-C ARG 111.0 2.7
|
||||
CA-C-O ARG 120.1 2.1
|
||||
N-CA-CB ASN 110.6 1.8
|
||||
CB-CA-C ASN 110.4 2.0
|
||||
CA-CB-CG ASN 113.4 2.2
|
||||
CB-CG-ND2 ASN 116.7 2.4
|
||||
CB-CG-OD1 ASN 121.6 2.0
|
||||
ND2-CG-OD1 ASN 121.9 2.3
|
||||
N-CA-C ASN 111.0 2.7
|
||||
CA-C-O ASN 120.1 2.1
|
||||
N-CA-CB ASP 110.6 1.8
|
||||
CB-CA-C ASP 110.4 2.0
|
||||
CA-CB-CG ASP 113.4 2.2
|
||||
CB-CG-OD1 ASP 118.3 0.9
|
||||
CB-CG-OD2 ASP 118.3 0.9
|
||||
OD1-CG-OD2 ASP 123.3 1.9
|
||||
N-CA-C ASP 111.0 2.7
|
||||
CA-C-O ASP 120.1 2.1
|
||||
N-CA-CB CYS 110.8 1.5
|
||||
CB-CA-C CYS 111.5 1.2
|
||||
CA-CB-SG CYS 114.2 1.1
|
||||
N-CA-C CYS 111.0 2.7
|
||||
CA-C-O CYS 120.1 2.1
|
||||
N-CA-CB GLU 110.6 1.8
|
||||
CB-CA-C GLU 110.4 2.0
|
||||
CA-CB-CG GLU 113.4 2.2
|
||||
CB-CG-CD GLU 114.2 2.7
|
||||
CG-CD-OE1 GLU 118.3 2.0
|
||||
CG-CD-OE2 GLU 118.3 2.0
|
||||
OE1-CD-OE2 GLU 123.3 1.2
|
||||
N-CA-C GLU 111.0 2.7
|
||||
CA-C-O GLU 120.1 2.1
|
||||
N-CA-CB GLN 110.6 1.8
|
||||
CB-CA-C GLN 110.4 2.0
|
||||
CA-CB-CG GLN 113.4 2.2
|
||||
CB-CG-CD GLN 111.6 2.6
|
||||
CG-CD-OE1 GLN 121.6 2.0
|
||||
CG-CD-NE2 GLN 116.7 2.4
|
||||
OE1-CD-NE2 GLN 121.9 2.3
|
||||
N-CA-C GLN 111.0 2.7
|
||||
CA-C-O GLN 120.1 2.1
|
||||
N-CA-C GLY 113.1 2.5
|
||||
CA-C-O GLY 120.6 1.8
|
||||
N-CA-CB HIS 110.6 1.8
|
||||
CB-CA-C HIS 110.4 2.0
|
||||
CA-CB-CG HIS 113.6 1.7
|
||||
CB-CG-ND1 HIS 123.2 2.5
|
||||
CB-CG-CD2 HIS 130.8 3.1
|
||||
CG-ND1-CE1 HIS 108.2 1.4
|
||||
ND1-CE1-NE2 HIS 109.9 2.2
|
||||
CE1-NE2-CD2 HIS 106.6 2.5
|
||||
NE2-CD2-CG HIS 109.2 1.9
|
||||
CD2-CG-ND1 HIS 106.0 1.4
|
||||
N-CA-C HIS 111.0 2.7
|
||||
CA-C-O HIS 120.1 2.1
|
||||
N-CA-CB ILE 110.8 2.3
|
||||
CB-CA-C ILE 111.6 2.0
|
||||
CA-CB-CG1 ILE 111.0 1.9
|
||||
CB-CG1-CD1 ILE 113.9 2.8
|
||||
CA-CB-CG2 ILE 110.9 2.0
|
||||
CG1-CB-CG2 ILE 111.4 2.2
|
||||
N-CA-C ILE 111.0 2.7
|
||||
CA-C-O ILE 120.1 2.1
|
||||
N-CA-CB LEU 110.4 2.0
|
||||
CB-CA-C LEU 110.2 1.9
|
||||
CA-CB-CG LEU 115.3 2.3
|
||||
CB-CG-CD1 LEU 111.0 1.7
|
||||
CB-CG-CD2 LEU 111.0 1.7
|
||||
CD1-CG-CD2 LEU 110.5 3.0
|
||||
N-CA-C LEU 111.0 2.7
|
||||
CA-C-O LEU 120.1 2.1
|
||||
N-CA-CB LYS 110.6 1.8
|
||||
CB-CA-C LYS 110.4 2.0
|
||||
CA-CB-CG LYS 113.4 2.2
|
||||
CB-CG-CD LYS 111.6 2.6
|
||||
CG-CD-CE LYS 111.9 3.0
|
||||
CD-CE-NZ LYS 111.7 2.3
|
||||
N-CA-C LYS 111.0 2.7
|
||||
CA-C-O LYS 120.1 2.1
|
||||
N-CA-CB MET 110.6 1.8
|
||||
CB-CA-C MET 110.4 2.0
|
||||
CA-CB-CG MET 113.3 1.7
|
||||
CB-CG-SD MET 112.4 3.0
|
||||
CG-SD-CE MET 100.2 1.6
|
||||
N-CA-C MET 111.0 2.7
|
||||
CA-C-O MET 120.1 2.1
|
||||
N-CA-CB PHE 110.6 1.8
|
||||
CB-CA-C PHE 110.4 2.0
|
||||
CA-CB-CG PHE 113.9 2.4
|
||||
CB-CG-CD1 PHE 120.8 0.7
|
||||
CB-CG-CD2 PHE 120.8 0.7
|
||||
CD1-CG-CD2 PHE 118.3 1.3
|
||||
CG-CD1-CE1 PHE 120.8 1.1
|
||||
CG-CD2-CE2 PHE 120.8 1.1
|
||||
CD1-CE1-CZ PHE 120.1 1.2
|
||||
CD2-CE2-CZ PHE 120.1 1.2
|
||||
CE1-CZ-CE2 PHE 120.0 1.8
|
||||
N-CA-C PHE 111.0 2.7
|
||||
CA-C-O PHE 120.1 2.1
|
||||
N-CA-CB PRO 103.3 1.2
|
||||
CB-CA-C PRO 111.7 2.1
|
||||
CA-CB-CG PRO 104.8 1.9
|
||||
CB-CG-CD PRO 106.5 3.9
|
||||
CG-CD-N PRO 103.2 1.5
|
||||
CA-N-CD PRO 111.7 1.4
|
||||
N-CA-C PRO 112.1 2.6
|
||||
CA-C-O PRO 120.2 2.4
|
||||
N-CA-CB SER 110.5 1.5
|
||||
CB-CA-C SER 110.1 1.9
|
||||
CA-CB-OG SER 111.2 2.7
|
||||
N-CA-C SER 111.0 2.7
|
||||
CA-C-O SER 120.1 2.1
|
||||
N-CA-CB THR 110.3 1.9
|
||||
CB-CA-C THR 111.6 2.7
|
||||
CA-CB-OG1 THR 109.0 2.1
|
||||
CA-CB-CG2 THR 112.4 1.4
|
||||
OG1-CB-CG2 THR 110.0 2.3
|
||||
N-CA-C THR 111.0 2.7
|
||||
CA-C-O THR 120.1 2.1
|
||||
N-CA-CB TRP 110.6 1.8
|
||||
CB-CA-C TRP 110.4 2.0
|
||||
CA-CB-CG TRP 113.7 1.9
|
||||
CB-CG-CD1 TRP 127.0 1.3
|
||||
CB-CG-CD2 TRP 126.6 1.3
|
||||
CD1-CG-CD2 TRP 106.3 0.8
|
||||
CG-CD1-NE1 TRP 110.1 1.0
|
||||
CD1-NE1-CE2 TRP 109.0 0.9
|
||||
NE1-CE2-CD2 TRP 107.3 1.0
|
||||
CE2-CD2-CG TRP 107.3 0.8
|
||||
CG-CD2-CE3 TRP 133.9 0.9
|
||||
NE1-CE2-CZ2 TRP 130.4 1.1
|
||||
CE3-CD2-CE2 TRP 118.7 1.2
|
||||
CD2-CE2-CZ2 TRP 122.3 1.2
|
||||
CE2-CZ2-CH2 TRP 117.4 1.0
|
||||
CZ2-CH2-CZ3 TRP 121.6 1.2
|
||||
CH2-CZ3-CE3 TRP 121.2 1.1
|
||||
CZ3-CE3-CD2 TRP 118.8 1.3
|
||||
N-CA-C TRP 111.0 2.7
|
||||
CA-C-O TRP 120.1 2.1
|
||||
N-CA-CB TYR 110.6 1.8
|
||||
CB-CA-C TYR 110.4 2.0
|
||||
CA-CB-CG TYR 113.4 1.9
|
||||
CB-CG-CD1 TYR 121.0 0.6
|
||||
CB-CG-CD2 TYR 121.0 0.6
|
||||
CD1-CG-CD2 TYR 117.9 1.1
|
||||
CG-CD1-CE1 TYR 121.3 0.8
|
||||
CG-CD2-CE2 TYR 121.3 0.8
|
||||
CD1-CE1-CZ TYR 119.8 0.9
|
||||
CD2-CE2-CZ TYR 119.8 0.9
|
||||
CE1-CZ-CE2 TYR 119.8 1.6
|
||||
CE1-CZ-OH TYR 120.1 2.7
|
||||
CE2-CZ-OH TYR 120.1 2.7
|
||||
N-CA-C TYR 111.0 2.7
|
||||
CA-C-O TYR 120.1 2.1
|
||||
N-CA-CB VAL 111.5 2.2
|
||||
CB-CA-C VAL 111.4 1.9
|
||||
CA-CB-CG1 VAL 110.9 1.5
|
||||
CA-CB-CG2 VAL 110.9 1.5
|
||||
CG1-CB-CG2 VAL 110.9 1.6
|
||||
N-CA-C VAL 111.0 2.7
|
||||
CA-C-O VAL 120.1 2.1
|
||||
-
|
||||
|
||||
Non-bonded distance Minimum Dist Tolerance
|
||||
C-C 3.4 1.5
|
||||
C-N 3.25 1.5
|
||||
C-S 3.5 1.5
|
||||
C-O 3.22 1.5
|
||||
N-N 3.1 1.5
|
||||
N-S 3.35 1.5
|
||||
N-O 3.07 1.5
|
||||
O-S 3.32 1.5
|
||||
O-O 3.04 1.5
|
||||
S-S 2.03 1.0
|
||||
-
|
||||
161
modelscope/models/science/unifold/data/utils.py
Normal file
161
modelscope/models/science/unifold/data/utils.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
|
||||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
|
||||
|
||||
import copy as copy_lib
|
||||
import functools
|
||||
import gzip
|
||||
import pickle
|
||||
from typing import Any, Dict
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from scipy import sparse as sp
|
||||
|
||||
from . import residue_constants as rc
|
||||
from .data_ops import NumpyDict
|
||||
|
||||
# from typing import *
|
||||
|
||||
|
||||
def lru_cache(maxsize=16, typed=False, copy=False, deepcopy=False):
|
||||
if deepcopy:
|
||||
|
||||
def decorator(f):
|
||||
cached_func = functools.lru_cache(maxsize, typed)(f)
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
return copy_lib.deepcopy(cached_func(*args, **kwargs))
|
||||
|
||||
return wrapper
|
||||
|
||||
elif copy:
|
||||
|
||||
def decorator(f):
|
||||
cached_func = functools.lru_cache(maxsize, typed)(f)
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
return copy_lib.copy(cached_func(*args, **kwargs))
|
||||
|
||||
return wrapper
|
||||
|
||||
else:
|
||||
decorator = functools.lru_cache(maxsize, typed)
|
||||
return decorator
|
||||
|
||||
|
||||
@lru_cache(maxsize=8, deepcopy=True)
|
||||
def load_pickle_safe(path: str) -> Dict[str, Any]:
|
||||
|
||||
def load(path):
|
||||
assert path.endswith('.pkl') or path.endswith(
|
||||
'.pkl.gz'), f'bad suffix in {path} as pickle file.'
|
||||
open_fn = gzip.open if path.endswith('.gz') else open
|
||||
with open_fn(path, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
|
||||
ret = load(path)
|
||||
ret = uncompress_features(ret)
|
||||
return ret
|
||||
|
||||
|
||||
@lru_cache(maxsize=8, copy=True)
|
||||
def load_pickle(path: str) -> Dict[str, Any]:
|
||||
|
||||
def load(path):
|
||||
assert path.endswith('.pkl') or path.endswith(
|
||||
'.pkl.gz'), f'bad suffix in {path} as pickle file.'
|
||||
open_fn = gzip.open if path.endswith('.gz') else open
|
||||
with open_fn(path, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
|
||||
ret = load(path)
|
||||
ret = uncompress_features(ret)
|
||||
return ret
|
||||
|
||||
|
||||
def correct_template_restypes(feature):
|
||||
"""Correct template restype to have the same order as residue_constants."""
|
||||
feature = np.argmax(feature, axis=-1).astype(np.int32)
|
||||
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
|
||||
feature = np.take(new_order_list, feature.astype(np.int32), axis=0)
|
||||
return feature
|
||||
|
||||
|
||||
def convert_all_seq_feature(feature: NumpyDict) -> NumpyDict:
|
||||
feature['msa'] = feature['msa'].astype(np.uint8)
|
||||
if 'num_alignments' in feature:
|
||||
feature.pop('num_alignments')
|
||||
# make_all_seq_key = lambda k: f'{k}_all_seq' if not k.endswith('_all_seq') else k
|
||||
|
||||
def make_all_seq_key(k):
|
||||
if not k.endswith('_all_seq'):
|
||||
return f'{k}_all_seq'
|
||||
return k
|
||||
|
||||
return {make_all_seq_key(k): v for k, v in feature.items()}
|
||||
|
||||
|
||||
def to_dense_matrix(spmat_dict: NumpyDict):
|
||||
spmat = sp.coo_matrix(
|
||||
(spmat_dict['data'], (spmat_dict['row'], spmat_dict['col'])),
|
||||
shape=spmat_dict['shape'],
|
||||
dtype=np.float32,
|
||||
)
|
||||
return spmat.toarray()
|
||||
|
||||
|
||||
FEATS_DTYPE = {'msa': np.int32}
|
||||
|
||||
|
||||
def uncompress_features(feats: NumpyDict) -> NumpyDict:
|
||||
if 'sparse_deletion_matrix_int' in feats:
|
||||
v = feats.pop('sparse_deletion_matrix_int')
|
||||
v = to_dense_matrix(v)
|
||||
feats['deletion_matrix'] = v
|
||||
return feats
|
||||
|
||||
|
||||
def filter(feature: NumpyDict, **kwargs) -> NumpyDict:
|
||||
assert len(kwargs) == 1, f'wrong usage of filter with kwargs: {kwargs}'
|
||||
if 'desired_keys' in kwargs:
|
||||
feature = {
|
||||
k: v
|
||||
for k, v in feature.items() if k in kwargs['desired_keys']
|
||||
}
|
||||
elif 'required_keys' in kwargs:
|
||||
for k in kwargs['required_keys']:
|
||||
assert k in feature, f'cannot find required key {k}.'
|
||||
elif 'ignored_keys' in kwargs:
|
||||
feature = {
|
||||
k: v
|
||||
for k, v in feature.items() if k not in kwargs['ignored_keys']
|
||||
}
|
||||
else:
|
||||
raise AssertionError(f'wrong usage of filter with kwargs: {kwargs}')
|
||||
return feature
|
||||
|
||||
|
||||
def compress_features(features: NumpyDict):
|
||||
change_dtype = {
|
||||
'msa': np.uint8,
|
||||
}
|
||||
sparse_keys = ['deletion_matrix_int']
|
||||
|
||||
compressed_features = {}
|
||||
for k, v in features.items():
|
||||
if k in change_dtype:
|
||||
v = v.astype(change_dtype[k])
|
||||
if k in sparse_keys:
|
||||
v = sp.coo_matrix(v, dtype=v.dtype)
|
||||
sp_v = {
|
||||
'shape': v.shape,
|
||||
'row': v.row,
|
||||
'col': v.col,
|
||||
'data': v.data
|
||||
}
|
||||
k = f'sparse_{k}'
|
||||
v = sp_v
|
||||
compressed_features[k] = v
|
||||
return compressed_features
|
||||
514
modelscope/models/science/unifold/dataset.py
Normal file
514
modelscope/models/science/unifold/dataset.py
Normal file
@@ -0,0 +1,514 @@
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
# from typing import *
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import json
|
||||
import ml_collections as mlc
|
||||
import numpy as np
|
||||
import torch
|
||||
from unicore.data import UnicoreDataset, data_utils
|
||||
from unicore.distributed import utils as distributed_utils
|
||||
|
||||
from .data import utils
|
||||
from .data.data_ops import NumpyDict, TorchDict
|
||||
from .data.process import process_features, process_labels
|
||||
from .data.process_multimer import (add_assembly_features,
|
||||
convert_monomer_features, merge_msas,
|
||||
pair_and_merge, post_process)
|
||||
|
||||
Rotation = Iterable[Iterable]
|
||||
Translation = Iterable
|
||||
Operation = Union[str, Tuple[Rotation, Translation]]
|
||||
NumpyExample = Tuple[NumpyDict, Optional[List[NumpyDict]]]
|
||||
TorchExample = Tuple[TorchDict, Optional[List[TorchDict]]]
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def make_data_config(
|
||||
config: mlc.ConfigDict,
|
||||
mode: str,
|
||||
num_res: int,
|
||||
) -> Tuple[mlc.ConfigDict, List[str]]:
|
||||
cfg = copy.deepcopy(config)
|
||||
mode_cfg = cfg[mode]
|
||||
with cfg.unlocked():
|
||||
if mode_cfg.crop_size is None:
|
||||
mode_cfg.crop_size = num_res
|
||||
feature_names = cfg.common.unsupervised_features + cfg.common.recycling_features
|
||||
if cfg.common.use_templates:
|
||||
feature_names += cfg.common.template_features
|
||||
if cfg.common.is_multimer:
|
||||
feature_names += cfg.common.multimer_features
|
||||
if cfg[mode].supervised:
|
||||
feature_names += cfg.supervised.supervised_features
|
||||
|
||||
return cfg, feature_names
|
||||
|
||||
|
||||
def process_label(all_atom_positions: np.ndarray,
|
||||
operation: Operation) -> np.ndarray:
|
||||
if operation == 'I':
|
||||
return all_atom_positions
|
||||
rot, trans = operation
|
||||
rot = np.array(rot).reshape(3, 3)
|
||||
trans = np.array(trans).reshape(3)
|
||||
return all_atom_positions @ rot.T + trans
|
||||
|
||||
|
||||
@utils.lru_cache(maxsize=8, copy=True)
|
||||
def load_single_feature(
|
||||
sequence_id: str,
|
||||
monomer_feature_dir: str,
|
||||
uniprot_msa_dir: Optional[str] = None,
|
||||
is_monomer: bool = False,
|
||||
) -> NumpyDict:
|
||||
|
||||
monomer_feature = utils.load_pickle(
|
||||
os.path.join(monomer_feature_dir, f'{sequence_id}.feature.pkl.gz'))
|
||||
monomer_feature = convert_monomer_features(monomer_feature)
|
||||
chain_feature = {**monomer_feature}
|
||||
|
||||
if uniprot_msa_dir is not None:
|
||||
all_seq_feature = utils.load_pickle(
|
||||
os.path.join(uniprot_msa_dir, f'{sequence_id}.uniprot.pkl.gz'))
|
||||
if is_monomer:
|
||||
chain_feature['msa'], chain_feature[
|
||||
'deletion_matrix'] = merge_msas(
|
||||
chain_feature['msa'],
|
||||
chain_feature['deletion_matrix'],
|
||||
all_seq_feature['msa'],
|
||||
all_seq_feature['deletion_matrix'],
|
||||
) # noqa
|
||||
else:
|
||||
all_seq_feature = utils.convert_all_seq_feature(all_seq_feature)
|
||||
for key in [
|
||||
'msa_all_seq',
|
||||
'msa_species_identifiers_all_seq',
|
||||
'deletion_matrix_all_seq',
|
||||
]:
|
||||
chain_feature[key] = all_seq_feature[key]
|
||||
|
||||
return chain_feature
|
||||
|
||||
|
||||
def load_single_label(
|
||||
label_id: str,
|
||||
label_dir: str,
|
||||
symmetry_operation: Optional[Operation] = None,
|
||||
) -> NumpyDict:
|
||||
label = utils.load_pickle(
|
||||
os.path.join(label_dir, f'{label_id}.label.pkl.gz'))
|
||||
if symmetry_operation is not None:
|
||||
label['all_atom_positions'] = process_label(
|
||||
label['all_atom_positions'], symmetry_operation)
|
||||
label = {
|
||||
k: v
|
||||
for k, v in label.items() if k in
|
||||
['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution']
|
||||
}
|
||||
return label
|
||||
|
||||
|
||||
def load(
|
||||
sequence_ids: List[str],
|
||||
monomer_feature_dir: str,
|
||||
uniprot_msa_dir: Optional[str] = None,
|
||||
label_ids: Optional[List[str]] = None,
|
||||
label_dir: Optional[str] = None,
|
||||
symmetry_operations: Optional[List[Operation]] = None,
|
||||
is_monomer: bool = False,
|
||||
) -> NumpyExample:
|
||||
|
||||
all_chain_features = [
|
||||
load_single_feature(s, monomer_feature_dir, uniprot_msa_dir,
|
||||
is_monomer) for s in sequence_ids
|
||||
]
|
||||
|
||||
if label_ids is not None:
|
||||
# load labels
|
||||
assert len(label_ids) == len(sequence_ids)
|
||||
assert label_dir is not None
|
||||
if symmetry_operations is None:
|
||||
symmetry_operations = ['I' for _ in label_ids]
|
||||
all_chain_labels = [
|
||||
load_single_label(ll, label_dir, o)
|
||||
for ll, o in zip(label_ids, symmetry_operations)
|
||||
]
|
||||
# update labels into features to calculate spatial cropping etc.
|
||||
[f.update(ll) for f, ll in zip(all_chain_features, all_chain_labels)]
|
||||
|
||||
all_chain_features = add_assembly_features(all_chain_features)
|
||||
|
||||
# get labels back from features, as add_assembly_features may alter the order of inputs.
|
||||
if label_ids is not None:
|
||||
all_chain_labels = [{
|
||||
k: f[k]
|
||||
for k in
|
||||
['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution']
|
||||
} for f in all_chain_features]
|
||||
else:
|
||||
all_chain_labels = None
|
||||
|
||||
asym_len = np.array([c['seq_length'] for c in all_chain_features],
|
||||
dtype=np.int64)
|
||||
if is_monomer:
|
||||
all_chain_features = all_chain_features[0]
|
||||
else:
|
||||
all_chain_features = pair_and_merge(all_chain_features)
|
||||
all_chain_features = post_process(all_chain_features)
|
||||
all_chain_features['asym_len'] = asym_len
|
||||
|
||||
return all_chain_features, all_chain_labels
|
||||
|
||||
|
||||
def process(
|
||||
config: mlc.ConfigDict,
|
||||
mode: str,
|
||||
features: NumpyDict,
|
||||
labels: Optional[List[NumpyDict]] = None,
|
||||
seed: int = 0,
|
||||
batch_idx: Optional[int] = None,
|
||||
data_idx: Optional[int] = None,
|
||||
is_distillation: bool = False,
|
||||
) -> TorchExample:
|
||||
|
||||
if mode == 'train':
|
||||
assert batch_idx is not None
|
||||
with data_utils.numpy_seed(seed, batch_idx, key='recycling'):
|
||||
num_iters = np.random.randint(
|
||||
0, config.common.max_recycling_iters + 1)
|
||||
use_clamped_fape = np.random.rand(
|
||||
) < config[mode].use_clamped_fape_prob
|
||||
else:
|
||||
num_iters = config.common.max_recycling_iters
|
||||
use_clamped_fape = 1
|
||||
|
||||
features['num_recycling_iters'] = int(num_iters)
|
||||
features['use_clamped_fape'] = int(use_clamped_fape)
|
||||
features['is_distillation'] = int(is_distillation)
|
||||
if is_distillation and 'msa_chains' in features:
|
||||
features.pop('msa_chains')
|
||||
|
||||
num_res = int(features['seq_length'])
|
||||
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
|
||||
|
||||
if labels is not None:
|
||||
features['resolution'] = labels[0]['resolution'].reshape(-1)
|
||||
|
||||
with data_utils.numpy_seed(seed, data_idx, key='protein_feature'):
|
||||
features['crop_and_fix_size_seed'] = np.random.randint(0, 63355)
|
||||
features = utils.filter(features, desired_keys=feature_names)
|
||||
features = {k: torch.tensor(v) for k, v in features.items()}
|
||||
with torch.no_grad():
|
||||
features = process_features(features, cfg.common, cfg[mode])
|
||||
|
||||
if labels is not None:
|
||||
labels = [{k: torch.tensor(v) for k, v in ll.items()} for ll in labels]
|
||||
with torch.no_grad():
|
||||
labels = process_labels(labels)
|
||||
|
||||
return features, labels
|
||||
|
||||
|
||||
def load_and_process(
|
||||
config: mlc.ConfigDict,
|
||||
mode: str,
|
||||
seed: int = 0,
|
||||
batch_idx: Optional[int] = None,
|
||||
data_idx: Optional[int] = None,
|
||||
is_distillation: bool = False,
|
||||
**load_kwargs,
|
||||
):
|
||||
is_monomer = (
|
||||
is_distillation
|
||||
if 'is_monomer' not in load_kwargs else load_kwargs.pop('is_monomer'))
|
||||
features, labels = load(**load_kwargs, is_monomer=is_monomer)
|
||||
features, labels = process(config, mode, features, labels, seed, batch_idx,
|
||||
data_idx, is_distillation)
|
||||
return features, labels
|
||||
|
||||
|
||||
class UnifoldDataset(UnicoreDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
seed,
|
||||
config,
|
||||
data_path,
|
||||
mode='train',
|
||||
max_step=None,
|
||||
disable_sd=False,
|
||||
json_prefix='',
|
||||
):
|
||||
self.path = data_path
|
||||
|
||||
def load_json(filename):
|
||||
return json.load(open(filename, 'r'))
|
||||
|
||||
sample_weight = load_json(
|
||||
os.path.join(self.path,
|
||||
json_prefix + mode + '_sample_weight.json'))
|
||||
self.multi_label = load_json(
|
||||
os.path.join(self.path, json_prefix + mode + '_multi_label.json'))
|
||||
self.inverse_multi_label = self._inverse_map(self.multi_label)
|
||||
self.sample_weight = {}
|
||||
for chain in self.inverse_multi_label:
|
||||
entity = self.inverse_multi_label[chain]
|
||||
self.sample_weight[chain] = sample_weight[entity]
|
||||
self.seq_sample_weight = sample_weight
|
||||
logger.info('load {} chains (unique {} sequences)'.format(
|
||||
len(self.sample_weight), len(self.seq_sample_weight)))
|
||||
self.feature_path = os.path.join(self.path, 'pdb_features')
|
||||
self.label_path = os.path.join(self.path, 'pdb_labels')
|
||||
sd_sample_weight_path = os.path.join(
|
||||
self.path, json_prefix + 'sd_train_sample_weight.json')
|
||||
if mode == 'train' and os.path.isfile(
|
||||
sd_sample_weight_path) and not disable_sd:
|
||||
self.sd_sample_weight = load_json(sd_sample_weight_path)
|
||||
logger.info('load {} self-distillation samples.'.format(
|
||||
len(self.sd_sample_weight)))
|
||||
self.sd_feature_path = os.path.join(self.path, 'sd_features')
|
||||
self.sd_label_path = os.path.join(self.path, 'sd_labels')
|
||||
else:
|
||||
self.sd_sample_weight = None
|
||||
self.batch_size = (
|
||||
args.batch_size * distributed_utils.get_data_parallel_world_size()
|
||||
* args.update_freq[0])
|
||||
self.data_len = (
|
||||
max_step * self.batch_size
|
||||
if max_step is not None else len(self.sample_weight))
|
||||
self.mode = mode
|
||||
self.num_seq, self.seq_keys, self.seq_sample_prob = self.cal_sample_weight(
|
||||
self.seq_sample_weight)
|
||||
self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight(
|
||||
self.sample_weight)
|
||||
if self.sd_sample_weight is not None:
|
||||
(
|
||||
self.sd_num_chain,
|
||||
self.sd_chain_keys,
|
||||
self.sd_sample_prob,
|
||||
) = self.cal_sample_weight(self.sd_sample_weight)
|
||||
self.config = config.data
|
||||
self.seed = seed
|
||||
self.sd_prob = args.sd_prob
|
||||
|
||||
def cal_sample_weight(self, sample_weight):
|
||||
prot_keys = list(sample_weight.keys())
|
||||
sum_weight = sum(sample_weight.values())
|
||||
sample_prob = [sample_weight[k] / sum_weight for k in prot_keys]
|
||||
num_prot = len(prot_keys)
|
||||
return num_prot, prot_keys, sample_prob
|
||||
|
||||
def sample_chain(self, idx, sample_by_seq=False):
|
||||
is_distillation = False
|
||||
if self.mode == 'train':
|
||||
with data_utils.numpy_seed(self.seed, idx, key='data_sample'):
|
||||
is_distillation = ((np.random.rand(1)[0] < self.sd_prob)
|
||||
if self.sd_sample_weight is not None else
|
||||
False)
|
||||
if is_distillation:
|
||||
prot_idx = np.random.choice(
|
||||
self.sd_num_chain, p=self.sd_sample_prob)
|
||||
label_name = self.sd_chain_keys[prot_idx]
|
||||
seq_name = label_name
|
||||
else:
|
||||
if not sample_by_seq:
|
||||
prot_idx = np.random.choice(
|
||||
self.num_chain, p=self.sample_prob)
|
||||
label_name = self.chain_keys[prot_idx]
|
||||
seq_name = self.inverse_multi_label[label_name]
|
||||
else:
|
||||
seq_idx = np.random.choice(
|
||||
self.num_seq, p=self.seq_sample_prob)
|
||||
seq_name = self.seq_keys[seq_idx]
|
||||
label_name = np.random.choice(
|
||||
self.multi_label[seq_name])
|
||||
else:
|
||||
label_name = self.chain_keys[idx]
|
||||
seq_name = self.inverse_multi_label[label_name]
|
||||
return seq_name, label_name, is_distillation
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sequence_id, label_id, is_distillation = self.sample_chain(
|
||||
idx, sample_by_seq=True)
|
||||
feature_dir, label_dir = ((self.feature_path,
|
||||
self.label_path) if not is_distillation else
|
||||
(self.sd_feature_path, self.sd_label_path))
|
||||
features, _ = load_and_process(
|
||||
self.config,
|
||||
self.mode,
|
||||
self.seed,
|
||||
batch_idx=(idx // self.batch_size),
|
||||
data_idx=idx,
|
||||
is_distillation=is_distillation,
|
||||
sequence_ids=[sequence_id],
|
||||
monomer_feature_dir=feature_dir,
|
||||
uniprot_msa_dir=None,
|
||||
label_ids=[label_id],
|
||||
label_dir=label_dir,
|
||||
symmetry_operations=None,
|
||||
is_monomer=True,
|
||||
)
|
||||
return features
|
||||
|
||||
def __len__(self):
|
||||
return self.data_len
|
||||
|
||||
@staticmethod
|
||||
def collater(samples):
|
||||
# first dim is recyling. bsz is at the 2nd dim
|
||||
return data_utils.collate_dict(samples, dim=1)
|
||||
|
||||
@staticmethod
|
||||
def _inverse_map(mapping: Dict[str, List[str]]):
|
||||
inverse_mapping = {}
|
||||
for ent, refs in mapping.items():
|
||||
for ref in refs:
|
||||
if ref in inverse_mapping: # duplicated ent for this ref.
|
||||
ent_2 = inverse_mapping[ref]
|
||||
assert (
|
||||
ent == ent_2
|
||||
), f'multiple entities ({ent_2}, {ent}) exist for reference {ref}.'
|
||||
inverse_mapping[ref] = ent
|
||||
return inverse_mapping
|
||||
|
||||
|
||||
class UnifoldMultimerDataset(UnifoldDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: mlc.ConfigDict,
|
||||
seed: int,
|
||||
config: mlc.ConfigDict,
|
||||
data_path: str,
|
||||
mode: str = 'train',
|
||||
max_step: Optional[int] = None,
|
||||
disable_sd: bool = False,
|
||||
json_prefix: str = '',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(args, seed, config, data_path, mode, max_step,
|
||||
disable_sd, json_prefix)
|
||||
self.data_path = data_path
|
||||
self.pdb_assembly = json.load(
|
||||
open(
|
||||
os.path.join(self.data_path,
|
||||
json_prefix + 'pdb_assembly.json')))
|
||||
self.pdb_chains = self.get_chains(self.inverse_multi_label)
|
||||
self.monomer_feature_path = os.path.join(self.data_path,
|
||||
'pdb_features')
|
||||
self.uniprot_msa_path = os.path.join(self.data_path, 'pdb_uniprots')
|
||||
self.label_path = os.path.join(self.data_path, 'pdb_labels')
|
||||
self.max_chains = args.max_chains
|
||||
if self.mode == 'train':
|
||||
self.pdb_chains, self.sample_weight = self.filter_pdb_by_max_chains(
|
||||
self.pdb_chains, self.pdb_assembly, self.sample_weight,
|
||||
self.max_chains)
|
||||
self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight(
|
||||
self.sample_weight)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
seq_id, label_id, is_distillation = self.sample_chain(idx)
|
||||
if is_distillation:
|
||||
label_ids = [label_id]
|
||||
sequence_ids = [seq_id]
|
||||
monomer_feature_path, uniprot_msa_path, label_path = (
|
||||
self.sd_feature_path,
|
||||
None,
|
||||
self.sd_label_path,
|
||||
)
|
||||
symmetry_operations = None
|
||||
else:
|
||||
pdb_id = self.get_pdb_name(label_id)
|
||||
if pdb_id in self.pdb_assembly and self.mode == 'train':
|
||||
label_ids = [
|
||||
pdb_id + '_' + id
|
||||
for id in self.pdb_assembly[pdb_id]['chains']
|
||||
]
|
||||
symmetry_operations = [
|
||||
t for t in self.pdb_assembly[pdb_id]['opers']
|
||||
]
|
||||
else:
|
||||
label_ids = self.pdb_chains[pdb_id]
|
||||
symmetry_operations = None
|
||||
sequence_ids = [
|
||||
self.inverse_multi_label[chain_id] for chain_id in label_ids
|
||||
]
|
||||
monomer_feature_path, uniprot_msa_path, label_path = (
|
||||
self.monomer_feature_path,
|
||||
self.uniprot_msa_path,
|
||||
self.label_path,
|
||||
)
|
||||
|
||||
return load_and_process(
|
||||
self.config,
|
||||
self.mode,
|
||||
self.seed,
|
||||
batch_idx=(idx // self.batch_size),
|
||||
data_idx=idx,
|
||||
is_distillation=is_distillation,
|
||||
sequence_ids=sequence_ids,
|
||||
monomer_feature_dir=monomer_feature_path,
|
||||
uniprot_msa_dir=uniprot_msa_path,
|
||||
label_ids=label_ids,
|
||||
label_dir=label_path,
|
||||
symmetry_operations=symmetry_operations,
|
||||
is_monomer=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def collater(samples):
|
||||
# first dim is recyling. bsz is at the 2nd dim
|
||||
if len(samples) <= 0: # tackle empty batch
|
||||
return None
|
||||
feats = [s[0] for s in samples]
|
||||
labs = [s[1] for s in samples if s[1] is not None]
|
||||
try:
|
||||
feats = data_utils.collate_dict(feats, dim=1)
|
||||
except BaseException:
|
||||
raise ValueError('cannot collate features', feats)
|
||||
if not labs:
|
||||
labs = None
|
||||
return feats, labs
|
||||
|
||||
@staticmethod
|
||||
def get_pdb_name(chain):
|
||||
return chain.split('_')[0]
|
||||
|
||||
@staticmethod
|
||||
def get_chains(canon_chain_map):
|
||||
pdb_chains = {}
|
||||
for chain in canon_chain_map:
|
||||
pdb = UnifoldMultimerDataset.get_pdb_name(chain)
|
||||
if pdb not in pdb_chains:
|
||||
pdb_chains[pdb] = []
|
||||
pdb_chains[pdb].append(chain)
|
||||
return pdb_chains
|
||||
|
||||
@staticmethod
|
||||
def filter_pdb_by_max_chains(pdb_chains, pdb_assembly, sample_weight,
|
||||
max_chains):
|
||||
new_pdb_chains = {}
|
||||
for chain in pdb_chains:
|
||||
if chain in pdb_assembly:
|
||||
size = len(pdb_assembly[chain]['chains'])
|
||||
if size <= max_chains:
|
||||
new_pdb_chains[chain] = pdb_chains[chain]
|
||||
else:
|
||||
size = len(pdb_chains[chain])
|
||||
if size == 1:
|
||||
new_pdb_chains[chain] = pdb_chains[chain]
|
||||
new_sample_weight = {
|
||||
k: sample_weight[k]
|
||||
for k in sample_weight
|
||||
if UnifoldMultimerDataset.get_pdb_name(k) in new_pdb_chains
|
||||
}
|
||||
logger.info(
|
||||
f'filtered out {len(pdb_chains) - len(new_pdb_chains)} / {len(pdb_chains)} PDBs '
|
||||
f'({len(sample_weight) - len(new_sample_weight)} / {len(sample_weight)} chains) '
|
||||
f'by max_chains {max_chains}')
|
||||
return new_pdb_chains, new_sample_weight
|
||||
75
modelscope/models/science/unifold/model.py
Normal file
75
modelscope/models/science/unifold/model.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import argparse
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .config import model_config
|
||||
from .modules.alphafold import AlphaFold
|
||||
|
||||
__all__ = ['UnifoldForProteinStructrue']
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.protein_structure, module_name=Models.unifold)
|
||||
class UnifoldForProteinStructrue(TorchModel):
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add model-specific arguments to the parser."""
|
||||
parser.add_argument(
|
||||
'--model-name',
|
||||
help='choose the model config',
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
parser = argparse.ArgumentParser()
|
||||
parse_comm = []
|
||||
for key in kwargs:
|
||||
parser.add_argument(f'--{key}')
|
||||
parse_comm.append(f'--{key}')
|
||||
parse_comm.append(kwargs[key])
|
||||
args = parser.parse_args(parse_comm)
|
||||
base_architecture(args)
|
||||
self.args = args
|
||||
config = model_config(
|
||||
self.args.model_name,
|
||||
train=True,
|
||||
)
|
||||
self.model = AlphaFold(config)
|
||||
self.config = config
|
||||
|
||||
# load model state dict
|
||||
param_path = os.path.join(kwargs['model_dir'],
|
||||
ModelFile.TORCH_MODEL_BIN_FILE)
|
||||
state_dict = torch.load(param_path)['ema']['params']
|
||||
state_dict = {
|
||||
'.'.join(k.split('.')[1:]): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
def half(self):
|
||||
self.model = self.model.half()
|
||||
return self
|
||||
|
||||
def bfloat16(self):
|
||||
self.model = self.model.bfloat16()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
"""Build a new model instance."""
|
||||
return cls(args)
|
||||
|
||||
def forward(self, batch, **kwargs):
|
||||
outputs = self.model.forward(batch)
|
||||
return outputs, self.config.loss
|
||||
|
||||
|
||||
def base_architecture(args):
|
||||
args.model_name = getattr(args, 'model_name', 'model_2')
|
||||
450
modelscope/models/science/unifold/modules/alphafold.py
Normal file
450
modelscope/models/science/unifold/modules/alphafold.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
|
||||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from unicore.utils import tensor_tree_map
|
||||
|
||||
from ..data import residue_constants
|
||||
from .attentions import gen_msa_attn_mask, gen_tri_attn_mask
|
||||
from .auxillary_heads import AuxiliaryHeads
|
||||
from .common import residual
|
||||
from .embedders import (ExtraMSAEmbedder, InputEmbedder, RecyclingEmbedder,
|
||||
TemplateAngleEmbedder, TemplatePairEmbedder)
|
||||
from .evoformer import EvoformerStack, ExtraMSAStack
|
||||
from .featurization import (atom14_to_atom37, build_extra_msa_feat,
|
||||
build_template_angle_feat,
|
||||
build_template_pair_feat,
|
||||
build_template_pair_feat_v2, pseudo_beta_fn)
|
||||
from .structure_module import StructureModule
|
||||
from .template import (TemplatePairStack, TemplatePointwiseAttention,
|
||||
TemplateProjection)
|
||||
|
||||
|
||||
class AlphaFold(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(AlphaFold, self).__init__()
|
||||
|
||||
self.globals = config.globals
|
||||
config = config.model
|
||||
template_config = config.template
|
||||
extra_msa_config = config.extra_msa
|
||||
|
||||
self.input_embedder = InputEmbedder(
|
||||
**config['input_embedder'],
|
||||
use_chain_relative=config.is_multimer,
|
||||
)
|
||||
self.recycling_embedder = RecyclingEmbedder(
|
||||
**config['recycling_embedder'], )
|
||||
if config.template.enabled:
|
||||
self.template_angle_embedder = TemplateAngleEmbedder(
|
||||
**template_config['template_angle_embedder'], )
|
||||
self.template_pair_embedder = TemplatePairEmbedder(
|
||||
**template_config['template_pair_embedder'], )
|
||||
self.template_pair_stack = TemplatePairStack(
|
||||
**template_config['template_pair_stack'], )
|
||||
else:
|
||||
self.template_pair_stack = None
|
||||
self.enable_template_pointwise_attention = template_config[
|
||||
'template_pointwise_attention'].enabled
|
||||
if self.enable_template_pointwise_attention:
|
||||
self.template_pointwise_att = TemplatePointwiseAttention(
|
||||
**template_config['template_pointwise_attention'], )
|
||||
else:
|
||||
self.template_proj = TemplateProjection(
|
||||
**template_config['template_pointwise_attention'], )
|
||||
self.extra_msa_embedder = ExtraMSAEmbedder(
|
||||
**extra_msa_config['extra_msa_embedder'], )
|
||||
self.extra_msa_stack = ExtraMSAStack(
|
||||
**extra_msa_config['extra_msa_stack'], )
|
||||
self.evoformer = EvoformerStack(**config['evoformer_stack'], )
|
||||
self.structure_module = StructureModule(**config['structure_module'], )
|
||||
|
||||
self.aux_heads = AuxiliaryHeads(config['heads'], )
|
||||
|
||||
self.config = config
|
||||
self.dtype = torch.float
|
||||
self.inf = self.globals.inf
|
||||
if self.globals.alphafold_original_mode:
|
||||
self.alphafold_original_mode()
|
||||
|
||||
def __make_input_float__(self):
|
||||
self.input_embedder = self.input_embedder.float()
|
||||
self.recycling_embedder = self.recycling_embedder.float()
|
||||
|
||||
def half(self):
|
||||
super().half()
|
||||
if (not getattr(self, 'inference', False)):
|
||||
self.__make_input_float__()
|
||||
self.dtype = torch.half
|
||||
return self
|
||||
|
||||
def bfloat16(self):
|
||||
super().bfloat16()
|
||||
if (not getattr(self, 'inference', False)):
|
||||
self.__make_input_float__()
|
||||
self.dtype = torch.bfloat16
|
||||
return self
|
||||
|
||||
def alphafold_original_mode(self):
|
||||
|
||||
def set_alphafold_original_mode(module):
|
||||
if hasattr(module, 'apply_alphafold_original_mode'):
|
||||
module.apply_alphafold_original_mode()
|
||||
if hasattr(module, 'act'):
|
||||
module.act = nn.ReLU()
|
||||
|
||||
self.apply(set_alphafold_original_mode)
|
||||
|
||||
def inference_mode(self):
|
||||
|
||||
def set_inference_mode(module):
|
||||
setattr(module, 'inference', True)
|
||||
|
||||
self.apply(set_inference_mode)
|
||||
|
||||
def __convert_input_dtype__(self, batch):
|
||||
for key in batch:
|
||||
# only convert features with mask
|
||||
if batch[key].dtype != self.dtype and 'mask' in key:
|
||||
batch[key] = batch[key].type(self.dtype)
|
||||
return batch
|
||||
|
||||
def embed_templates_pair_core(self, batch, z, pair_mask,
|
||||
tri_start_attn_mask, tri_end_attn_mask,
|
||||
templ_dim, multichain_mask_2d):
|
||||
if self.config.template.template_pair_embedder.v2_feature:
|
||||
t = build_template_pair_feat_v2(
|
||||
batch,
|
||||
inf=self.config.template.inf,
|
||||
eps=self.config.template.eps,
|
||||
multichain_mask_2d=multichain_mask_2d,
|
||||
**self.config.template.distogram,
|
||||
)
|
||||
num_template = t[0].shape[-4]
|
||||
single_templates = [
|
||||
self.template_pair_embedder([x[..., ti, :, :, :]
|
||||
for x in t], z)
|
||||
for ti in range(num_template)
|
||||
]
|
||||
else:
|
||||
t = build_template_pair_feat(
|
||||
batch,
|
||||
inf=self.config.template.inf,
|
||||
eps=self.config.template.eps,
|
||||
**self.config.template.distogram,
|
||||
)
|
||||
single_templates = [
|
||||
self.template_pair_embedder(x, z)
|
||||
for x in torch.unbind(t, dim=templ_dim)
|
||||
]
|
||||
|
||||
t = self.template_pair_stack(
|
||||
single_templates,
|
||||
pair_mask,
|
||||
tri_start_attn_mask=tri_start_attn_mask,
|
||||
tri_end_attn_mask=tri_end_attn_mask,
|
||||
templ_dim=templ_dim,
|
||||
chunk_size=self.globals.chunk_size,
|
||||
block_size=self.globals.block_size,
|
||||
return_mean=not self.enable_template_pointwise_attention,
|
||||
)
|
||||
return t
|
||||
|
||||
def embed_templates_pair(self, batch, z, pair_mask, tri_start_attn_mask,
|
||||
tri_end_attn_mask, templ_dim):
|
||||
if self.config.template.template_pair_embedder.v2_feature and 'asym_id' in batch:
|
||||
multichain_mask_2d = (
|
||||
batch['asym_id'][..., :, None] == batch['asym_id'][...,
|
||||
None, :])
|
||||
multichain_mask_2d = multichain_mask_2d.unsqueeze(0)
|
||||
else:
|
||||
multichain_mask_2d = None
|
||||
|
||||
if self.training or self.enable_template_pointwise_attention:
|
||||
t = self.embed_templates_pair_core(batch, z, pair_mask,
|
||||
tri_start_attn_mask,
|
||||
tri_end_attn_mask, templ_dim,
|
||||
multichain_mask_2d)
|
||||
if self.enable_template_pointwise_attention:
|
||||
t = self.template_pointwise_att(
|
||||
t,
|
||||
z,
|
||||
template_mask=batch['template_mask'],
|
||||
chunk_size=self.globals.chunk_size,
|
||||
)
|
||||
t_mask = torch.sum(
|
||||
batch['template_mask'], dim=-1, keepdims=True) > 0
|
||||
t_mask = t_mask[..., None, None].type(t.dtype)
|
||||
t *= t_mask
|
||||
else:
|
||||
t = self.template_proj(t, z)
|
||||
else:
|
||||
template_aatype_shape = batch['template_aatype'].shape
|
||||
# template_aatype is either [n_template, n_res] or [1, n_template_, n_res]
|
||||
batch_templ_dim = 1 if len(template_aatype_shape) == 3 else 0
|
||||
n_templ = batch['template_aatype'].shape[batch_templ_dim]
|
||||
|
||||
if n_templ <= 0:
|
||||
t = None
|
||||
else:
|
||||
template_batch = {
|
||||
k: v
|
||||
for k, v in batch.items() if k.startswith('template_')
|
||||
}
|
||||
|
||||
def embed_one_template(i):
|
||||
|
||||
def slice_template_tensor(t):
|
||||
s = [slice(None) for _ in t.shape]
|
||||
s[batch_templ_dim] = slice(i, i + 1)
|
||||
return t[s]
|
||||
|
||||
template_feats = tensor_tree_map(
|
||||
slice_template_tensor,
|
||||
template_batch,
|
||||
)
|
||||
t = self.embed_templates_pair_core(
|
||||
template_feats, z, pair_mask, tri_start_attn_mask,
|
||||
tri_end_attn_mask, templ_dim, multichain_mask_2d)
|
||||
return t
|
||||
|
||||
t = embed_one_template(0)
|
||||
# iterate templates one by one
|
||||
for i in range(1, n_templ):
|
||||
t += embed_one_template(i)
|
||||
t /= n_templ
|
||||
t = self.template_proj(t, z)
|
||||
return t
|
||||
|
||||
def embed_templates_angle(self, batch):
|
||||
template_angle_feat, template_angle_mask = build_template_angle_feat(
|
||||
batch,
|
||||
v2_feature=self.config.template.template_pair_embedder.v2_feature)
|
||||
t = self.template_angle_embedder(template_angle_feat)
|
||||
return t, template_angle_mask
|
||||
|
||||
def iteration_evoformer(self, feats, m_1_prev, z_prev, x_prev):
|
||||
batch_dims = feats['target_feat'].shape[:-2]
|
||||
n = feats['target_feat'].shape[-2]
|
||||
seq_mask = feats['seq_mask']
|
||||
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
|
||||
msa_mask = feats['msa_mask']
|
||||
|
||||
m, z = self.input_embedder(
|
||||
feats['target_feat'],
|
||||
feats['msa_feat'],
|
||||
)
|
||||
|
||||
if m_1_prev is None:
|
||||
m_1_prev = m.new_zeros(
|
||||
(*batch_dims, n, self.config.input_embedder.d_msa),
|
||||
requires_grad=False,
|
||||
)
|
||||
if z_prev is None:
|
||||
z_prev = z.new_zeros(
|
||||
(*batch_dims, n, n, self.config.input_embedder.d_pair),
|
||||
requires_grad=False,
|
||||
)
|
||||
if x_prev is None:
|
||||
x_prev = z.new_zeros(
|
||||
(*batch_dims, n, residue_constants.atom_type_num, 3),
|
||||
requires_grad=False,
|
||||
)
|
||||
x_prev = pseudo_beta_fn(feats['aatype'], x_prev, None)
|
||||
|
||||
z += self.recycling_embedder.recyle_pos(x_prev)
|
||||
|
||||
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
|
||||
m_1_prev,
|
||||
z_prev,
|
||||
)
|
||||
|
||||
m[..., 0, :, :] += m_1_prev_emb
|
||||
|
||||
z += z_prev_emb
|
||||
|
||||
z += self.input_embedder.relpos_emb(
|
||||
feats['residue_index'].long(),
|
||||
feats.get('sym_id', None),
|
||||
feats.get('asym_id', None),
|
||||
feats.get('entity_id', None),
|
||||
feats.get('num_sym', None),
|
||||
)
|
||||
|
||||
m = m.type(self.dtype)
|
||||
z = z.type(self.dtype)
|
||||
tri_start_attn_mask, tri_end_attn_mask = gen_tri_attn_mask(
|
||||
pair_mask, self.inf)
|
||||
|
||||
if self.config.template.enabled:
|
||||
template_mask = feats['template_mask']
|
||||
if torch.any(template_mask):
|
||||
z = residual(
|
||||
z,
|
||||
self.embed_templates_pair(
|
||||
feats,
|
||||
z,
|
||||
pair_mask,
|
||||
tri_start_attn_mask,
|
||||
tri_end_attn_mask,
|
||||
templ_dim=-4,
|
||||
),
|
||||
self.training,
|
||||
)
|
||||
|
||||
if self.config.extra_msa.enabled:
|
||||
a = self.extra_msa_embedder(build_extra_msa_feat(feats))
|
||||
extra_msa_row_mask = gen_msa_attn_mask(
|
||||
feats['extra_msa_mask'],
|
||||
inf=self.inf,
|
||||
gen_col_mask=False,
|
||||
)
|
||||
z = self.extra_msa_stack(
|
||||
a,
|
||||
z,
|
||||
msa_mask=feats['extra_msa_mask'],
|
||||
chunk_size=self.globals.chunk_size,
|
||||
block_size=self.globals.block_size,
|
||||
pair_mask=pair_mask,
|
||||
msa_row_attn_mask=extra_msa_row_mask,
|
||||
msa_col_attn_mask=None,
|
||||
tri_start_attn_mask=tri_start_attn_mask,
|
||||
tri_end_attn_mask=tri_end_attn_mask,
|
||||
)
|
||||
|
||||
if self.config.template.embed_angles:
|
||||
template_1d_feat, template_1d_mask = self.embed_templates_angle(
|
||||
feats)
|
||||
m = torch.cat([m, template_1d_feat], dim=-3)
|
||||
msa_mask = torch.cat([feats['msa_mask'], template_1d_mask], dim=-2)
|
||||
|
||||
msa_row_mask, msa_col_mask = gen_msa_attn_mask(
|
||||
msa_mask,
|
||||
inf=self.inf,
|
||||
)
|
||||
|
||||
m, z, s = self.evoformer(
|
||||
m,
|
||||
z,
|
||||
msa_mask=msa_mask,
|
||||
pair_mask=pair_mask,
|
||||
msa_row_attn_mask=msa_row_mask,
|
||||
msa_col_attn_mask=msa_col_mask,
|
||||
tri_start_attn_mask=tri_start_attn_mask,
|
||||
tri_end_attn_mask=tri_end_attn_mask,
|
||||
chunk_size=self.globals.chunk_size,
|
||||
block_size=self.globals.block_size,
|
||||
)
|
||||
return m, z, s, msa_mask, m_1_prev_emb, z_prev_emb
|
||||
|
||||
def iteration_evoformer_structure_module(self,
|
||||
batch,
|
||||
m_1_prev,
|
||||
z_prev,
|
||||
x_prev,
|
||||
cycle_no,
|
||||
num_recycling,
|
||||
num_ensembles=1):
|
||||
z, s = 0, 0
|
||||
n_seq = batch['msa_feat'].shape[-3]
|
||||
assert num_ensembles >= 1
|
||||
for ensemble_no in range(num_ensembles):
|
||||
idx = cycle_no * num_ensembles + ensemble_no
|
||||
|
||||
# fetch_cur_batch = lambda t: t[min(t.shape[0] - 1, idx), ...]
|
||||
def fetch_cur_batch(t):
|
||||
return t[min(t.shape[0] - 1, idx), ...]
|
||||
|
||||
feats = tensor_tree_map(fetch_cur_batch, batch)
|
||||
m, z0, s0, msa_mask, m_1_prev_emb, z_prev_emb = self.iteration_evoformer(
|
||||
feats, m_1_prev, z_prev, x_prev)
|
||||
z += z0
|
||||
s += s0
|
||||
del z0, s0
|
||||
if num_ensembles > 1:
|
||||
z /= float(num_ensembles)
|
||||
s /= float(num_ensembles)
|
||||
|
||||
outputs = {}
|
||||
|
||||
outputs['msa'] = m[..., :n_seq, :, :]
|
||||
outputs['pair'] = z
|
||||
outputs['single'] = s
|
||||
|
||||
# norm loss
|
||||
if (not getattr(self, 'inference',
|
||||
False)) and num_recycling == (cycle_no + 1):
|
||||
delta_msa = m
|
||||
delta_msa[...,
|
||||
0, :, :] = delta_msa[...,
|
||||
0, :, :] - m_1_prev_emb.detach()
|
||||
delta_pair = z - z_prev_emb.detach()
|
||||
outputs['delta_msa'] = delta_msa
|
||||
outputs['delta_pair'] = delta_pair
|
||||
outputs['msa_norm_mask'] = msa_mask
|
||||
|
||||
outputs['sm'] = self.structure_module(
|
||||
s,
|
||||
z,
|
||||
feats['aatype'],
|
||||
mask=feats['seq_mask'],
|
||||
)
|
||||
outputs['final_atom_positions'] = atom14_to_atom37(
|
||||
outputs['sm']['positions'], feats)
|
||||
outputs['final_atom_mask'] = feats['atom37_atom_exists']
|
||||
outputs['pred_frame_tensor'] = outputs['sm']['frames'][-1]
|
||||
|
||||
# use float32 for numerical stability
|
||||
if (not getattr(self, 'inference', False)):
|
||||
m_1_prev = m[..., 0, :, :].float()
|
||||
z_prev = z.float()
|
||||
x_prev = outputs['final_atom_positions'].float()
|
||||
else:
|
||||
m_1_prev = m[..., 0, :, :]
|
||||
z_prev = z
|
||||
x_prev = outputs['final_atom_positions']
|
||||
|
||||
return outputs, m_1_prev, z_prev, x_prev
|
||||
|
||||
def forward(self, batch):
|
||||
|
||||
m_1_prev = batch.get('m_1_prev', None)
|
||||
z_prev = batch.get('z_prev', None)
|
||||
x_prev = batch.get('x_prev', None)
|
||||
|
||||
is_grad_enabled = torch.is_grad_enabled()
|
||||
|
||||
num_iters = int(batch['num_recycling_iters']) + 1
|
||||
num_ensembles = int(batch['msa_mask'].shape[0]) // num_iters
|
||||
if self.training:
|
||||
# don't use ensemble during training
|
||||
assert num_ensembles == 1
|
||||
|
||||
# convert dtypes in batch
|
||||
batch = self.__convert_input_dtype__(batch)
|
||||
for cycle_no in range(num_iters):
|
||||
is_final_iter = cycle_no == (num_iters - 1)
|
||||
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
|
||||
(
|
||||
outputs,
|
||||
m_1_prev,
|
||||
z_prev,
|
||||
x_prev,
|
||||
) = self.iteration_evoformer_structure_module(
|
||||
batch,
|
||||
m_1_prev,
|
||||
z_prev,
|
||||
x_prev,
|
||||
cycle_no=cycle_no,
|
||||
num_recycling=num_iters,
|
||||
num_ensembles=num_ensembles,
|
||||
)
|
||||
if not is_final_iter:
|
||||
del outputs
|
||||
|
||||
if 'asym_id' in batch:
|
||||
outputs['asym_id'] = batch['asym_id'][0, ...]
|
||||
outputs.update(self.aux_heads(outputs))
|
||||
return outputs
|
||||
430
modelscope/models/science/unifold/modules/attentions.py
Normal file
430
modelscope/models/science/unifold/modules/attentions.py
Normal file
@@ -0,0 +1,430 @@
|
||||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
|
||||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
|
||||
|
||||
from functools import partialmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from unicore.modules import LayerNorm, softmax_dropout
|
||||
from unicore.utils import permute_final_dims
|
||||
|
||||
from .common import Linear, chunk_layer
|
||||
|
||||
|
||||
def gen_attn_mask(mask, neg_inf):
|
||||
assert neg_inf < -1e4
|
||||
attn_mask = torch.zeros_like(mask)
|
||||
attn_mask[mask == 0] = neg_inf
|
||||
return attn_mask
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
q_dim: int,
|
||||
k_dim: int,
|
||||
v_dim: int,
|
||||
head_dim: int,
|
||||
num_heads: int,
|
||||
gating: bool = True,
|
||||
):
|
||||
super(Attention, self).__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
total_dim = head_dim * self.num_heads
|
||||
self.gating = gating
|
||||
self.linear_q = Linear(q_dim, total_dim, bias=False, init='glorot')
|
||||
self.linear_k = Linear(k_dim, total_dim, bias=False, init='glorot')
|
||||
self.linear_v = Linear(v_dim, total_dim, bias=False, init='glorot')
|
||||
self.linear_o = Linear(total_dim, q_dim, init='final')
|
||||
self.linear_g = None
|
||||
if self.gating:
|
||||
self.linear_g = Linear(q_dim, total_dim, init='gating')
|
||||
# precompute the 1/sqrt(head_dim)
|
||||
self.norm = head_dim**-0.5
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
g = None
|
||||
if self.linear_g is not None:
|
||||
# gating, use raw query input
|
||||
g = self.linear_g(q)
|
||||
|
||||
q = self.linear_q(q)
|
||||
q *= self.norm
|
||||
k = self.linear_k(k)
|
||||
v = self.linear_v(v)
|
||||
|
||||
q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose(
|
||||
-2, -3).contiguous()
|
||||
k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose(
|
||||
-2, -3).contiguous()
|
||||
v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3)
|
||||
|
||||
attn = torch.matmul(q, k.transpose(-1, -2))
|
||||
del q, k
|
||||
|
||||
attn = softmax_dropout(attn, 0, self.training, mask=mask, bias=bias)
|
||||
o = torch.matmul(attn, v)
|
||||
del attn, v
|
||||
|
||||
o = o.transpose(-2, -3).contiguous()
|
||||
o = o.view(*o.shape[:-2], -1)
|
||||
|
||||
if g is not None:
|
||||
o = torch.sigmoid(g) * o
|
||||
|
||||
# merge heads
|
||||
o = nn.functional.linear(o, self.linear_o.weight)
|
||||
return o
|
||||
|
||||
def get_output_bias(self):
|
||||
return self.linear_o.bias
|
||||
|
||||
|
||||
class GlobalAttention(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, head_dim, num_heads, inf, eps):
|
||||
super(GlobalAttention, self).__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.inf = inf
|
||||
self.eps = eps
|
||||
self.linear_q = Linear(
|
||||
input_dim, head_dim * num_heads, bias=False, init='glorot')
|
||||
self.linear_k = Linear(input_dim, head_dim, bias=False, init='glorot')
|
||||
self.linear_v = Linear(input_dim, head_dim, bias=False, init='glorot')
|
||||
self.linear_g = Linear(input_dim, head_dim * num_heads, init='gating')
|
||||
self.linear_o = Linear(head_dim * num_heads, input_dim, init='final')
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
# precompute the 1/sqrt(head_dim)
|
||||
self.norm = head_dim**-0.5
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# gating
|
||||
g = self.sigmoid(self.linear_g(x))
|
||||
|
||||
k = self.linear_k(x)
|
||||
v = self.linear_v(x)
|
||||
|
||||
q = torch.sum(
|
||||
x * mask.unsqueeze(-1), dim=-2) / (
|
||||
torch.sum(mask, dim=-1, keepdims=True) + self.eps)
|
||||
q = self.linear_q(q)
|
||||
q *= self.norm
|
||||
q = q.view(q.shape[:-1] + (self.num_heads, -1))
|
||||
|
||||
attn = torch.matmul(q, k.transpose(-1, -2))
|
||||
del q, k
|
||||
|
||||
attn_mask = gen_attn_mask(mask, -self.inf)[..., :, None, :]
|
||||
attn = softmax_dropout(attn, 0, self.training, mask=attn_mask)
|
||||
|
||||
o = torch.matmul(
|
||||
attn,
|
||||
v,
|
||||
)
|
||||
del attn, v
|
||||
|
||||
g = g.view(g.shape[:-1] + (self.num_heads, -1))
|
||||
o = o.unsqueeze(-3) * g
|
||||
del g
|
||||
|
||||
# merge heads
|
||||
o = o.reshape(o.shape[:-2] + (-1, ))
|
||||
return self.linear_o(o)
|
||||
|
||||
|
||||
def gen_msa_attn_mask(mask, inf, gen_col_mask=True):
|
||||
row_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :]
|
||||
if gen_col_mask:
|
||||
col_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None,
|
||||
None, :]
|
||||
return row_mask, col_mask
|
||||
else:
|
||||
return row_mask
|
||||
|
||||
|
||||
class MSAAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_in,
|
||||
d_hid,
|
||||
num_heads,
|
||||
pair_bias=False,
|
||||
d_pair=None,
|
||||
):
|
||||
super(MSAAttention, self).__init__()
|
||||
|
||||
self.pair_bias = pair_bias
|
||||
self.layer_norm_m = LayerNorm(d_in)
|
||||
self.layer_norm_z = None
|
||||
self.linear_z = None
|
||||
if self.pair_bias:
|
||||
self.layer_norm_z = LayerNorm(d_pair)
|
||||
self.linear_z = Linear(
|
||||
d_pair, num_heads, bias=False, init='normal')
|
||||
|
||||
self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads)
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
chunk_size: int = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
return chunk_layer(
|
||||
self._attn_forward,
|
||||
{
|
||||
'm': m,
|
||||
'mask': mask,
|
||||
'bias': bias
|
||||
},
|
||||
chunk_size=chunk_size,
|
||||
num_batch_dims=len(m.shape[:-2]),
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def _attn_chunk_forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = 2560,
|
||||
) -> torch.Tensor:
|
||||
m = self.layer_norm_m(m)
|
||||
num_chunk = (m.shape[-3] + chunk_size - 1) // chunk_size
|
||||
outputs = []
|
||||
for i in range(num_chunk):
|
||||
chunk_start = i * chunk_size
|
||||
chunk_end = min(m.shape[-3], chunk_start + chunk_size)
|
||||
cur_m = m[..., chunk_start:chunk_end, :, :]
|
||||
cur_mask = (
|
||||
mask[..., chunk_start:chunk_end, :, :, :]
|
||||
if mask is not None else None)
|
||||
outputs.append(
|
||||
self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias))
|
||||
return torch.concat(outputs, dim=-3)
|
||||
|
||||
def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None):
|
||||
m = self.layer_norm_m(m)
|
||||
return self.mha(q=m, k=m, v=m, mask=mask, bias=bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
z: Optional[torch.Tensor] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
bias = None
|
||||
if self.pair_bias:
|
||||
z = self.layer_norm_z(z)
|
||||
bias = (
|
||||
permute_final_dims(self.linear_z(z),
|
||||
(2, 0, 1)).unsqueeze(-4).contiguous())
|
||||
|
||||
if chunk_size is not None:
|
||||
m = self._chunk(m, attn_mask, bias, chunk_size)
|
||||
else:
|
||||
attn_chunk_size = 2560
|
||||
if m.shape[-3] <= attn_chunk_size:
|
||||
m = self._attn_forward(m, attn_mask, bias)
|
||||
else:
|
||||
# reduce the peak memory cost in extra_msa_stack
|
||||
return self._attn_chunk_forward(
|
||||
m, attn_mask, bias, chunk_size=attn_chunk_size)
|
||||
|
||||
return m
|
||||
|
||||
def get_output_bias(self):
|
||||
return self.mha.get_output_bias()
|
||||
|
||||
|
||||
class MSARowAttentionWithPairBias(MSAAttention):
|
||||
|
||||
def __init__(self, d_msa, d_pair, d_hid, num_heads):
|
||||
super(MSARowAttentionWithPairBias, self).__init__(
|
||||
d_msa,
|
||||
d_hid,
|
||||
num_heads,
|
||||
pair_bias=True,
|
||||
d_pair=d_pair,
|
||||
)
|
||||
|
||||
|
||||
class MSAColumnAttention(MSAAttention):
|
||||
|
||||
def __init__(self, d_msa, d_hid, num_heads):
|
||||
super(MSAColumnAttention, self).__init__(
|
||||
d_in=d_msa,
|
||||
d_hid=d_hid,
|
||||
num_heads=num_heads,
|
||||
pair_bias=False,
|
||||
d_pair=None,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
m = m.transpose(-2, -3)
|
||||
m = super().forward(m, attn_mask=attn_mask, chunk_size=chunk_size)
|
||||
m = m.transpose(-2, -3)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
class MSAColumnGlobalAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_in,
|
||||
d_hid,
|
||||
num_heads,
|
||||
inf=1e9,
|
||||
eps=1e-10,
|
||||
):
|
||||
super(MSAColumnGlobalAttention, self).__init__()
|
||||
|
||||
self.layer_norm_m = LayerNorm(d_in)
|
||||
self.global_attention = GlobalAttention(
|
||||
d_in,
|
||||
d_hid,
|
||||
num_heads,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
return chunk_layer(
|
||||
self._attn_forward,
|
||||
{
|
||||
'm': m,
|
||||
'mask': mask
|
||||
},
|
||||
chunk_size=chunk_size,
|
||||
num_batch_dims=len(m.shape[:-2]),
|
||||
)
|
||||
|
||||
def _attn_forward(self, m, mask):
|
||||
m = self.layer_norm_m(m)
|
||||
return self.global_attention(m, mask=mask)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
m = m.transpose(-2, -3)
|
||||
mask = mask.transpose(-1, -2)
|
||||
|
||||
if chunk_size is not None:
|
||||
m = self._chunk(m, mask, chunk_size)
|
||||
else:
|
||||
m = self._attn_forward(m, mask=mask)
|
||||
|
||||
m = m.transpose(-2, -3)
|
||||
return m
|
||||
|
||||
|
||||
def gen_tri_attn_mask(mask, inf):
|
||||
start_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :]
|
||||
end_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None,
|
||||
None, :]
|
||||
return start_mask, end_mask
|
||||
|
||||
|
||||
class TriangleAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_in,
|
||||
d_hid,
|
||||
num_heads,
|
||||
starting,
|
||||
):
|
||||
super(TriangleAttention, self).__init__()
|
||||
self.starting = starting
|
||||
self.layer_norm = LayerNorm(d_in)
|
||||
self.linear = Linear(d_in, num_heads, bias=False, init='normal')
|
||||
self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads)
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
chunk_size: int = None,
|
||||
) -> torch.Tensor:
|
||||
return chunk_layer(
|
||||
self.mha,
|
||||
{
|
||||
'q': x,
|
||||
'k': x,
|
||||
'v': x,
|
||||
'mask': mask,
|
||||
'bias': bias
|
||||
},
|
||||
chunk_size=chunk_size,
|
||||
num_batch_dims=len(x.shape[:-2]),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
if not self.starting:
|
||||
x = x.transpose(-2, -3)
|
||||
|
||||
x = self.layer_norm(x)
|
||||
triangle_bias = (
|
||||
permute_final_dims(self.linear(x),
|
||||
(2, 0, 1)).unsqueeze(-4).contiguous())
|
||||
|
||||
if chunk_size is not None:
|
||||
x = self._chunk(x, attn_mask, triangle_bias, chunk_size)
|
||||
else:
|
||||
x = self.mha(q=x, k=x, v=x, mask=attn_mask, bias=triangle_bias)
|
||||
|
||||
if not self.starting:
|
||||
x = x.transpose(-2, -3)
|
||||
return x
|
||||
|
||||
def get_output_bias(self):
|
||||
return self.mha.get_output_bias()
|
||||
|
||||
|
||||
class TriangleAttentionStarting(TriangleAttention):
|
||||
__init__ = partialmethod(TriangleAttention.__init__, starting=True)
|
||||
|
||||
|
||||
class TriangleAttentionEnding(TriangleAttention):
|
||||
__init__ = partialmethod(TriangleAttention.__init__, starting=False)
|
||||
171
modelscope/models/science/unifold/modules/auxillary_heads.py
Normal file
171
modelscope/models/science/unifold/modules/auxillary_heads.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
|
||||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch.nn as nn
|
||||
from unicore.modules import LayerNorm
|
||||
|
||||
from .common import Linear
|
||||
from .confidence import (predicted_aligned_error, predicted_lddt,
|
||||
predicted_tm_score)
|
||||
|
||||
|
||||
class AuxiliaryHeads(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(AuxiliaryHeads, self).__init__()
|
||||
|
||||
self.plddt = PredictedLDDTHead(**config['plddt'], )
|
||||
|
||||
self.distogram = DistogramHead(**config['distogram'], )
|
||||
|
||||
self.masked_msa = MaskedMSAHead(**config['masked_msa'], )
|
||||
|
||||
if config.experimentally_resolved.enabled:
|
||||
self.experimentally_resolved = ExperimentallyResolvedHead(
|
||||
**config['experimentally_resolved'], )
|
||||
|
||||
if config.pae.enabled:
|
||||
self.pae = PredictedAlignedErrorHead(**config.pae, )
|
||||
|
||||
self.config = config
|
||||
|
||||
def forward(self, outputs):
|
||||
aux_out = {}
|
||||
plddt_logits = self.plddt(outputs['sm']['single'])
|
||||
aux_out['plddt_logits'] = plddt_logits
|
||||
|
||||
aux_out['plddt'] = predicted_lddt(plddt_logits.detach())
|
||||
|
||||
distogram_logits = self.distogram(outputs['pair'])
|
||||
aux_out['distogram_logits'] = distogram_logits
|
||||
|
||||
masked_msa_logits = self.masked_msa(outputs['msa'])
|
||||
aux_out['masked_msa_logits'] = masked_msa_logits
|
||||
|
||||
if self.config.experimentally_resolved.enabled:
|
||||
exp_res_logits = self.experimentally_resolved(outputs['single'])
|
||||
aux_out['experimentally_resolved_logits'] = exp_res_logits
|
||||
|
||||
if self.config.pae.enabled:
|
||||
pae_logits = self.pae(outputs['pair'])
|
||||
aux_out['pae_logits'] = pae_logits
|
||||
pae_logits = pae_logits.detach()
|
||||
aux_out.update(
|
||||
predicted_aligned_error(
|
||||
pae_logits,
|
||||
**self.config.pae,
|
||||
))
|
||||
aux_out['ptm'] = predicted_tm_score(
|
||||
pae_logits, interface=False, **self.config.pae)
|
||||
|
||||
iptm_weight = self.config.pae.get('iptm_weight', 0.0)
|
||||
if iptm_weight > 0.0:
|
||||
aux_out['iptm'] = predicted_tm_score(
|
||||
pae_logits,
|
||||
interface=True,
|
||||
asym_id=outputs['asym_id'],
|
||||
**self.config.pae,
|
||||
)
|
||||
aux_out['iptm+ptm'] = (
|
||||
iptm_weight * aux_out['iptm'] + # noqa W504
|
||||
(1.0 - iptm_weight) * aux_out['ptm'])
|
||||
|
||||
return aux_out
|
||||
|
||||
|
||||
class PredictedLDDTHead(nn.Module):
|
||||
|
||||
def __init__(self, num_bins, d_in, d_hid):
|
||||
super(PredictedLDDTHead, self).__init__()
|
||||
|
||||
self.num_bins = num_bins
|
||||
self.d_in = d_in
|
||||
self.d_hid = d_hid
|
||||
|
||||
self.layer_norm = LayerNorm(self.d_in)
|
||||
|
||||
self.linear_1 = Linear(self.d_in, self.d_hid, init='relu')
|
||||
self.linear_2 = Linear(self.d_hid, self.d_hid, init='relu')
|
||||
self.act = nn.GELU()
|
||||
self.linear_3 = Linear(self.d_hid, self.num_bins, init='final')
|
||||
|
||||
def forward(self, s):
|
||||
s = self.layer_norm(s)
|
||||
s = self.linear_1(s)
|
||||
s = self.act(s)
|
||||
s = self.linear_2(s)
|
||||
s = self.act(s)
|
||||
s = self.linear_3(s)
|
||||
return s
|
||||
|
||||
|
||||
class EnhancedHeadBase(nn.Module):
|
||||
|
||||
def __init__(self, d_in, d_out, disable_enhance_head):
|
||||
super(EnhancedHeadBase, self).__init__()
|
||||
if disable_enhance_head:
|
||||
self.layer_norm = None
|
||||
self.linear_in = None
|
||||
else:
|
||||
self.layer_norm = LayerNorm(d_in)
|
||||
self.linear_in = Linear(d_in, d_in, init='relu')
|
||||
self.act = nn.GELU()
|
||||
self.linear = Linear(d_in, d_out, init='final')
|
||||
|
||||
def apply_alphafold_original_mode(self):
|
||||
self.layer_norm = None
|
||||
self.linear_in = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.layer_norm is not None:
|
||||
x = self.layer_norm(x)
|
||||
x = self.act(self.linear_in(x))
|
||||
logits = self.linear(x)
|
||||
return logits
|
||||
|
||||
|
||||
class DistogramHead(EnhancedHeadBase):
|
||||
|
||||
def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs):
|
||||
super(DistogramHead, self).__init__(
|
||||
d_in=d_pair,
|
||||
d_out=num_bins,
|
||||
disable_enhance_head=disable_enhance_head,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
logits = super().forward(x)
|
||||
logits = logits + logits.transpose(-2, -3)
|
||||
return logits
|
||||
|
||||
|
||||
class PredictedAlignedErrorHead(EnhancedHeadBase):
|
||||
|
||||
def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs):
|
||||
super(PredictedAlignedErrorHead, self).__init__(
|
||||
d_in=d_pair,
|
||||
d_out=num_bins,
|
||||
disable_enhance_head=disable_enhance_head,
|
||||
)
|
||||
|
||||
|
||||
class MaskedMSAHead(EnhancedHeadBase):
|
||||
|
||||
def __init__(self, d_msa, d_out, disable_enhance_head, **kwargs):
|
||||
super(MaskedMSAHead, self).__init__(
|
||||
d_in=d_msa,
|
||||
d_out=d_out,
|
||||
disable_enhance_head=disable_enhance_head,
|
||||
)
|
||||
|
||||
|
||||
class ExperimentallyResolvedHead(EnhancedHeadBase):
|
||||
|
||||
def __init__(self, d_single, d_out, disable_enhance_head, **kwargs):
|
||||
super(ExperimentallyResolvedHead, self).__init__(
|
||||
d_in=d_single,
|
||||
d_out=d_out,
|
||||
disable_enhance_head=disable_enhance_head,
|
||||
)
|
||||
387
modelscope/models/science/unifold/modules/common.py
Normal file
387
modelscope/models/science/unifold/modules/common.py
Normal file
@@ -0,0 +1,387 @@
|
||||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
|
||||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from unicore.modules import LayerNorm
|
||||
from unicore.utils import tensor_tree_map
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_in: int,
|
||||
d_out: int,
|
||||
bias: bool = True,
|
||||
init: str = 'default',
|
||||
):
|
||||
super(Linear, self).__init__(d_in, d_out, bias=bias)
|
||||
|
||||
self.use_bias = bias
|
||||
|
||||
if self.use_bias:
|
||||
with torch.no_grad():
|
||||
self.bias.fill_(0)
|
||||
|
||||
if init == 'default':
|
||||
self._trunc_normal_init(1.0)
|
||||
elif init == 'relu':
|
||||
self._trunc_normal_init(2.0)
|
||||
elif init == 'glorot':
|
||||
self._glorot_uniform_init()
|
||||
elif init == 'gating':
|
||||
self._zero_init(self.use_bias)
|
||||
elif init == 'normal':
|
||||
self._normal_init()
|
||||
elif init == 'final':
|
||||
self._zero_init(False)
|
||||
else:
|
||||
raise ValueError('Invalid init method.')
|
||||
|
||||
def _trunc_normal_init(self, scale=1.0):
|
||||
# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
|
||||
TRUNCATED_NORMAL_STDDEV_FACTOR = 0.87962566103423978
|
||||
_, fan_in = self.weight.shape
|
||||
scale = scale / max(1, fan_in)
|
||||
std = (scale**0.5) / TRUNCATED_NORMAL_STDDEV_FACTOR
|
||||
nn.init.trunc_normal_(self.weight, mean=0.0, std=std)
|
||||
|
||||
def _glorot_uniform_init(self):
|
||||
nn.init.xavier_uniform_(self.weight, gain=1)
|
||||
|
||||
def _zero_init(self, use_bias=True):
|
||||
with torch.no_grad():
|
||||
self.weight.fill_(0.0)
|
||||
if use_bias:
|
||||
with torch.no_grad():
|
||||
self.bias.fill_(1.0)
|
||||
|
||||
def _normal_init(self):
|
||||
torch.nn.init.kaiming_normal_(self.weight, nonlinearity='linear')
|
||||
|
||||
|
||||
class Transition(nn.Module):
|
||||
|
||||
def __init__(self, d_in, n):
|
||||
|
||||
super(Transition, self).__init__()
|
||||
|
||||
self.d_in = d_in
|
||||
self.n = n
|
||||
|
||||
self.layer_norm = LayerNorm(self.d_in)
|
||||
self.linear_1 = Linear(self.d_in, self.n * self.d_in, init='relu')
|
||||
self.act = nn.GELU()
|
||||
self.linear_2 = Linear(self.n * self.d_in, d_in, init='final')
|
||||
|
||||
def _transition(self, x):
|
||||
x = self.layer_norm(x)
|
||||
x = self.linear_1(x)
|
||||
x = self.act(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
return chunk_layer(
|
||||
self._transition,
|
||||
{'x': x},
|
||||
chunk_size=chunk_size,
|
||||
num_batch_dims=len(x.shape[:-2]),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if chunk_size is not None:
|
||||
x = self._chunk(x, chunk_size)
|
||||
else:
|
||||
x = self._transition(x=x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class OuterProductMean(nn.Module):
|
||||
|
||||
def __init__(self, d_msa, d_pair, d_hid, eps=1e-3):
|
||||
super(OuterProductMean, self).__init__()
|
||||
|
||||
self.d_msa = d_msa
|
||||
self.d_pair = d_pair
|
||||
self.d_hid = d_hid
|
||||
self.eps = eps
|
||||
|
||||
self.layer_norm = LayerNorm(d_msa)
|
||||
self.linear_1 = Linear(d_msa, d_hid)
|
||||
self.linear_2 = Linear(d_msa, d_hid)
|
||||
self.linear_out = Linear(d_hid**2, d_pair, init='relu')
|
||||
self.act = nn.GELU()
|
||||
self.linear_z = Linear(self.d_pair, self.d_pair, init='final')
|
||||
self.layer_norm_out = LayerNorm(self.d_pair)
|
||||
|
||||
def _opm(self, a, b):
|
||||
outer = torch.einsum('...bac,...dae->...bdce', a, b)
|
||||
outer = outer.reshape(outer.shape[:-2] + (-1, ))
|
||||
outer = self.linear_out(outer)
|
||||
return outer
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self, a: torch.Tensor, b: torch.Tensor,
|
||||
chunk_size: int) -> torch.Tensor:
|
||||
a = a.reshape((-1, ) + a.shape[-3:])
|
||||
b = b.reshape((-1, ) + b.shape[-3:])
|
||||
out = []
|
||||
# TODO: optimize this
|
||||
for a_prime, b_prime in zip(a, b):
|
||||
outer = chunk_layer(
|
||||
partial(self._opm, b=b_prime),
|
||||
{'a': a_prime},
|
||||
chunk_size=chunk_size,
|
||||
num_batch_dims=1,
|
||||
)
|
||||
out.append(outer)
|
||||
if len(out) == 1:
|
||||
outer = out[0].unsqueeze(0)
|
||||
else:
|
||||
outer = torch.stack(out, dim=0)
|
||||
outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
|
||||
|
||||
return outer
|
||||
|
||||
def apply_alphafold_original_mode(self):
|
||||
self.linear_z = None
|
||||
self.layer_norm_out = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
m = self.layer_norm(m)
|
||||
mask = mask.unsqueeze(-1)
|
||||
if self.layer_norm_out is not None:
|
||||
# for numerical stability
|
||||
mask = mask * (mask.size(-2)**-0.5)
|
||||
a = self.linear_1(m)
|
||||
b = self.linear_2(m)
|
||||
if self.training:
|
||||
a = a * mask
|
||||
b = b * mask
|
||||
else:
|
||||
a *= mask
|
||||
b *= mask
|
||||
|
||||
a = a.transpose(-2, -3)
|
||||
b = b.transpose(-2, -3)
|
||||
|
||||
if chunk_size is not None:
|
||||
z = self._chunk(a, b, chunk_size)
|
||||
else:
|
||||
z = self._opm(a, b)
|
||||
|
||||
norm = torch.einsum('...abc,...adc->...bdc', mask, mask)
|
||||
z /= self.eps + norm
|
||||
if self.layer_norm_out is not None:
|
||||
z = self.act(z)
|
||||
z = self.layer_norm_out(z)
|
||||
z = self.linear_z(z)
|
||||
return z
|
||||
|
||||
|
||||
def residual(residual, x, training):
|
||||
if training:
|
||||
return x + residual
|
||||
else:
|
||||
residual += x
|
||||
return residual
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_bias_dropout_add(
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
dropmask: torch.Tensor,
|
||||
prob: float,
|
||||
) -> torch.Tensor:
|
||||
return (x + bias) * F.dropout(dropmask, p=prob, training=True) + residual
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_bias_dropout_add_inference(
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
residual += bias + x
|
||||
return residual
|
||||
|
||||
|
||||
def bias_dropout_residual(module, residual, x, dropout_shared_dim, prob,
|
||||
training):
|
||||
bias = module.get_output_bias()
|
||||
if training:
|
||||
shape = list(x.shape)
|
||||
shape[dropout_shared_dim] = 1
|
||||
with torch.no_grad():
|
||||
mask = x.new_ones(shape)
|
||||
return fused_bias_dropout_add(x, bias, residual, mask, prob)
|
||||
else:
|
||||
return fused_bias_dropout_add_inference(x, bias, residual)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_bias_gated_dropout_add(
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
g_bias: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
dropout_mask: torch.Tensor,
|
||||
prob: float,
|
||||
) -> torch.Tensor:
|
||||
return (torch.sigmoid(g + g_bias) * (x + bias)) * F.dropout(
|
||||
dropout_mask,
|
||||
p=prob,
|
||||
training=True,
|
||||
) + residual
|
||||
|
||||
|
||||
def tri_mul_residual(
|
||||
module,
|
||||
residual,
|
||||
outputs,
|
||||
dropout_shared_dim,
|
||||
prob,
|
||||
training,
|
||||
block_size,
|
||||
):
|
||||
if training:
|
||||
x, g = outputs
|
||||
bias, g_bias = module.get_output_bias()
|
||||
shape = list(x.shape)
|
||||
shape[dropout_shared_dim] = 1
|
||||
with torch.no_grad():
|
||||
mask = x.new_ones(shape)
|
||||
return fused_bias_gated_dropout_add(
|
||||
x,
|
||||
bias,
|
||||
g,
|
||||
g_bias,
|
||||
residual,
|
||||
mask,
|
||||
prob,
|
||||
)
|
||||
elif block_size is None:
|
||||
x, g = outputs
|
||||
bias, g_bias = module.get_output_bias()
|
||||
residual += (torch.sigmoid(g + g_bias) * (x + bias))
|
||||
return residual
|
||||
else:
|
||||
# gated is not used here
|
||||
residual += outputs
|
||||
return residual
|
||||
|
||||
|
||||
class SimpleModuleList(nn.ModuleList):
|
||||
|
||||
def __repr__(self):
|
||||
return str(len(self)) + ' X ...\n' + self[0].__repr__()
|
||||
|
||||
|
||||
def chunk_layer(
|
||||
layer: Callable,
|
||||
inputs: Dict[str, Any],
|
||||
chunk_size: int,
|
||||
num_batch_dims: int,
|
||||
) -> Any:
|
||||
# TODO: support inplace add to output
|
||||
if not (len(inputs) > 0):
|
||||
raise ValueError('Must provide at least one input')
|
||||
|
||||
def _dict_get_shapes(input):
|
||||
shapes = []
|
||||
if type(input) is torch.Tensor:
|
||||
shapes.append(input.shape)
|
||||
elif type(input) is dict:
|
||||
for v in input.values():
|
||||
shapes.extend(_dict_get_shapes(v))
|
||||
elif isinstance(input, Iterable):
|
||||
for v in input:
|
||||
shapes.extend(_dict_get_shapes(v))
|
||||
else:
|
||||
raise ValueError('Not supported')
|
||||
|
||||
return shapes
|
||||
|
||||
inputs = {k: v for k, v in inputs.items() if v is not None}
|
||||
initial_dims = [
|
||||
shape[:num_batch_dims] for shape in _dict_get_shapes(inputs)
|
||||
]
|
||||
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
|
||||
|
||||
flat_batch_dim = 1
|
||||
for d in orig_batch_dims:
|
||||
flat_batch_dim *= d
|
||||
num_chunks = (flat_batch_dim + chunk_size - 1) // chunk_size
|
||||
|
||||
def _flat_inputs(t):
|
||||
t = t.view(-1, *t.shape[num_batch_dims:])
|
||||
assert (
|
||||
t.shape[0] == flat_batch_dim or t.shape[0] == 1
|
||||
), 'batch dimension must be 1 or equal to the flat batch dimension'
|
||||
return t
|
||||
|
||||
flat_inputs = tensor_tree_map(_flat_inputs, inputs)
|
||||
|
||||
out = None
|
||||
for i in range(num_chunks):
|
||||
chunk_start = i * chunk_size
|
||||
chunk_end = min((i + 1) * chunk_size, flat_batch_dim)
|
||||
|
||||
def select_chunk(t):
|
||||
if t.shape[0] == 1:
|
||||
return t[0:1]
|
||||
else:
|
||||
return t[chunk_start:chunk_end]
|
||||
|
||||
chunkes = tensor_tree_map(select_chunk, flat_inputs)
|
||||
|
||||
output_chunk = layer(**chunkes)
|
||||
|
||||
if out is None:
|
||||
out = tensor_tree_map(
|
||||
lambda t: t.new_zeros((flat_batch_dim, ) + t.shape[1:]),
|
||||
output_chunk)
|
||||
|
||||
out_type = type(output_chunk)
|
||||
if out_type is tuple:
|
||||
for x, y in zip(out, output_chunk):
|
||||
x[chunk_start:chunk_end] = y
|
||||
elif out_type is torch.Tensor:
|
||||
out[chunk_start:chunk_end] = output_chunk
|
||||
else:
|
||||
raise ValueError('Not supported')
|
||||
|
||||
# reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
|
||||
def reshape(t):
|
||||
return t.view(orig_batch_dims + t.shape[1:])
|
||||
|
||||
out = tensor_tree_map(reshape, out)
|
||||
|
||||
return out
|
||||
159
modelscope/models/science/unifold/modules/confidence.py
Normal file
159
modelscope/models/science/unifold/modules/confidence.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
|
||||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
|
||||
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def predicted_lddt(plddt_logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes per-residue pLDDT from logits.
|
||||
Args:
|
||||
logits: [num_res, num_bins] output from the PredictedLDDTHead.
|
||||
Returns:
|
||||
plddt: [num_res] per-residue pLDDT.
|
||||
"""
|
||||
num_bins = plddt_logits.shape[-1]
|
||||
bin_probs = torch.nn.functional.softmax(plddt_logits.float(), dim=-1)
|
||||
bin_width = 1.0 / num_bins
|
||||
bounds = torch.arange(
|
||||
start=0.5 * bin_width,
|
||||
end=1.0,
|
||||
step=bin_width,
|
||||
device=plddt_logits.device)
|
||||
plddt = torch.sum(
|
||||
bin_probs
|
||||
* bounds.view(*((1, ) * len(bin_probs.shape[:-1])), *bounds.shape),
|
||||
dim=-1,
|
||||
)
|
||||
return plddt
|
||||
|
||||
|
||||
def compute_bin_values(breaks: torch.Tensor):
|
||||
"""Gets the bin centers from the bin edges.
|
||||
Args:
|
||||
breaks: [num_bins - 1] the error bin edges.
|
||||
Returns:
|
||||
bin_centers: [num_bins] the error bin centers.
|
||||
"""
|
||||
step = breaks[1] - breaks[0]
|
||||
bin_values = breaks + step / 2
|
||||
bin_values = torch.cat([bin_values, (bin_values[-1] + step).unsqueeze(-1)],
|
||||
dim=0)
|
||||
return bin_values
|
||||
|
||||
|
||||
def compute_predicted_aligned_error(
|
||||
bin_edges: torch.Tensor,
|
||||
bin_probs: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Calculates expected aligned distance errors for every pair of residues.
|
||||
Args:
|
||||
alignment_confidence_breaks: [num_bins - 1] the error bin edges.
|
||||
aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted
|
||||
probs for each error bin, for each pair of residues.
|
||||
Returns:
|
||||
predicted_aligned_error: [num_res, num_res] the expected aligned distance
|
||||
error for each pair of residues.
|
||||
max_predicted_aligned_error: The maximum predicted error possible.
|
||||
"""
|
||||
bin_values = compute_bin_values(bin_edges)
|
||||
return torch.sum(bin_probs * bin_values, dim=-1)
|
||||
|
||||
|
||||
def predicted_aligned_error(
|
||||
pae_logits: torch.Tensor,
|
||||
max_bin: int = 31,
|
||||
num_bins: int = 64,
|
||||
**kwargs,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Computes aligned confidence metrics from logits.
|
||||
Args:
|
||||
logits: [num_res, num_res, num_bins] the logits output from
|
||||
PredictedAlignedErrorHead.
|
||||
breaks: [num_bins - 1] the error bin edges.
|
||||
Returns:
|
||||
aligned_confidence_probs: [num_res, num_res, num_bins] the predicted
|
||||
aligned error probabilities over bins for each residue pair.
|
||||
predicted_aligned_error: [num_res, num_res] the expected aligned distance
|
||||
error for each pair of residues.
|
||||
max_predicted_aligned_error: The maximum predicted error possible.
|
||||
"""
|
||||
bin_probs = torch.nn.functional.softmax(pae_logits.float(), dim=-1)
|
||||
bin_edges = torch.linspace(
|
||||
0, max_bin, steps=(num_bins - 1), device=pae_logits.device)
|
||||
|
||||
predicted_aligned_error = compute_predicted_aligned_error(
|
||||
bin_edges=bin_edges,
|
||||
bin_probs=bin_probs,
|
||||
)
|
||||
|
||||
return {
|
||||
'aligned_error_probs_per_bin': bin_probs,
|
||||
'predicted_aligned_error': predicted_aligned_error,
|
||||
}
|
||||
|
||||
|
||||
def predicted_tm_score(
|
||||
pae_logits: torch.Tensor,
|
||||
residue_weights: Optional[torch.Tensor] = None,
|
||||
max_bin: int = 31,
|
||||
num_bins: int = 64,
|
||||
eps: float = 1e-8,
|
||||
asym_id: Optional[torch.Tensor] = None,
|
||||
interface: bool = False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Computes predicted TM alignment or predicted interface TM alignment score.
|
||||
Args:
|
||||
logits: [num_res, num_res, num_bins] the logits output from
|
||||
PredictedAlignedErrorHead.
|
||||
breaks: [num_bins] the error bins.
|
||||
residue_weights: [num_res] the per residue weights to use for the
|
||||
expectation.
|
||||
asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
|
||||
ipTM calculation, i.e. when interface=True.
|
||||
interface: If True, interface predicted TM score is computed.
|
||||
Returns:
|
||||
ptm_score: The predicted TM alignment or the predicted iTM score.
|
||||
"""
|
||||
pae_logits = pae_logits.float()
|
||||
if residue_weights is None:
|
||||
residue_weights = pae_logits.new_ones(pae_logits.shape[:-2])
|
||||
|
||||
breaks = torch.linspace(
|
||||
0, max_bin, steps=(num_bins - 1), device=pae_logits.device)
|
||||
|
||||
def tm_kernal(nres):
|
||||
clipped_n = max(nres, 19)
|
||||
d0 = 1.24 * (clipped_n - 15)**(1.0 / 3.0) - 1.8
|
||||
return lambda x: 1.0 / (1.0 + (x / d0)**2)
|
||||
|
||||
def rmsd_kernal(eps): # leave for compute pRMS
|
||||
return lambda x: 1. / (x + eps)
|
||||
|
||||
bin_centers = compute_bin_values(breaks)
|
||||
probs = torch.nn.functional.softmax(pae_logits, dim=-1)
|
||||
tm_per_bin = tm_kernal(nres=pae_logits.shape[-2])(bin_centers)
|
||||
# tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2))
|
||||
# rmsd_per_bin = rmsd_kernal()(bin_centers)
|
||||
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
|
||||
|
||||
pair_mask = predicted_tm_term.new_ones(predicted_tm_term.shape)
|
||||
if interface:
|
||||
assert asym_id is not None, 'must provide asym_id for iptm calculation.'
|
||||
pair_mask *= asym_id[..., :, None] != asym_id[..., None, :]
|
||||
|
||||
predicted_tm_term *= pair_mask
|
||||
|
||||
pair_residue_weights = pair_mask * (
|
||||
residue_weights[None, :] * residue_weights[:, None])
|
||||
normed_residue_mask = pair_residue_weights / (
|
||||
eps + pair_residue_weights.sum(dim=-1, keepdim=True))
|
||||
|
||||
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
|
||||
weighted = per_alignment * residue_weights
|
||||
ret = per_alignment.gather(
|
||||
dim=-1, index=weighted.max(dim=-1,
|
||||
keepdim=True).indices).squeeze(dim=-1)
|
||||
return ret
|
||||
290
modelscope/models/science/unifold/modules/embedders.py
Normal file
290
modelscope/models/science/unifold/modules/embedders.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
|
||||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from unicore.modules import LayerNorm
|
||||
from unicore.utils import one_hot
|
||||
|
||||
from .common import Linear, SimpleModuleList, residual
|
||||
|
||||
|
||||
class InputEmbedder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tf_dim: int,
|
||||
msa_dim: int,
|
||||
d_pair: int,
|
||||
d_msa: int,
|
||||
relpos_k: int,
|
||||
use_chain_relative: bool = False,
|
||||
max_relative_chain: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super(InputEmbedder, self).__init__()
|
||||
|
||||
self.tf_dim = tf_dim
|
||||
self.msa_dim = msa_dim
|
||||
|
||||
self.d_pair = d_pair
|
||||
self.d_msa = d_msa
|
||||
|
||||
self.linear_tf_z_i = Linear(tf_dim, d_pair)
|
||||
self.linear_tf_z_j = Linear(tf_dim, d_pair)
|
||||
self.linear_tf_m = Linear(tf_dim, d_msa)
|
||||
self.linear_msa_m = Linear(msa_dim, d_msa)
|
||||
|
||||
# RPE stuff
|
||||
self.relpos_k = relpos_k
|
||||
self.use_chain_relative = use_chain_relative
|
||||
self.max_relative_chain = max_relative_chain
|
||||
if not self.use_chain_relative:
|
||||
self.num_bins = 2 * self.relpos_k + 1
|
||||
else:
|
||||
self.num_bins = 2 * self.relpos_k + 2
|
||||
self.num_bins += 1 # entity id
|
||||
self.num_bins += 2 * max_relative_chain + 2
|
||||
|
||||
self.linear_relpos = Linear(self.num_bins, d_pair)
|
||||
|
||||
def _relpos_indices(
|
||||
self,
|
||||
res_id: torch.Tensor,
|
||||
sym_id: Optional[torch.Tensor] = None,
|
||||
asym_id: Optional[torch.Tensor] = None,
|
||||
entity_id: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
||||
max_rel_res = self.relpos_k
|
||||
rp = res_id[..., None] - res_id[..., None, :]
|
||||
rp = rp.clip(-max_rel_res, max_rel_res) + max_rel_res
|
||||
if not self.use_chain_relative:
|
||||
return rp
|
||||
else:
|
||||
asym_id_same = asym_id[..., :, None] == asym_id[..., None, :]
|
||||
rp[~asym_id_same] = 2 * max_rel_res + 1
|
||||
entity_id_same = entity_id[..., :, None] == entity_id[..., None, :]
|
||||
rp_entity_id = entity_id_same.type(rp.dtype)[..., None]
|
||||
|
||||
rel_sym_id = sym_id[..., :, None] - sym_id[..., None, :]
|
||||
|
||||
max_rel_chain = self.max_relative_chain
|
||||
|
||||
clipped_rel_chain = torch.clamp(
|
||||
rel_sym_id + max_rel_chain, min=0, max=2 * max_rel_chain)
|
||||
|
||||
clipped_rel_chain[~entity_id_same] = 2 * max_rel_chain + 1
|
||||
return rp, rp_entity_id, clipped_rel_chain
|
||||
|
||||
def relpos_emb(
|
||||
self,
|
||||
res_id: torch.Tensor,
|
||||
sym_id: Optional[torch.Tensor] = None,
|
||||
asym_id: Optional[torch.Tensor] = None,
|
||||
entity_id: Optional[torch.Tensor] = None,
|
||||
num_sym: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
||||
dtype = self.linear_relpos.weight.dtype
|
||||
if not self.use_chain_relative:
|
||||
rp = self._relpos_indices(res_id=res_id)
|
||||
return self.linear_relpos(
|
||||
one_hot(rp, num_classes=self.num_bins, dtype=dtype))
|
||||
else:
|
||||
rp, rp_entity_id, rp_rel_chain = self._relpos_indices(
|
||||
res_id=res_id,
|
||||
sym_id=sym_id,
|
||||
asym_id=asym_id,
|
||||
entity_id=entity_id)
|
||||
rp = one_hot(rp, num_classes=(2 * self.relpos_k + 2), dtype=dtype)
|
||||
rp_entity_id = rp_entity_id.type(dtype)
|
||||
rp_rel_chain = one_hot(
|
||||
rp_rel_chain,
|
||||
num_classes=(2 * self.max_relative_chain + 2),
|
||||
dtype=dtype)
|
||||
return self.linear_relpos(
|
||||
torch.cat([rp, rp_entity_id, rp_rel_chain], dim=-1))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tf: torch.Tensor,
|
||||
msa: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# [*, N_res, d_pair]
|
||||
if self.tf_dim == 21:
|
||||
# multimer use 21 target dim
|
||||
tf = tf[..., 1:]
|
||||
# convert type if necessary
|
||||
tf = tf.type(self.linear_tf_z_i.weight.dtype)
|
||||
msa = msa.type(self.linear_tf_z_i.weight.dtype)
|
||||
n_clust = msa.shape[-3]
|
||||
|
||||
msa_emb = self.linear_msa_m(msa)
|
||||
# target_feat (aatype) into msa representation
|
||||
tf_m = (
|
||||
self.linear_tf_m(tf).unsqueeze(-3).expand(
|
||||
((-1, ) * len(tf.shape[:-2]) + # noqa W504
|
||||
(n_clust, -1, -1))))
|
||||
msa_emb += tf_m
|
||||
|
||||
tf_emb_i = self.linear_tf_z_i(tf)
|
||||
tf_emb_j = self.linear_tf_z_j(tf)
|
||||
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
|
||||
|
||||
return msa_emb, pair_emb
|
||||
|
||||
|
||||
class RecyclingEmbedder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_msa: int,
|
||||
d_pair: int,
|
||||
min_bin: float,
|
||||
max_bin: float,
|
||||
num_bins: int,
|
||||
inf: float = 1e8,
|
||||
**kwargs,
|
||||
):
|
||||
super(RecyclingEmbedder, self).__init__()
|
||||
|
||||
self.d_msa = d_msa
|
||||
self.d_pair = d_pair
|
||||
self.min_bin = min_bin
|
||||
self.max_bin = max_bin
|
||||
self.num_bins = num_bins
|
||||
self.inf = inf
|
||||
|
||||
self.squared_bins = None
|
||||
|
||||
self.linear = Linear(self.num_bins, self.d_pair)
|
||||
self.layer_norm_m = LayerNorm(self.d_msa)
|
||||
self.layer_norm_z = LayerNorm(self.d_pair)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
m_update = self.layer_norm_m(m)
|
||||
z_update = self.layer_norm_z(z)
|
||||
|
||||
return m_update, z_update
|
||||
|
||||
def recyle_pos(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
if self.squared_bins is None:
|
||||
bins = torch.linspace(
|
||||
self.min_bin,
|
||||
self.max_bin,
|
||||
self.num_bins,
|
||||
dtype=torch.float if self.training else x.dtype,
|
||||
device=x.device,
|
||||
requires_grad=False,
|
||||
)
|
||||
self.squared_bins = bins**2
|
||||
upper = torch.cat(
|
||||
[self.squared_bins[1:],
|
||||
self.squared_bins.new_tensor([self.inf])],
|
||||
dim=-1)
|
||||
if self.training:
|
||||
x = x.float()
|
||||
d = torch.sum(
|
||||
(x[..., None, :] - x[..., None, :, :])**2, dim=-1, keepdims=True)
|
||||
d = ((d > self.squared_bins) * # noqa W504
|
||||
(d < upper)).type(self.linear.weight.dtype)
|
||||
d = self.linear(d)
|
||||
return d
|
||||
|
||||
|
||||
class TemplateAngleEmbedder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_in: int,
|
||||
d_out: int,
|
||||
**kwargs,
|
||||
):
|
||||
super(TemplateAngleEmbedder, self).__init__()
|
||||
|
||||
self.d_out = d_out
|
||||
self.d_in = d_in
|
||||
|
||||
self.linear_1 = Linear(self.d_in, self.d_out, init='relu')
|
||||
self.act = nn.GELU()
|
||||
self.linear_2 = Linear(self.d_out, self.d_out, init='relu')
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear_1(x.type(self.linear_1.weight.dtype))
|
||||
x = self.act(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class TemplatePairEmbedder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_in: int,
|
||||
v2_d_in: list,
|
||||
d_out: int,
|
||||
d_pair: int,
|
||||
v2_feature: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super(TemplatePairEmbedder, self).__init__()
|
||||
|
||||
self.d_out = d_out
|
||||
self.v2_feature = v2_feature
|
||||
if self.v2_feature:
|
||||
self.d_in = v2_d_in
|
||||
self.linear = SimpleModuleList()
|
||||
for d_in in self.d_in:
|
||||
self.linear.append(Linear(d_in, self.d_out, init='relu'))
|
||||
self.z_layer_norm = LayerNorm(d_pair)
|
||||
self.z_linear = Linear(d_pair, self.d_out, init='relu')
|
||||
else:
|
||||
self.d_in = d_in
|
||||
self.linear = Linear(self.d_in, self.d_out, init='relu')
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
z,
|
||||
) -> torch.Tensor:
|
||||
if not self.v2_feature:
|
||||
x = self.linear(x.type(self.linear.weight.dtype))
|
||||
return x
|
||||
else:
|
||||
dtype = self.z_linear.weight.dtype
|
||||
t = self.linear[0](x[0].type(dtype))
|
||||
for i, s in enumerate(x[1:]):
|
||||
t = residual(t, self.linear[i + 1](s.type(dtype)),
|
||||
self.training)
|
||||
t = residual(t, self.z_linear(self.z_layer_norm(z)), self.training)
|
||||
return t
|
||||
|
||||
|
||||
class ExtraMSAEmbedder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_in: int,
|
||||
d_out: int,
|
||||
**kwargs,
|
||||
):
|
||||
super(ExtraMSAEmbedder, self).__init__()
|
||||
|
||||
self.d_in = d_in
|
||||
self.d_out = d_out
|
||||
self.linear = Linear(self.d_in, self.d_out)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear(x.type(self.linear.weight.dtype))
|
||||
362
modelscope/models/science/unifold/modules/evoformer.py
Normal file
362
modelscope/models/science/unifold/modules/evoformer.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
|
||||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from unicore.utils import checkpoint_sequential
|
||||
|
||||
from .attentions import (MSAColumnAttention, MSAColumnGlobalAttention,
|
||||
MSARowAttentionWithPairBias, TriangleAttentionEnding,
|
||||
TriangleAttentionStarting)
|
||||
from .common import (Linear, OuterProductMean, SimpleModuleList, Transition,
|
||||
bias_dropout_residual, residual, tri_mul_residual)
|
||||
from .triangle_multiplication import (TriangleMultiplicationIncoming,
|
||||
TriangleMultiplicationOutgoing)
|
||||
|
||||
|
||||
class EvoformerIteration(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_msa: int,
|
||||
d_pair: int,
|
||||
d_hid_msa_att: int,
|
||||
d_hid_opm: int,
|
||||
d_hid_mul: int,
|
||||
d_hid_pair_att: int,
|
||||
num_heads_msa: int,
|
||||
num_heads_pair: int,
|
||||
transition_n: int,
|
||||
msa_dropout: float,
|
||||
pair_dropout: float,
|
||||
outer_product_mean_first: bool,
|
||||
inf: float,
|
||||
eps: float,
|
||||
_is_extra_msa_stack: bool = False,
|
||||
):
|
||||
super(EvoformerIteration, self).__init__()
|
||||
|
||||
self._is_extra_msa_stack = _is_extra_msa_stack
|
||||
self.outer_product_mean_first = outer_product_mean_first
|
||||
|
||||
self.msa_att_row = MSARowAttentionWithPairBias(
|
||||
d_msa=d_msa,
|
||||
d_pair=d_pair,
|
||||
d_hid=d_hid_msa_att,
|
||||
num_heads=num_heads_msa,
|
||||
)
|
||||
|
||||
if _is_extra_msa_stack:
|
||||
self.msa_att_col = MSAColumnGlobalAttention(
|
||||
d_in=d_msa,
|
||||
d_hid=d_hid_msa_att,
|
||||
num_heads=num_heads_msa,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
)
|
||||
else:
|
||||
self.msa_att_col = MSAColumnAttention(
|
||||
d_msa,
|
||||
d_hid_msa_att,
|
||||
num_heads_msa,
|
||||
)
|
||||
|
||||
self.msa_transition = Transition(
|
||||
d_in=d_msa,
|
||||
n=transition_n,
|
||||
)
|
||||
|
||||
self.outer_product_mean = OuterProductMean(
|
||||
d_msa,
|
||||
d_pair,
|
||||
d_hid_opm,
|
||||
)
|
||||
|
||||
self.tri_mul_out = TriangleMultiplicationOutgoing(
|
||||
d_pair,
|
||||
d_hid_mul,
|
||||
)
|
||||
self.tri_mul_in = TriangleMultiplicationIncoming(
|
||||
d_pair,
|
||||
d_hid_mul,
|
||||
)
|
||||
|
||||
self.tri_att_start = TriangleAttentionStarting(
|
||||
d_pair,
|
||||
d_hid_pair_att,
|
||||
num_heads_pair,
|
||||
)
|
||||
self.tri_att_end = TriangleAttentionEnding(
|
||||
d_pair,
|
||||
d_hid_pair_att,
|
||||
num_heads_pair,
|
||||
)
|
||||
|
||||
self.pair_transition = Transition(
|
||||
d_in=d_pair,
|
||||
n=transition_n,
|
||||
)
|
||||
|
||||
self.row_dropout_share_dim = -3
|
||||
self.col_dropout_share_dim = -2
|
||||
self.msa_dropout = msa_dropout
|
||||
self.pair_dropout = pair_dropout
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
msa_mask: torch.Tensor,
|
||||
pair_mask: torch.Tensor,
|
||||
msa_row_attn_mask: torch.Tensor,
|
||||
msa_col_attn_mask: Optional[torch.Tensor],
|
||||
tri_start_attn_mask: torch.Tensor,
|
||||
tri_end_attn_mask: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
block_size: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
if self.outer_product_mean_first:
|
||||
z = residual(
|
||||
z,
|
||||
self.outer_product_mean(
|
||||
m, mask=msa_mask, chunk_size=chunk_size), self.training)
|
||||
|
||||
m = bias_dropout_residual(
|
||||
self.msa_att_row,
|
||||
m,
|
||||
self.msa_att_row(
|
||||
m, z=z, attn_mask=msa_row_attn_mask, chunk_size=chunk_size),
|
||||
self.row_dropout_share_dim,
|
||||
self.msa_dropout,
|
||||
self.training,
|
||||
)
|
||||
if self._is_extra_msa_stack:
|
||||
m = residual(
|
||||
m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size),
|
||||
self.training)
|
||||
else:
|
||||
m = bias_dropout_residual(
|
||||
self.msa_att_col,
|
||||
m,
|
||||
self.msa_att_col(
|
||||
m, attn_mask=msa_col_attn_mask, chunk_size=chunk_size),
|
||||
self.col_dropout_share_dim,
|
||||
self.msa_dropout,
|
||||
self.training,
|
||||
)
|
||||
m = residual(m, self.msa_transition(m, chunk_size=chunk_size),
|
||||
self.training)
|
||||
if not self.outer_product_mean_first:
|
||||
z = residual(
|
||||
z,
|
||||
self.outer_product_mean(
|
||||
m, mask=msa_mask, chunk_size=chunk_size), self.training)
|
||||
|
||||
z = tri_mul_residual(
|
||||
self.tri_mul_out,
|
||||
z,
|
||||
self.tri_mul_out(z, mask=pair_mask, block_size=block_size),
|
||||
self.row_dropout_share_dim,
|
||||
self.pair_dropout,
|
||||
self.training,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
z = tri_mul_residual(
|
||||
self.tri_mul_in,
|
||||
z,
|
||||
self.tri_mul_in(z, mask=pair_mask, block_size=block_size),
|
||||
self.row_dropout_share_dim,
|
||||
self.pair_dropout,
|
||||
self.training,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
z = bias_dropout_residual(
|
||||
self.tri_att_start,
|
||||
z,
|
||||
self.tri_att_start(
|
||||
z, attn_mask=tri_start_attn_mask, chunk_size=chunk_size),
|
||||
self.row_dropout_share_dim,
|
||||
self.pair_dropout,
|
||||
self.training,
|
||||
)
|
||||
|
||||
z = bias_dropout_residual(
|
||||
self.tri_att_end,
|
||||
z,
|
||||
self.tri_att_end(
|
||||
z, attn_mask=tri_end_attn_mask, chunk_size=chunk_size),
|
||||
self.col_dropout_share_dim,
|
||||
self.pair_dropout,
|
||||
self.training,
|
||||
)
|
||||
z = residual(z, self.pair_transition(z, chunk_size=chunk_size),
|
||||
self.training)
|
||||
return m, z
|
||||
|
||||
|
||||
class EvoformerStack(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_msa: int,
|
||||
d_pair: int,
|
||||
d_hid_msa_att: int,
|
||||
d_hid_opm: int,
|
||||
d_hid_mul: int,
|
||||
d_hid_pair_att: int,
|
||||
d_single: int,
|
||||
num_heads_msa: int,
|
||||
num_heads_pair: int,
|
||||
num_blocks: int,
|
||||
transition_n: int,
|
||||
msa_dropout: float,
|
||||
pair_dropout: float,
|
||||
outer_product_mean_first: bool,
|
||||
inf: float,
|
||||
eps: float,
|
||||
_is_extra_msa_stack: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super(EvoformerStack, self).__init__()
|
||||
|
||||
self._is_extra_msa_stack = _is_extra_msa_stack
|
||||
|
||||
self.blocks = SimpleModuleList()
|
||||
|
||||
for _ in range(num_blocks):
|
||||
self.blocks.append(
|
||||
EvoformerIteration(
|
||||
d_msa=d_msa,
|
||||
d_pair=d_pair,
|
||||
d_hid_msa_att=d_hid_msa_att,
|
||||
d_hid_opm=d_hid_opm,
|
||||
d_hid_mul=d_hid_mul,
|
||||
d_hid_pair_att=d_hid_pair_att,
|
||||
num_heads_msa=num_heads_msa,
|
||||
num_heads_pair=num_heads_pair,
|
||||
transition_n=transition_n,
|
||||
msa_dropout=msa_dropout,
|
||||
pair_dropout=pair_dropout,
|
||||
outer_product_mean_first=outer_product_mean_first,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
_is_extra_msa_stack=_is_extra_msa_stack,
|
||||
))
|
||||
if not self._is_extra_msa_stack:
|
||||
self.linear = Linear(d_msa, d_single)
|
||||
else:
|
||||
self.linear = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
msa_mask: torch.Tensor,
|
||||
pair_mask: torch.Tensor,
|
||||
msa_row_attn_mask: torch.Tensor,
|
||||
msa_col_attn_mask: torch.Tensor,
|
||||
tri_start_attn_mask: torch.Tensor,
|
||||
tri_end_attn_mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
block_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
blocks = [
|
||||
partial(
|
||||
b,
|
||||
msa_mask=msa_mask,
|
||||
pair_mask=pair_mask,
|
||||
msa_row_attn_mask=msa_row_attn_mask,
|
||||
msa_col_attn_mask=msa_col_attn_mask,
|
||||
tri_start_attn_mask=tri_start_attn_mask,
|
||||
tri_end_attn_mask=tri_end_attn_mask,
|
||||
chunk_size=chunk_size,
|
||||
block_size=block_size) for b in self.blocks
|
||||
]
|
||||
|
||||
m, z = checkpoint_sequential(
|
||||
blocks,
|
||||
input=(m, z),
|
||||
)
|
||||
|
||||
s = None
|
||||
if not self._is_extra_msa_stack:
|
||||
seq_dim = -3
|
||||
index = torch.tensor([0], device=m.device)
|
||||
s = self.linear(torch.index_select(m, dim=seq_dim, index=index))
|
||||
s = s.squeeze(seq_dim)
|
||||
|
||||
return m, z, s
|
||||
|
||||
|
||||
class ExtraMSAStack(EvoformerStack):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_msa: int,
|
||||
d_pair: int,
|
||||
d_hid_msa_att: int,
|
||||
d_hid_opm: int,
|
||||
d_hid_mul: int,
|
||||
d_hid_pair_att: int,
|
||||
num_heads_msa: int,
|
||||
num_heads_pair: int,
|
||||
num_blocks: int,
|
||||
transition_n: int,
|
||||
msa_dropout: float,
|
||||
pair_dropout: float,
|
||||
outer_product_mean_first: bool,
|
||||
inf: float,
|
||||
eps: float,
|
||||
**kwargs,
|
||||
):
|
||||
super(ExtraMSAStack, self).__init__(
|
||||
d_msa=d_msa,
|
||||
d_pair=d_pair,
|
||||
d_hid_msa_att=d_hid_msa_att,
|
||||
d_hid_opm=d_hid_opm,
|
||||
d_hid_mul=d_hid_mul,
|
||||
d_hid_pair_att=d_hid_pair_att,
|
||||
d_single=None,
|
||||
num_heads_msa=num_heads_msa,
|
||||
num_heads_pair=num_heads_pair,
|
||||
num_blocks=num_blocks,
|
||||
transition_n=transition_n,
|
||||
msa_dropout=msa_dropout,
|
||||
pair_dropout=pair_dropout,
|
||||
outer_product_mean_first=outer_product_mean_first,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
_is_extra_msa_stack=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
msa_mask: Optional[torch.Tensor] = None,
|
||||
pair_mask: Optional[torch.Tensor] = None,
|
||||
msa_row_attn_mask: torch.Tensor = None,
|
||||
msa_col_attn_mask: torch.Tensor = None,
|
||||
tri_start_attn_mask: torch.Tensor = None,
|
||||
tri_end_attn_mask: torch.Tensor = None,
|
||||
chunk_size: int = None,
|
||||
block_size: int = None,
|
||||
) -> torch.Tensor:
|
||||
_, z, _ = super().forward(
|
||||
m,
|
||||
z,
|
||||
msa_mask=msa_mask,
|
||||
pair_mask=pair_mask,
|
||||
msa_row_attn_mask=msa_row_attn_mask,
|
||||
msa_col_attn_mask=msa_col_attn_mask,
|
||||
tri_start_attn_mask=tri_start_attn_mask,
|
||||
tri_end_attn_mask=tri_end_attn_mask,
|
||||
chunk_size=chunk_size,
|
||||
block_size=block_size)
|
||||
return z
|
||||
195
modelscope/models/science/unifold/modules/featurization.py
Normal file
195
modelscope/models/science/unifold/modules/featurization.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
|
||||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from unicore.utils import batched_gather, one_hot
|
||||
|
||||
from modelscope.models.science.unifold.data import residue_constants as rc
|
||||
from .frame import Frame
|
||||
|
||||
|
||||
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
|
||||
is_gly = aatype == rc.restype_order['G']
|
||||
ca_idx = rc.atom_order['CA']
|
||||
cb_idx = rc.atom_order['CB']
|
||||
pseudo_beta = torch.where(
|
||||
is_gly[..., None].expand(*((-1, ) * len(is_gly.shape)), 3),
|
||||
all_atom_positions[..., ca_idx, :],
|
||||
all_atom_positions[..., cb_idx, :],
|
||||
)
|
||||
|
||||
if all_atom_masks is not None:
|
||||
pseudo_beta_mask = torch.where(
|
||||
is_gly,
|
||||
all_atom_masks[..., ca_idx],
|
||||
all_atom_masks[..., cb_idx],
|
||||
)
|
||||
return pseudo_beta, pseudo_beta_mask
|
||||
else:
|
||||
return pseudo_beta
|
||||
|
||||
|
||||
def atom14_to_atom37(atom14, batch):
|
||||
atom37_data = batched_gather(
|
||||
atom14,
|
||||
batch['residx_atom37_to_atom14'],
|
||||
dim=-2,
|
||||
num_batch_dims=len(atom14.shape[:-2]),
|
||||
)
|
||||
|
||||
atom37_data = atom37_data * batch['atom37_atom_exists'][..., None]
|
||||
|
||||
return atom37_data
|
||||
|
||||
|
||||
def build_template_angle_feat(template_feats, v2_feature=False):
|
||||
template_aatype = template_feats['template_aatype']
|
||||
torsion_angles_sin_cos = template_feats['template_torsion_angles_sin_cos']
|
||||
torsion_angles_mask = template_feats['template_torsion_angles_mask']
|
||||
if not v2_feature:
|
||||
alt_torsion_angles_sin_cos = template_feats[
|
||||
'template_alt_torsion_angles_sin_cos']
|
||||
template_angle_feat = torch.cat(
|
||||
[
|
||||
one_hot(template_aatype, 22),
|
||||
torsion_angles_sin_cos.reshape(
|
||||
*torsion_angles_sin_cos.shape[:-2], 14),
|
||||
alt_torsion_angles_sin_cos.reshape(
|
||||
*alt_torsion_angles_sin_cos.shape[:-2], 14),
|
||||
torsion_angles_mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
template_angle_mask = torsion_angles_mask[..., 2]
|
||||
else:
|
||||
chi_mask = torsion_angles_mask[..., 3:]
|
||||
chi_angles_sin = torsion_angles_sin_cos[..., 3:, 0] * chi_mask
|
||||
chi_angles_cos = torsion_angles_sin_cos[..., 3:, 1] * chi_mask
|
||||
template_angle_feat = torch.cat(
|
||||
[
|
||||
one_hot(template_aatype, 22),
|
||||
chi_angles_sin,
|
||||
chi_angles_cos,
|
||||
chi_mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
template_angle_mask = chi_mask[..., 0]
|
||||
return template_angle_feat, template_angle_mask
|
||||
|
||||
|
||||
def build_template_pair_feat(
|
||||
batch,
|
||||
min_bin,
|
||||
max_bin,
|
||||
num_bins,
|
||||
eps=1e-20,
|
||||
inf=1e8,
|
||||
):
|
||||
template_mask = batch['template_pseudo_beta_mask']
|
||||
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
|
||||
|
||||
tpb = batch['template_pseudo_beta']
|
||||
dgram = torch.sum(
|
||||
(tpb[..., None, :] - tpb[..., None, :, :])**2, dim=-1, keepdim=True)
|
||||
lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device)**2
|
||||
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
|
||||
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
|
||||
|
||||
to_concat = [dgram, template_mask_2d[..., None]]
|
||||
|
||||
aatype_one_hot = nn.functional.one_hot(
|
||||
batch['template_aatype'],
|
||||
rc.restype_num + 2,
|
||||
)
|
||||
|
||||
n_res = batch['template_aatype'].shape[-1]
|
||||
to_concat.append(aatype_one_hot[..., None, :, :].expand(
|
||||
*aatype_one_hot.shape[:-2], n_res, -1, -1))
|
||||
to_concat.append(aatype_one_hot[...,
|
||||
None, :].expand(*aatype_one_hot.shape[:-2],
|
||||
-1, n_res, -1))
|
||||
|
||||
to_concat.append(template_mask_2d.new_zeros(*template_mask_2d.shape, 3))
|
||||
to_concat.append(template_mask_2d[..., None])
|
||||
|
||||
act = torch.cat(to_concat, dim=-1)
|
||||
act = act * template_mask_2d[..., None]
|
||||
|
||||
return act
|
||||
|
||||
|
||||
def build_template_pair_feat_v2(
|
||||
batch,
|
||||
min_bin,
|
||||
max_bin,
|
||||
num_bins,
|
||||
multichain_mask_2d=None,
|
||||
eps=1e-20,
|
||||
inf=1e8,
|
||||
):
|
||||
template_mask = batch['template_pseudo_beta_mask']
|
||||
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
|
||||
if multichain_mask_2d is not None:
|
||||
template_mask_2d *= multichain_mask_2d
|
||||
|
||||
tpb = batch['template_pseudo_beta']
|
||||
dgram = torch.sum(
|
||||
(tpb[..., None, :] - tpb[..., None, :, :])**2, dim=-1, keepdim=True)
|
||||
lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device)**2
|
||||
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
|
||||
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
|
||||
dgram *= template_mask_2d[..., None]
|
||||
to_concat = [dgram, template_mask_2d[..., None]]
|
||||
|
||||
aatype_one_hot = one_hot(
|
||||
batch['template_aatype'],
|
||||
rc.restype_num + 2,
|
||||
)
|
||||
|
||||
n_res = batch['template_aatype'].shape[-1]
|
||||
to_concat.append(aatype_one_hot[..., None, :, :].expand(
|
||||
*aatype_one_hot.shape[:-2], n_res, -1, -1))
|
||||
to_concat.append(aatype_one_hot[...,
|
||||
None, :].expand(*aatype_one_hot.shape[:-2],
|
||||
-1, n_res, -1))
|
||||
|
||||
n, ca, c = [rc.atom_order[a] for a in ['N', 'CA', 'C']]
|
||||
rigids = Frame.make_transform_from_reference(
|
||||
n_xyz=batch['template_all_atom_positions'][..., n, :],
|
||||
ca_xyz=batch['template_all_atom_positions'][..., ca, :],
|
||||
c_xyz=batch['template_all_atom_positions'][..., c, :],
|
||||
eps=eps,
|
||||
)
|
||||
points = rigids.get_trans()[..., None, :, :]
|
||||
rigid_vec = rigids[..., None].invert_apply(points)
|
||||
|
||||
inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1))
|
||||
|
||||
t_aa_masks = batch['template_all_atom_mask']
|
||||
backbone_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[...,
|
||||
c]
|
||||
backbone_mask_2d = backbone_mask[..., :, None] * backbone_mask[...,
|
||||
None, :]
|
||||
if multichain_mask_2d is not None:
|
||||
backbone_mask_2d *= multichain_mask_2d
|
||||
|
||||
inv_distance_scalar = inv_distance_scalar * backbone_mask_2d
|
||||
unit_vector_data = rigid_vec * inv_distance_scalar[..., None]
|
||||
to_concat.extend(torch.unbind(unit_vector_data[..., None, :], dim=-1))
|
||||
to_concat.append(backbone_mask_2d[..., None])
|
||||
|
||||
return to_concat
|
||||
|
||||
|
||||
def build_extra_msa_feat(batch):
|
||||
msa_1hot = one_hot(batch['extra_msa'], 23)
|
||||
msa_feat = [
|
||||
msa_1hot,
|
||||
batch['extra_msa_has_deletion'].unsqueeze(-1),
|
||||
batch['extra_msa_deletion_value'].unsqueeze(-1),
|
||||
]
|
||||
return torch.cat(msa_feat, dim=-1)
|
||||
562
modelscope/models/science/unifold/modules/frame.py
Normal file
562
modelscope/models/science/unifold/modules/frame.py
Normal file
@@ -0,0 +1,562 @@
|
||||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
|
||||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
|
||||
|
||||
from __future__ import annotations # noqa
|
||||
from typing import Any, Callable, Iterable, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def zero_translation(
|
||||
batch_dims: Tuple[int],
|
||||
dtype: Optional[torch.dtype] = torch.float,
|
||||
device: Optional[torch.device] = torch.device('cpu'),
|
||||
requires_grad: bool = False,
|
||||
) -> torch.Tensor:
|
||||
trans = torch.zeros((*batch_dims, 3),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
requires_grad=requires_grad)
|
||||
return trans
|
||||
|
||||
|
||||
# pylint: disable=bad-whitespace
|
||||
_QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32)
|
||||
|
||||
_QUAT_TO_ROT[0, 0] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rr
|
||||
_QUAT_TO_ROT[1, 1] = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] # ii
|
||||
_QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [0, 1, 0], [0, 0, -1]] # jj
|
||||
_QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]] # kk
|
||||
|
||||
_QUAT_TO_ROT[1, 2] = [[0, 2, 0], [2, 0, 0], [0, 0, 0]] # ij
|
||||
_QUAT_TO_ROT[1, 3] = [[0, 0, 2], [0, 0, 0], [2, 0, 0]] # ik
|
||||
_QUAT_TO_ROT[2, 3] = [[0, 0, 0], [0, 0, 2], [0, 2, 0]] # jk
|
||||
|
||||
_QUAT_TO_ROT[0, 1] = [[0, 0, 0], [0, 0, -2], [0, 2, 0]] # ir
|
||||
_QUAT_TO_ROT[0, 2] = [[0, 0, 2], [0, 0, 0], [-2, 0, 0]] # jr
|
||||
_QUAT_TO_ROT[0, 3] = [[0, -2, 0], [2, 0, 0], [0, 0, 0]] # kr
|
||||
|
||||
_QUAT_TO_ROT = _QUAT_TO_ROT.reshape(4, 4, 9)
|
||||
_QUAT_TO_ROT_tensor = torch.from_numpy(_QUAT_TO_ROT)
|
||||
|
||||
_QUAT_MULTIPLY = np.zeros((4, 4, 4))
|
||||
_QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0],
|
||||
[0, 0, 0, -1]]
|
||||
|
||||
_QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1],
|
||||
[0, 0, -1, 0]]
|
||||
|
||||
_QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0],
|
||||
[0, 1, 0, 0]]
|
||||
|
||||
_QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0],
|
||||
[1, 0, 0, 0]]
|
||||
|
||||
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
|
||||
_QUAT_MULTIPLY_BY_VEC_tensor = torch.from_numpy(_QUAT_MULTIPLY_BY_VEC)
|
||||
|
||||
|
||||
class Rotation:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mat: torch.Tensor,
|
||||
):
|
||||
if mat.shape[-2:] != (3, 3):
|
||||
raise ValueError(f'incorrect rotation shape: {mat.shape}')
|
||||
self._mat = mat
|
||||
|
||||
@staticmethod
|
||||
def identity(
|
||||
shape,
|
||||
dtype: Optional[torch.dtype] = torch.float,
|
||||
device: Optional[torch.device] = torch.device('cpu'),
|
||||
requires_grad: bool = False,
|
||||
) -> Rotation:
|
||||
mat = torch.eye(
|
||||
3, dtype=dtype, device=device, requires_grad=requires_grad)
|
||||
mat = mat.view(*((1, ) * len(shape)), 3, 3)
|
||||
mat = mat.expand(*shape, -1, -1)
|
||||
return Rotation(mat)
|
||||
|
||||
@staticmethod
|
||||
def mat_mul_mat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
return (a.float() @ b.float()).type(a.dtype)
|
||||
|
||||
@staticmethod
|
||||
def mat_mul_vec(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
return (r.float() @ t.float().unsqueeze(-1)).squeeze(-1).type(t.dtype)
|
||||
|
||||
def __getitem__(self, index: Any) -> Rotation:
|
||||
if not isinstance(index, tuple):
|
||||
index = (index, )
|
||||
return Rotation(mat=self._mat[index + (slice(None), slice(None))])
|
||||
|
||||
def __mul__(self, right: Any) -> Rotation:
|
||||
if isinstance(right, (int, float)):
|
||||
return Rotation(mat=self._mat * right)
|
||||
elif isinstance(right, torch.Tensor):
|
||||
return Rotation(mat=self._mat * right[..., None, None])
|
||||
else:
|
||||
raise TypeError(
|
||||
f'multiplicand must be a tensor or a number, got {type(right)}.'
|
||||
)
|
||||
|
||||
def __rmul__(self, left: Any) -> Rotation:
|
||||
return self.__mul__(left)
|
||||
|
||||
def __matmul__(self, other: Rotation) -> Rotation:
|
||||
new_mat = Rotation.mat_mul_mat(self.rot_mat, other.rot_mat)
|
||||
return Rotation(mat=new_mat)
|
||||
|
||||
@property
|
||||
def _inv_mat(self):
|
||||
return self._mat.transpose(-1, -2)
|
||||
|
||||
@property
|
||||
def rot_mat(self) -> torch.Tensor:
|
||||
return self._mat
|
||||
|
||||
def invert(self) -> Rotation:
|
||||
return Rotation(mat=self._inv_mat)
|
||||
|
||||
def apply(self, pts: torch.Tensor) -> torch.Tensor:
|
||||
return Rotation.mat_mul_vec(self._mat, pts)
|
||||
|
||||
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
|
||||
return Rotation.mat_mul_vec(self._inv_mat, pts)
|
||||
|
||||
# inherit tensor behaviors
|
||||
@property
|
||||
def shape(self) -> torch.Size:
|
||||
s = self._mat.shape[:-2]
|
||||
return s
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self._mat.dtype
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._mat.device
|
||||
|
||||
@property
|
||||
def requires_grad(self) -> bool:
|
||||
return self._mat.requires_grad
|
||||
|
||||
def unsqueeze(self, dim: int) -> Rotation:
|
||||
if dim >= len(self.shape):
|
||||
raise ValueError('Invalid dimension')
|
||||
|
||||
rot_mats = self._mat.unsqueeze(dim if dim >= 0 else dim - 2)
|
||||
return Rotation(mat=rot_mats)
|
||||
|
||||
def map_tensor_fn(self, fn: Callable[[torch.Tensor],
|
||||
torch.Tensor]) -> Rotation:
|
||||
mat = self._mat.view(self._mat.shape[:-2] + (9, ))
|
||||
mat = torch.stack(list(map(fn, torch.unbind(mat, dim=-1))), dim=-1)
|
||||
mat = mat.view(mat.shape[:-1] + (3, 3))
|
||||
return Rotation(mat=mat)
|
||||
|
||||
@staticmethod
|
||||
def cat(rs: Sequence[Rotation], dim: int) -> Rotation:
|
||||
rot_mats = [r.rot_mat for r in rs]
|
||||
rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
|
||||
|
||||
return Rotation(mat=rot_mats)
|
||||
|
||||
def cuda(self) -> Rotation:
|
||||
return Rotation(mat=self._mat.cuda())
|
||||
|
||||
def to(self, device: Optional[torch.device],
|
||||
dtype: Optional[torch.dtype]) -> Rotation:
|
||||
return Rotation(mat=self._mat.to(device=device, dtype=dtype))
|
||||
|
||||
def type(self, dtype: Optional[torch.dtype]) -> Rotation:
|
||||
return Rotation(mat=self._mat.type(dtype))
|
||||
|
||||
def detach(self) -> Rotation:
|
||||
return Rotation(mat=self._mat.detach())
|
||||
|
||||
|
||||
class Frame:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rotation: Optional[Rotation],
|
||||
translation: Optional[torch.Tensor],
|
||||
):
|
||||
if rotation is None and translation is None:
|
||||
rotation = Rotation.identity((0, ))
|
||||
translation = zero_translation((0, ))
|
||||
elif translation is None:
|
||||
translation = zero_translation(rotation.shape, rotation.dtype,
|
||||
rotation.device,
|
||||
rotation.requires_grad)
|
||||
|
||||
elif rotation is None:
|
||||
rotation = Rotation.identity(
|
||||
translation.shape[:-1],
|
||||
translation.dtype,
|
||||
translation.device,
|
||||
translation.requires_grad,
|
||||
)
|
||||
|
||||
if (rotation.shape != translation.shape[:-1]) or (rotation.device
|
||||
!= # noqa W504
|
||||
translation.device):
|
||||
raise ValueError('RotationMatrix and translation incompatible')
|
||||
|
||||
self._r = rotation
|
||||
self._t = translation
|
||||
|
||||
@staticmethod
|
||||
def identity(
|
||||
shape: Iterable[int],
|
||||
dtype: Optional[torch.dtype] = torch.float,
|
||||
device: Optional[torch.device] = torch.device('cpu'),
|
||||
requires_grad: bool = False,
|
||||
) -> Frame:
|
||||
return Frame(
|
||||
Rotation.identity(shape, dtype, device, requires_grad),
|
||||
zero_translation(shape, dtype, device, requires_grad),
|
||||
)
|
||||
|
||||
def __getitem__(
|
||||
self,
|
||||
index: Any,
|
||||
) -> Frame:
|
||||
if type(index) != tuple:
|
||||
index = (index, )
|
||||
|
||||
return Frame(
|
||||
self._r[index],
|
||||
self._t[index + (slice(None), )],
|
||||
)
|
||||
|
||||
def __mul__(
|
||||
self,
|
||||
right: torch.Tensor,
|
||||
) -> Frame:
|
||||
if not (isinstance(right, torch.Tensor)):
|
||||
raise TypeError('The other multiplicand must be a Tensor')
|
||||
|
||||
new_rots = self._r * right
|
||||
new_trans = self._t * right[..., None]
|
||||
|
||||
return Frame(new_rots, new_trans)
|
||||
|
||||
def __rmul__(
|
||||
self,
|
||||
left: torch.Tensor,
|
||||
) -> Frame:
|
||||
return self.__mul__(left)
|
||||
|
||||
@property
|
||||
def shape(self) -> torch.Size:
|
||||
s = self._t.shape[:-1]
|
||||
return s
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._t.device
|
||||
|
||||
def get_rots(self) -> Rotation:
|
||||
return self._r
|
||||
|
||||
def get_trans(self) -> torch.Tensor:
|
||||
return self._t
|
||||
|
||||
def compose(
|
||||
self,
|
||||
other: Frame,
|
||||
) -> Frame:
|
||||
new_rot = self._r @ other._r
|
||||
new_trans = self._r.apply(other._t) + self._t
|
||||
return Frame(new_rot, new_trans)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
pts: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
rotated = self._r.apply(pts)
|
||||
return rotated + self._t
|
||||
|
||||
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
|
||||
pts = pts - self._t
|
||||
return self._r.invert_apply(pts)
|
||||
|
||||
def invert(self) -> Frame:
|
||||
rot_inv = self._r.invert()
|
||||
trn_inv = rot_inv.apply(self._t)
|
||||
|
||||
return Frame(rot_inv, -1 * trn_inv)
|
||||
|
||||
def map_tensor_fn(self, fn: Callable[[torch.Tensor],
|
||||
torch.Tensor]) -> Frame:
|
||||
new_rots = self._r.map_tensor_fn(fn)
|
||||
new_trans = torch.stack(
|
||||
list(map(fn, torch.unbind(self._t, dim=-1))), dim=-1)
|
||||
|
||||
return Frame(new_rots, new_trans)
|
||||
|
||||
def to_tensor_4x4(self) -> torch.Tensor:
|
||||
tensor = self._t.new_zeros((*self.shape, 4, 4))
|
||||
tensor[..., :3, :3] = self._r.rot_mat
|
||||
tensor[..., :3, 3] = self._t
|
||||
tensor[..., 3, 3] = 1
|
||||
return tensor
|
||||
|
||||
@staticmethod
|
||||
def from_tensor_4x4(t: torch.Tensor) -> Frame:
|
||||
if t.shape[-2:] != (4, 4):
|
||||
raise ValueError('Incorrectly shaped input tensor')
|
||||
|
||||
rots = Rotation(mat=t[..., :3, :3])
|
||||
trans = t[..., :3, 3]
|
||||
|
||||
return Frame(rots, trans)
|
||||
|
||||
@staticmethod
|
||||
def from_3_points(
|
||||
p_neg_x_axis: torch.Tensor,
|
||||
origin: torch.Tensor,
|
||||
p_xy_plane: torch.Tensor,
|
||||
eps: float = 1e-8,
|
||||
) -> Frame:
|
||||
p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
|
||||
origin = torch.unbind(origin, dim=-1)
|
||||
p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
|
||||
|
||||
e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
|
||||
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
|
||||
|
||||
denom = torch.sqrt(sum((c * c for c in e0)) + eps)
|
||||
e0 = [c / denom for c in e0]
|
||||
dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
|
||||
e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
|
||||
denom = torch.sqrt(sum((c * c for c in e1)) + eps)
|
||||
e1 = [c / denom for c in e1]
|
||||
e2 = [
|
||||
e0[1] * e1[2] - e0[2] * e1[1],
|
||||
e0[2] * e1[0] - e0[0] * e1[2],
|
||||
e0[0] * e1[1] - e0[1] * e1[0],
|
||||
]
|
||||
|
||||
rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
|
||||
rots = rots.reshape(rots.shape[:-1] + (3, 3))
|
||||
|
||||
rot_obj = Rotation(mat=rots)
|
||||
|
||||
return Frame(rot_obj, torch.stack(origin, dim=-1))
|
||||
|
||||
def unsqueeze(
|
||||
self,
|
||||
dim: int,
|
||||
) -> Frame:
|
||||
if dim >= len(self.shape):
|
||||
raise ValueError('Invalid dimension')
|
||||
rots = self._r.unsqueeze(dim)
|
||||
trans = self._t.unsqueeze(dim if dim >= 0 else dim - 1)
|
||||
|
||||
return Frame(rots, trans)
|
||||
|
||||
@staticmethod
|
||||
def cat(
|
||||
Ts: Sequence[Frame],
|
||||
dim: int,
|
||||
) -> Frame:
|
||||
rots = Rotation.cat([T._r for T in Ts], dim)
|
||||
trans = torch.cat([T._t for T in Ts], dim=dim if dim >= 0 else dim - 1)
|
||||
|
||||
return Frame(rots, trans)
|
||||
|
||||
def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Frame:
|
||||
return Frame(fn(self._r), self._t)
|
||||
|
||||
def apply_trans_fn(self, fn: Callable[[torch.Tensor],
|
||||
torch.Tensor]) -> Frame:
|
||||
return Frame(self._r, fn(self._t))
|
||||
|
||||
def scale_translation(self, trans_scale_factor: float) -> Frame:
|
||||
# fn = lambda t: t * trans_scale_factor
|
||||
def fn(t):
|
||||
return t * trans_scale_factor
|
||||
|
||||
return self.apply_trans_fn(fn)
|
||||
|
||||
def stop_rot_gradient(self) -> Frame:
|
||||
# fn = lambda r: r.detach()
|
||||
def fn(r):
|
||||
return r.detach()
|
||||
|
||||
return self.apply_rot_fn(fn)
|
||||
|
||||
@staticmethod
|
||||
def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
|
||||
input_dtype = ca_xyz.dtype
|
||||
n_xyz = n_xyz.float()
|
||||
ca_xyz = ca_xyz.float()
|
||||
c_xyz = c_xyz.float()
|
||||
n_xyz = n_xyz - ca_xyz
|
||||
c_xyz = c_xyz - ca_xyz
|
||||
|
||||
c_x, c_y, d_pair = [c_xyz[..., i] for i in range(3)]
|
||||
norm = torch.sqrt(eps + c_x**2 + c_y**2)
|
||||
sin_c1 = -c_y / norm
|
||||
cos_c1 = c_x / norm
|
||||
|
||||
c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
|
||||
c1_rots[..., 0, 0] = cos_c1
|
||||
c1_rots[..., 0, 1] = -1 * sin_c1
|
||||
c1_rots[..., 1, 0] = sin_c1
|
||||
c1_rots[..., 1, 1] = cos_c1
|
||||
c1_rots[..., 2, 2] = 1
|
||||
|
||||
norm = torch.sqrt(eps + c_x**2 + c_y**2 + d_pair**2)
|
||||
sin_c2 = d_pair / norm
|
||||
cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm
|
||||
|
||||
c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
|
||||
c2_rots[..., 0, 0] = cos_c2
|
||||
c2_rots[..., 0, 2] = sin_c2
|
||||
c2_rots[..., 1, 1] = 1
|
||||
c2_rots[..., 2, 0] = -1 * sin_c2
|
||||
c2_rots[..., 2, 2] = cos_c2
|
||||
|
||||
c_rots = Rotation.mat_mul_mat(c2_rots, c1_rots)
|
||||
n_xyz = Rotation.mat_mul_vec(c_rots, n_xyz)
|
||||
|
||||
_, n_y, n_z = [n_xyz[..., i] for i in range(3)]
|
||||
norm = torch.sqrt(eps + n_y**2 + n_z**2)
|
||||
sin_n = -n_z / norm
|
||||
cos_n = n_y / norm
|
||||
|
||||
n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
|
||||
n_rots[..., 0, 0] = 1
|
||||
n_rots[..., 1, 1] = cos_n
|
||||
n_rots[..., 1, 2] = -1 * sin_n
|
||||
n_rots[..., 2, 1] = sin_n
|
||||
n_rots[..., 2, 2] = cos_n
|
||||
|
||||
rots = Rotation.mat_mul_mat(n_rots, c_rots)
|
||||
|
||||
rots = rots.transpose(-1, -2)
|
||||
rot_obj = Rotation(mat=rots.type(input_dtype))
|
||||
|
||||
return Frame(rot_obj, ca_xyz.type(input_dtype))
|
||||
|
||||
def cuda(self) -> Frame:
|
||||
return Frame(self._r.cuda(), self._t.cuda())
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
assert self._r.dtype == self._t.dtype
|
||||
return self._r.dtype
|
||||
|
||||
def type(self, dtype) -> Frame:
|
||||
return Frame(self._r.type(dtype), self._t.type(dtype))
|
||||
|
||||
|
||||
class Quaternion:
|
||||
|
||||
def __init__(self, quaternion: torch.Tensor, translation: torch.Tensor):
|
||||
if quaternion.shape[-1] != 4:
|
||||
raise ValueError(f'incorrect quaternion shape: {quaternion.shape}')
|
||||
self._q = quaternion
|
||||
self._t = translation
|
||||
|
||||
@staticmethod
|
||||
def identity(
|
||||
shape: Iterable[int],
|
||||
dtype: Optional[torch.dtype] = torch.float,
|
||||
device: Optional[torch.device] = torch.device('cpu'),
|
||||
requires_grad: bool = False,
|
||||
) -> Quaternion:
|
||||
trans = zero_translation(shape, dtype, device, requires_grad)
|
||||
quats = torch.zeros((*shape, 4),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
requires_grad=requires_grad)
|
||||
with torch.no_grad():
|
||||
quats[..., 0] = 1
|
||||
return Quaternion(quats, trans)
|
||||
|
||||
def get_quats(self):
|
||||
return self._q
|
||||
|
||||
def get_trans(self):
|
||||
return self._t
|
||||
|
||||
def get_rot_mats(self):
|
||||
quats = self.get_quats()
|
||||
rot_mats = Quaternion.quat_to_rot(quats)
|
||||
return rot_mats
|
||||
|
||||
@staticmethod
|
||||
def quat_to_rot(normalized_quat):
|
||||
global _QUAT_TO_ROT_tensor
|
||||
dtype = normalized_quat.dtype
|
||||
normalized_quat = normalized_quat.float()
|
||||
if _QUAT_TO_ROT_tensor.device != normalized_quat.device:
|
||||
_QUAT_TO_ROT_tensor = _QUAT_TO_ROT_tensor.to(
|
||||
normalized_quat.device)
|
||||
rot_tensor = torch.sum(
|
||||
_QUAT_TO_ROT_tensor * normalized_quat[..., :, None, None]
|
||||
* normalized_quat[..., None, :, None],
|
||||
dim=(-3, -2),
|
||||
)
|
||||
rot_tensor = rot_tensor.type(dtype)
|
||||
rot_tensor = rot_tensor.view(*rot_tensor.shape[:-1], 3, 3)
|
||||
return rot_tensor
|
||||
|
||||
@staticmethod
|
||||
def normalize_quat(quats):
|
||||
dtype = quats.dtype
|
||||
quats = quats.float()
|
||||
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
|
||||
quats = quats.type(dtype)
|
||||
return quats
|
||||
|
||||
@staticmethod
|
||||
def quat_multiply_by_vec(quat, vec):
|
||||
dtype = quat.dtype
|
||||
quat = quat.float()
|
||||
vec = vec.float()
|
||||
global _QUAT_MULTIPLY_BY_VEC_tensor
|
||||
if _QUAT_MULTIPLY_BY_VEC_tensor.device != quat.device:
|
||||
_QUAT_MULTIPLY_BY_VEC_tensor = _QUAT_MULTIPLY_BY_VEC_tensor.to(
|
||||
quat.device)
|
||||
mat = _QUAT_MULTIPLY_BY_VEC_tensor
|
||||
reshaped_mat = mat.view((1, ) * len(quat.shape[:-1]) + mat.shape)
|
||||
return torch.sum(
|
||||
reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None],
|
||||
dim=(-3, -2),
|
||||
).type(dtype)
|
||||
|
||||
def compose_q_update_vec(self,
|
||||
q_update_vec: torch.Tensor,
|
||||
normalize_quats: bool = True) -> torch.Tensor:
|
||||
quats = self.get_quats()
|
||||
new_quats = quats + Quaternion.quat_multiply_by_vec(
|
||||
quats, q_update_vec)
|
||||
if normalize_quats:
|
||||
new_quats = Quaternion.normalize_quat(new_quats)
|
||||
return new_quats
|
||||
|
||||
def compose_update_vec(
|
||||
self,
|
||||
update_vec: torch.Tensor,
|
||||
pre_rot_mat: Rotation,
|
||||
) -> Quaternion:
|
||||
q_vec, t_vec = update_vec[..., :3], update_vec[..., 3:]
|
||||
new_quats = self.compose_q_update_vec(q_vec)
|
||||
|
||||
trans_update = pre_rot_mat.apply(t_vec)
|
||||
new_trans = self._t + trans_update
|
||||
|
||||
return Quaternion(new_quats, new_trans)
|
||||
|
||||
def stop_rot_gradient(self) -> Quaternion:
|
||||
return Quaternion(self._q.detach(), self._t)
|
||||
592
modelscope/models/science/unifold/modules/structure_module.py
Normal file
592
modelscope/models/science/unifold/modules/structure_module.py
Normal file
@@ -0,0 +1,592 @@
|
||||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
|
||||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
|
||||
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from unicore.modules import LayerNorm, softmax_dropout
|
||||
from unicore.utils import dict_multimap, one_hot, permute_final_dims
|
||||
|
||||
from modelscope.models.science.unifold.data.residue_constants import (
|
||||
restype_atom14_mask, restype_atom14_rigid_group_positions,
|
||||
restype_atom14_to_rigid_group, restype_rigid_group_default_frame)
|
||||
from .attentions import gen_attn_mask
|
||||
from .common import Linear, SimpleModuleList, residual
|
||||
from .frame import Frame, Quaternion, Rotation
|
||||
|
||||
|
||||
def ipa_point_weights_init_(weights):
|
||||
with torch.no_grad():
|
||||
softplus_inverse_1 = 0.541324854612918
|
||||
weights.fill_(softplus_inverse_1)
|
||||
|
||||
|
||||
def torsion_angles_to_frames(
|
||||
frame: Frame,
|
||||
alpha: torch.Tensor,
|
||||
aatype: torch.Tensor,
|
||||
default_frames: torch.Tensor,
|
||||
):
|
||||
default_frame = Frame.from_tensor_4x4(default_frames[aatype, ...])
|
||||
|
||||
bb_rot = alpha.new_zeros((*((1, ) * len(alpha.shape[:-1])), 2))
|
||||
bb_rot[..., 1] = 1
|
||||
|
||||
alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha],
|
||||
dim=-2)
|
||||
|
||||
all_rots = alpha.new_zeros(default_frame.get_rots().rot_mat.shape)
|
||||
all_rots[..., 0, 0] = 1
|
||||
all_rots[..., 1, 1] = alpha[..., 1]
|
||||
all_rots[..., 1, 2] = -alpha[..., 0]
|
||||
all_rots[..., 2, 1:] = alpha
|
||||
|
||||
all_rots = Frame(Rotation(mat=all_rots), None)
|
||||
|
||||
all_frames = default_frame.compose(all_rots)
|
||||
|
||||
chi2_frame_to_frame = all_frames[..., 5]
|
||||
chi3_frame_to_frame = all_frames[..., 6]
|
||||
chi4_frame_to_frame = all_frames[..., 7]
|
||||
|
||||
chi1_frame_to_bb = all_frames[..., 4]
|
||||
chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
|
||||
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
|
||||
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
|
||||
|
||||
all_frames_to_bb = Frame.cat(
|
||||
[
|
||||
all_frames[..., :5],
|
||||
chi2_frame_to_bb.unsqueeze(-1),
|
||||
chi3_frame_to_bb.unsqueeze(-1),
|
||||
chi4_frame_to_bb.unsqueeze(-1),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
all_frames_to_global = frame[..., None].compose(all_frames_to_bb)
|
||||
|
||||
return all_frames_to_global
|
||||
|
||||
|
||||
def frames_and_literature_positions_to_atom14_pos(
|
||||
frame: Frame,
|
||||
aatype: torch.Tensor,
|
||||
default_frames,
|
||||
group_idx,
|
||||
atom_mask,
|
||||
lit_positions,
|
||||
):
|
||||
group_mask = group_idx[aatype, ...]
|
||||
group_mask = one_hot(
|
||||
group_mask,
|
||||
num_classes=default_frames.shape[-3],
|
||||
)
|
||||
|
||||
t_atoms_to_global = frame[..., None, :] * group_mask
|
||||
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
|
||||
lambda x: torch.sum(x, dim=-1))
|
||||
|
||||
atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
|
||||
|
||||
lit_positions = lit_positions[aatype, ...]
|
||||
pred_positions = t_atoms_to_global.apply(lit_positions)
|
||||
pred_positions = pred_positions * atom_mask
|
||||
|
||||
return pred_positions
|
||||
|
||||
|
||||
class SideChainAngleResnetIteration(nn.Module):
|
||||
|
||||
def __init__(self, d_hid):
|
||||
super(SideChainAngleResnetIteration, self).__init__()
|
||||
|
||||
self.d_hid = d_hid
|
||||
|
||||
self.linear_1 = Linear(self.d_hid, self.d_hid, init='relu')
|
||||
self.act = nn.GELU()
|
||||
self.linear_2 = Linear(self.d_hid, self.d_hid, init='final')
|
||||
|
||||
def forward(self, s: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
x = self.act(s)
|
||||
x = self.linear_1(x)
|
||||
x = self.act(x)
|
||||
x = self.linear_2(x)
|
||||
|
||||
return residual(s, x, self.training)
|
||||
|
||||
|
||||
class SidechainAngleResnet(nn.Module):
|
||||
|
||||
def __init__(self, d_in, d_hid, num_blocks, num_angles):
|
||||
super(SidechainAngleResnet, self).__init__()
|
||||
|
||||
self.linear_in = Linear(d_in, d_hid)
|
||||
self.act = nn.GELU()
|
||||
self.linear_initial = Linear(d_in, d_hid)
|
||||
|
||||
self.layers = SimpleModuleList()
|
||||
for _ in range(num_blocks):
|
||||
self.layers.append(SideChainAngleResnetIteration(d_hid=d_hid))
|
||||
|
||||
self.linear_out = Linear(d_hid, num_angles * 2)
|
||||
|
||||
def forward(self, s: torch.Tensor,
|
||||
initial_s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
initial_s = self.linear_initial(self.act(initial_s))
|
||||
s = self.linear_in(self.act(s))
|
||||
|
||||
s = s + initial_s
|
||||
|
||||
for layer in self.layers:
|
||||
s = layer(s)
|
||||
|
||||
s = self.linear_out(self.act(s))
|
||||
|
||||
s = s.view(s.shape[:-1] + (-1, 2))
|
||||
|
||||
unnormalized_s = s
|
||||
norm_denom = torch.sqrt(
|
||||
torch.clamp(
|
||||
torch.sum(s.float()**2, dim=-1, keepdim=True),
|
||||
min=1e-12,
|
||||
))
|
||||
s = s.float() / norm_denom
|
||||
|
||||
return unnormalized_s, s.type(unnormalized_s.dtype)
|
||||
|
||||
|
||||
class InvariantPointAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_single: int,
|
||||
d_pair: int,
|
||||
d_hid: int,
|
||||
num_heads: int,
|
||||
num_qk_points: int,
|
||||
num_v_points: int,
|
||||
separate_kv: bool = False,
|
||||
bias: bool = True,
|
||||
eps: float = 1e-8,
|
||||
):
|
||||
super(InvariantPointAttention, self).__init__()
|
||||
|
||||
self.d_hid = d_hid
|
||||
self.num_heads = num_heads
|
||||
self.num_qk_points = num_qk_points
|
||||
self.num_v_points = num_v_points
|
||||
self.eps = eps
|
||||
|
||||
hc = self.d_hid * self.num_heads
|
||||
self.linear_q = Linear(d_single, hc, bias=bias)
|
||||
self.separate_kv = separate_kv
|
||||
if self.separate_kv:
|
||||
self.linear_k = Linear(d_single, hc, bias=bias)
|
||||
self.linear_v = Linear(d_single, hc, bias=bias)
|
||||
else:
|
||||
self.linear_kv = Linear(d_single, 2 * hc, bias=bias)
|
||||
|
||||
hpq = self.num_heads * self.num_qk_points * 3
|
||||
self.linear_q_points = Linear(d_single, hpq)
|
||||
hpk = self.num_heads * self.num_qk_points * 3
|
||||
hpv = self.num_heads * self.num_v_points * 3
|
||||
if self.separate_kv:
|
||||
self.linear_k_points = Linear(d_single, hpk)
|
||||
self.linear_v_points = Linear(d_single, hpv)
|
||||
else:
|
||||
hpkv = hpk + hpv
|
||||
self.linear_kv_points = Linear(d_single, hpkv)
|
||||
|
||||
self.linear_b = Linear(d_pair, self.num_heads)
|
||||
|
||||
self.head_weights = nn.Parameter(torch.zeros((num_heads)))
|
||||
ipa_point_weights_init_(self.head_weights)
|
||||
|
||||
concat_out_dim = self.num_heads * (
|
||||
d_pair + self.d_hid + self.num_v_points * 4)
|
||||
self.linear_out = Linear(concat_out_dim, d_single, init='final')
|
||||
|
||||
self.softplus = nn.Softplus()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
s: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
f: Frame,
|
||||
square_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
q = self.linear_q(s)
|
||||
|
||||
q = q.view(q.shape[:-1] + (self.num_heads, -1))
|
||||
|
||||
if self.separate_kv:
|
||||
k = self.linear_k(s)
|
||||
v = self.linear_v(s)
|
||||
k = k.view(k.shape[:-1] + (self.num_heads, -1))
|
||||
v = v.view(v.shape[:-1] + (self.num_heads, -1))
|
||||
else:
|
||||
kv = self.linear_kv(s)
|
||||
kv = kv.view(kv.shape[:-1] + (self.num_heads, -1))
|
||||
k, v = torch.split(kv, self.d_hid, dim=-1)
|
||||
|
||||
q_pts = self.linear_q_points(s)
|
||||
|
||||
def process_points(pts, no_points):
|
||||
shape = pts.shape[:-1] + (pts.shape[-1] // 3, 3)
|
||||
if self.separate_kv:
|
||||
# alphafold-multimer uses different layout
|
||||
pts = pts.view(pts.shape[:-1]
|
||||
+ (self.num_heads, no_points * 3))
|
||||
pts = torch.split(pts, pts.shape[-1] // 3, dim=-1)
|
||||
pts = torch.stack(pts, dim=-1).view(*shape)
|
||||
pts = f[..., None].apply(pts)
|
||||
|
||||
pts = pts.view(pts.shape[:-2] + (self.num_heads, no_points, 3))
|
||||
return pts
|
||||
|
||||
q_pts = process_points(q_pts, self.num_qk_points)
|
||||
|
||||
if self.separate_kv:
|
||||
k_pts = self.linear_k_points(s)
|
||||
v_pts = self.linear_v_points(s)
|
||||
k_pts = process_points(k_pts, self.num_qk_points)
|
||||
v_pts = process_points(v_pts, self.num_v_points)
|
||||
else:
|
||||
kv_pts = self.linear_kv_points(s)
|
||||
|
||||
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
|
||||
kv_pts = torch.stack(kv_pts, dim=-1)
|
||||
kv_pts = f[..., None].apply(kv_pts)
|
||||
|
||||
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3))
|
||||
|
||||
k_pts, v_pts = torch.split(
|
||||
kv_pts, [self.num_qk_points, self.num_v_points], dim=-2)
|
||||
|
||||
bias = self.linear_b(z)
|
||||
|
||||
attn = torch.matmul(
|
||||
permute_final_dims(q, (1, 0, 2)),
|
||||
permute_final_dims(k, (1, 2, 0)),
|
||||
)
|
||||
|
||||
if self.training:
|
||||
attn = attn * math.sqrt(1.0 / (3 * self.d_hid))
|
||||
attn = attn + (
|
||||
math.sqrt(1.0 / 3) * permute_final_dims(bias, (2, 0, 1)))
|
||||
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
|
||||
pt_att = pt_att.float()**2
|
||||
else:
|
||||
attn *= math.sqrt(1.0 / (3 * self.d_hid))
|
||||
attn += (math.sqrt(1.0 / 3) * permute_final_dims(bias, (2, 0, 1)))
|
||||
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
|
||||
pt_att *= pt_att
|
||||
|
||||
pt_att = pt_att.sum(dim=-1)
|
||||
head_weights = self.softplus(self.head_weights).view(
|
||||
*((1, ) * len(pt_att.shape[:-2]) + (-1, 1)))
|
||||
head_weights = head_weights * math.sqrt(
|
||||
1.0 / (3 * (self.num_qk_points * 9.0 / 2)))
|
||||
pt_att *= head_weights * (-0.5)
|
||||
|
||||
pt_att = torch.sum(pt_att, dim=-1)
|
||||
|
||||
pt_att = permute_final_dims(pt_att, (2, 0, 1))
|
||||
attn += square_mask
|
||||
attn = softmax_dropout(
|
||||
attn, 0, self.training, bias=pt_att.type(attn.dtype))
|
||||
del pt_att, q_pts, k_pts, bias
|
||||
o = torch.matmul(attn, v.transpose(-2, -3)).transpose(-2, -3)
|
||||
o = o.contiguous().view(*o.shape[:-2], -1)
|
||||
del q, k, v
|
||||
o_pts = torch.sum(
|
||||
(attn[..., None, :, :, None]
|
||||
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]),
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
o_pts = permute_final_dims(o_pts, (2, 0, 3, 1))
|
||||
o_pts = f[..., None, None].invert_apply(o_pts)
|
||||
if self.training:
|
||||
o_pts_norm = torch.sqrt(
|
||||
torch.sum(o_pts.float()**2, dim=-1) + self.eps).type(
|
||||
o_pts.dtype)
|
||||
else:
|
||||
o_pts_norm = torch.sqrt(torch.sum(o_pts**2, dim=-1)
|
||||
+ self.eps).type(o_pts.dtype)
|
||||
|
||||
o_pts_norm = o_pts_norm.view(*o_pts_norm.shape[:-2], -1)
|
||||
|
||||
o_pts = o_pts.view(*o_pts.shape[:-3], -1, 3)
|
||||
|
||||
o_pair = torch.matmul(attn.transpose(-2, -3), z)
|
||||
|
||||
o_pair = o_pair.view(*o_pair.shape[:-2], -1)
|
||||
|
||||
s = self.linear_out(
|
||||
torch.cat((o, *torch.unbind(o_pts, dim=-1), o_pts_norm, o_pair),
|
||||
dim=-1))
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class BackboneUpdate(nn.Module):
|
||||
|
||||
def __init__(self, d_single):
|
||||
super(BackboneUpdate, self).__init__()
|
||||
self.linear = Linear(d_single, 6, init='final')
|
||||
|
||||
def forward(self, s: torch.Tensor):
|
||||
return self.linear(s)
|
||||
|
||||
|
||||
class StructureModuleTransitionLayer(nn.Module):
|
||||
|
||||
def __init__(self, c):
|
||||
super(StructureModuleTransitionLayer, self).__init__()
|
||||
|
||||
self.linear_1 = Linear(c, c, init='relu')
|
||||
self.linear_2 = Linear(c, c, init='relu')
|
||||
self.act = nn.GELU()
|
||||
self.linear_3 = Linear(c, c, init='final')
|
||||
|
||||
def forward(self, s):
|
||||
s_old = s
|
||||
s = self.linear_1(s)
|
||||
s = self.act(s)
|
||||
s = self.linear_2(s)
|
||||
s = self.act(s)
|
||||
s = self.linear_3(s)
|
||||
|
||||
s = residual(s_old, s, self.training)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class StructureModuleTransition(nn.Module):
|
||||
|
||||
def __init__(self, c, num_layers, dropout_rate):
|
||||
super(StructureModuleTransition, self).__init__()
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
self.layers = SimpleModuleList()
|
||||
for _ in range(self.num_layers):
|
||||
self.layers.append(StructureModuleTransitionLayer(c))
|
||||
|
||||
self.dropout = nn.Dropout(self.dropout_rate)
|
||||
self.layer_norm = LayerNorm(c)
|
||||
|
||||
def forward(self, s):
|
||||
for layer in self.layers:
|
||||
s = layer(s)
|
||||
|
||||
s = self.dropout(s)
|
||||
s = self.layer_norm(s)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class StructureModule(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_single,
|
||||
d_pair,
|
||||
d_ipa,
|
||||
d_angle,
|
||||
num_heads_ipa,
|
||||
num_qk_points,
|
||||
num_v_points,
|
||||
dropout_rate,
|
||||
num_blocks,
|
||||
no_transition_layers,
|
||||
num_resnet_blocks,
|
||||
num_angles,
|
||||
trans_scale_factor,
|
||||
separate_kv,
|
||||
ipa_bias,
|
||||
epsilon,
|
||||
inf,
|
||||
**kwargs,
|
||||
):
|
||||
super(StructureModule, self).__init__()
|
||||
|
||||
self.num_blocks = num_blocks
|
||||
self.trans_scale_factor = trans_scale_factor
|
||||
self.default_frames = None
|
||||
self.group_idx = None
|
||||
self.atom_mask = None
|
||||
self.lit_positions = None
|
||||
self.inf = inf
|
||||
|
||||
self.layer_norm_s = LayerNorm(d_single)
|
||||
self.layer_norm_z = LayerNorm(d_pair)
|
||||
|
||||
self.linear_in = Linear(d_single, d_single)
|
||||
|
||||
self.ipa = InvariantPointAttention(
|
||||
d_single,
|
||||
d_pair,
|
||||
d_ipa,
|
||||
num_heads_ipa,
|
||||
num_qk_points,
|
||||
num_v_points,
|
||||
separate_kv=separate_kv,
|
||||
bias=ipa_bias,
|
||||
eps=epsilon,
|
||||
)
|
||||
|
||||
self.ipa_dropout = nn.Dropout(dropout_rate)
|
||||
self.layer_norm_ipa = LayerNorm(d_single)
|
||||
|
||||
self.transition = StructureModuleTransition(
|
||||
d_single,
|
||||
no_transition_layers,
|
||||
dropout_rate,
|
||||
)
|
||||
|
||||
self.bb_update = BackboneUpdate(d_single)
|
||||
|
||||
self.angle_resnet = SidechainAngleResnet(
|
||||
d_single,
|
||||
d_angle,
|
||||
num_resnet_blocks,
|
||||
num_angles,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
s,
|
||||
z,
|
||||
aatype,
|
||||
mask=None,
|
||||
):
|
||||
if mask is None:
|
||||
mask = s.new_ones(s.shape[:-1])
|
||||
|
||||
# generate square mask
|
||||
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
|
||||
square_mask = gen_attn_mask(square_mask, -self.inf).unsqueeze(-3)
|
||||
s = self.layer_norm_s(s)
|
||||
z = self.layer_norm_z(z)
|
||||
initial_s = s
|
||||
s = self.linear_in(s)
|
||||
|
||||
quat_encoder = Quaternion.identity(
|
||||
s.shape[:-1],
|
||||
s.dtype,
|
||||
s.device,
|
||||
requires_grad=False,
|
||||
)
|
||||
backb_to_global = Frame(
|
||||
Rotation(mat=quat_encoder.get_rot_mats(), ),
|
||||
quat_encoder.get_trans(),
|
||||
)
|
||||
outputs = []
|
||||
for i in range(self.num_blocks):
|
||||
s = residual(s, self.ipa(s, z, backb_to_global, square_mask),
|
||||
self.training)
|
||||
s = self.ipa_dropout(s)
|
||||
s = self.layer_norm_ipa(s)
|
||||
s = self.transition(s)
|
||||
|
||||
# update quaternion encoder
|
||||
# use backb_to_global to avoid quat-to-rot conversion
|
||||
quat_encoder = quat_encoder.compose_update_vec(
|
||||
self.bb_update(s), pre_rot_mat=backb_to_global.get_rots())
|
||||
|
||||
# initial_s is always used to update the backbone
|
||||
unnormalized_angles, angles = self.angle_resnet(s, initial_s)
|
||||
|
||||
# convert quaternion to rotation matrix
|
||||
backb_to_global = Frame(
|
||||
Rotation(mat=quat_encoder.get_rot_mats(), ),
|
||||
quat_encoder.get_trans(),
|
||||
)
|
||||
if i == self.num_blocks - 1:
|
||||
all_frames_to_global = self.torsion_angles_to_frames(
|
||||
backb_to_global.scale_translation(self.trans_scale_factor),
|
||||
angles,
|
||||
aatype,
|
||||
)
|
||||
|
||||
pred_positions = self.frames_and_literature_positions_to_atom14_pos(
|
||||
all_frames_to_global,
|
||||
aatype,
|
||||
)
|
||||
|
||||
preds = {
|
||||
'frames':
|
||||
backb_to_global.scale_translation(
|
||||
self.trans_scale_factor).to_tensor_4x4(),
|
||||
'unnormalized_angles':
|
||||
unnormalized_angles,
|
||||
'angles':
|
||||
angles,
|
||||
}
|
||||
|
||||
outputs.append(preds)
|
||||
if i < (self.num_blocks - 1):
|
||||
# stop gradient in iteration
|
||||
quat_encoder = quat_encoder.stop_rot_gradient()
|
||||
backb_to_global = backb_to_global.stop_rot_gradient()
|
||||
|
||||
outputs = dict_multimap(torch.stack, outputs)
|
||||
outputs['sidechain_frames'] = all_frames_to_global.to_tensor_4x4()
|
||||
outputs['positions'] = pred_positions
|
||||
outputs['single'] = s
|
||||
|
||||
return outputs
|
||||
|
||||
def _init_residue_constants(self, float_dtype, device):
|
||||
if self.default_frames is None:
|
||||
self.default_frames = torch.tensor(
|
||||
restype_rigid_group_default_frame,
|
||||
dtype=float_dtype,
|
||||
device=device,
|
||||
requires_grad=False,
|
||||
)
|
||||
if self.group_idx is None:
|
||||
self.group_idx = torch.tensor(
|
||||
restype_atom14_to_rigid_group,
|
||||
device=device,
|
||||
requires_grad=False,
|
||||
)
|
||||
if self.atom_mask is None:
|
||||
self.atom_mask = torch.tensor(
|
||||
restype_atom14_mask,
|
||||
dtype=float_dtype,
|
||||
device=device,
|
||||
requires_grad=False,
|
||||
)
|
||||
if self.lit_positions is None:
|
||||
self.lit_positions = torch.tensor(
|
||||
restype_atom14_rigid_group_positions,
|
||||
dtype=float_dtype,
|
||||
device=device,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def torsion_angles_to_frames(self, frame, alpha, aatype):
|
||||
self._init_residue_constants(alpha.dtype, alpha.device)
|
||||
return torsion_angles_to_frames(frame, alpha, aatype,
|
||||
self.default_frames)
|
||||
|
||||
def frames_and_literature_positions_to_atom14_pos(self, frame, aatype):
|
||||
self._init_residue_constants(frame.get_rots().dtype,
|
||||
frame.get_rots().device)
|
||||
return frames_and_literature_positions_to_atom14_pos(
|
||||
frame,
|
||||
aatype,
|
||||
self.default_frames,
|
||||
self.group_idx,
|
||||
self.atom_mask,
|
||||
self.lit_positions,
|
||||
)
|
||||
330
modelscope/models/science/unifold/modules/template.py
Normal file
330
modelscope/models/science/unifold/modules/template.py
Normal file
@@ -0,0 +1,330 @@
|
||||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
|
||||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from unicore.modules import LayerNorm
|
||||
from unicore.utils import (checkpoint_sequential, permute_final_dims,
|
||||
tensor_tree_map)
|
||||
|
||||
from .attentions import (Attention, TriangleAttentionEnding,
|
||||
TriangleAttentionStarting, gen_attn_mask)
|
||||
from .common import (Linear, SimpleModuleList, Transition,
|
||||
bias_dropout_residual, chunk_layer, residual,
|
||||
tri_mul_residual)
|
||||
from .featurization import build_template_pair_feat_v2
|
||||
from .triangle_multiplication import (TriangleMultiplicationIncoming,
|
||||
TriangleMultiplicationOutgoing)
|
||||
|
||||
|
||||
class TemplatePointwiseAttention(nn.Module):
|
||||
|
||||
def __init__(self, d_template, d_pair, d_hid, num_heads, inf, **kwargs):
|
||||
super(TemplatePointwiseAttention, self).__init__()
|
||||
|
||||
self.inf = inf
|
||||
|
||||
self.mha = Attention(
|
||||
d_pair,
|
||||
d_template,
|
||||
d_template,
|
||||
d_hid,
|
||||
num_heads,
|
||||
gating=False,
|
||||
)
|
||||
|
||||
def _chunk(
|
||||
self,
|
||||
z: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
mha_inputs = {
|
||||
'q': z,
|
||||
'k': t,
|
||||
'v': t,
|
||||
'mask': mask,
|
||||
}
|
||||
return chunk_layer(
|
||||
self.mha,
|
||||
mha_inputs,
|
||||
chunk_size=chunk_size,
|
||||
num_batch_dims=len(z.shape[:-2]),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
t: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
template_mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
if template_mask is None:
|
||||
template_mask = t.new_ones(t.shape[:-3])
|
||||
|
||||
mask = gen_attn_mask(template_mask, -self.inf)[..., None, None, None,
|
||||
None, :]
|
||||
z = z.unsqueeze(-2)
|
||||
|
||||
t = permute_final_dims(t, (1, 2, 0, 3))
|
||||
|
||||
if chunk_size is not None:
|
||||
z = self._chunk(z, t, mask, chunk_size)
|
||||
else:
|
||||
z = self.mha(z, t, t, mask=mask)
|
||||
|
||||
z = z.squeeze(-2)
|
||||
|
||||
return z
|
||||
|
||||
|
||||
class TemplateProjection(nn.Module):
|
||||
|
||||
def __init__(self, d_template, d_pair, **kwargs):
|
||||
super(TemplateProjection, self).__init__()
|
||||
|
||||
self.d_pair = d_pair
|
||||
self.act = nn.ReLU()
|
||||
self.output_linear = Linear(d_template, d_pair, init='relu')
|
||||
|
||||
def forward(self, t, z) -> torch.Tensor:
|
||||
if t is None:
|
||||
# handle for non-template case
|
||||
shape = z.shape
|
||||
shape[-1] = self.d_pair
|
||||
t = torch.zeros(shape, dtype=z.dtype, device=z.device)
|
||||
t = self.act(t)
|
||||
z_t = self.output_linear(t)
|
||||
return z_t
|
||||
|
||||
|
||||
class TemplatePairStackBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_template: int,
|
||||
d_hid_tri_att: int,
|
||||
d_hid_tri_mul: int,
|
||||
num_heads: int,
|
||||
pair_transition_n: int,
|
||||
dropout_rate: float,
|
||||
tri_attn_first: bool,
|
||||
inf: float,
|
||||
**kwargs,
|
||||
):
|
||||
super(TemplatePairStackBlock, self).__init__()
|
||||
|
||||
self.tri_att_start = TriangleAttentionStarting(
|
||||
d_template,
|
||||
d_hid_tri_att,
|
||||
num_heads,
|
||||
)
|
||||
self.tri_att_end = TriangleAttentionEnding(
|
||||
d_template,
|
||||
d_hid_tri_att,
|
||||
num_heads,
|
||||
)
|
||||
|
||||
self.tri_mul_out = TriangleMultiplicationOutgoing(
|
||||
d_template,
|
||||
d_hid_tri_mul,
|
||||
)
|
||||
self.tri_mul_in = TriangleMultiplicationIncoming(
|
||||
d_template,
|
||||
d_hid_tri_mul,
|
||||
)
|
||||
|
||||
self.pair_transition = Transition(
|
||||
d_template,
|
||||
pair_transition_n,
|
||||
)
|
||||
self.tri_attn_first = tri_attn_first
|
||||
self.dropout = dropout_rate
|
||||
self.row_dropout_share_dim = -3
|
||||
self.col_dropout_share_dim = -2
|
||||
|
||||
def forward(
|
||||
self,
|
||||
s: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
tri_start_attn_mask: torch.Tensor,
|
||||
tri_end_attn_mask: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
block_size: Optional[int] = None,
|
||||
):
|
||||
if self.tri_attn_first:
|
||||
s = bias_dropout_residual(
|
||||
self.tri_att_start,
|
||||
s,
|
||||
self.tri_att_start(
|
||||
s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size),
|
||||
self.row_dropout_share_dim,
|
||||
self.dropout,
|
||||
self.training,
|
||||
)
|
||||
|
||||
s = bias_dropout_residual(
|
||||
self.tri_att_end,
|
||||
s,
|
||||
self.tri_att_end(
|
||||
s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size),
|
||||
self.col_dropout_share_dim,
|
||||
self.dropout,
|
||||
self.training,
|
||||
)
|
||||
s = tri_mul_residual(
|
||||
self.tri_mul_out,
|
||||
s,
|
||||
self.tri_mul_out(s, mask=mask, block_size=block_size),
|
||||
self.row_dropout_share_dim,
|
||||
self.dropout,
|
||||
self.training,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
s = tri_mul_residual(
|
||||
self.tri_mul_in,
|
||||
s,
|
||||
self.tri_mul_in(s, mask=mask, block_size=block_size),
|
||||
self.row_dropout_share_dim,
|
||||
self.dropout,
|
||||
self.training,
|
||||
block_size=block_size,
|
||||
)
|
||||
else:
|
||||
s = tri_mul_residual(
|
||||
self.tri_mul_out,
|
||||
s,
|
||||
self.tri_mul_out(s, mask=mask, block_size=block_size),
|
||||
self.row_dropout_share_dim,
|
||||
self.dropout,
|
||||
self.training,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
s = tri_mul_residual(
|
||||
self.tri_mul_in,
|
||||
s,
|
||||
self.tri_mul_in(s, mask=mask, block_size=block_size),
|
||||
self.row_dropout_share_dim,
|
||||
self.dropout,
|
||||
self.training,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
s = bias_dropout_residual(
|
||||
self.tri_att_start,
|
||||
s,
|
||||
self.tri_att_start(
|
||||
s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size),
|
||||
self.row_dropout_share_dim,
|
||||
self.dropout,
|
||||
self.training,
|
||||
)
|
||||
|
||||
s = bias_dropout_residual(
|
||||
self.tri_att_end,
|
||||
s,
|
||||
self.tri_att_end(
|
||||
s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size),
|
||||
self.col_dropout_share_dim,
|
||||
self.dropout,
|
||||
self.training,
|
||||
)
|
||||
s = residual(s, self.pair_transition(
|
||||
s,
|
||||
chunk_size=chunk_size,
|
||||
), self.training)
|
||||
return s
|
||||
|
||||
|
||||
class TemplatePairStack(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_template,
|
||||
d_hid_tri_att,
|
||||
d_hid_tri_mul,
|
||||
num_blocks,
|
||||
num_heads,
|
||||
pair_transition_n,
|
||||
dropout_rate,
|
||||
tri_attn_first,
|
||||
inf=1e9,
|
||||
**kwargs,
|
||||
):
|
||||
super(TemplatePairStack, self).__init__()
|
||||
|
||||
self.blocks = SimpleModuleList()
|
||||
for _ in range(num_blocks):
|
||||
self.blocks.append(
|
||||
TemplatePairStackBlock(
|
||||
d_template=d_template,
|
||||
d_hid_tri_att=d_hid_tri_att,
|
||||
d_hid_tri_mul=d_hid_tri_mul,
|
||||
num_heads=num_heads,
|
||||
pair_transition_n=pair_transition_n,
|
||||
dropout_rate=dropout_rate,
|
||||
inf=inf,
|
||||
tri_attn_first=tri_attn_first,
|
||||
))
|
||||
|
||||
self.layer_norm = LayerNorm(d_template)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
single_templates: Tuple[torch.Tensor],
|
||||
mask: torch.tensor,
|
||||
tri_start_attn_mask: torch.Tensor,
|
||||
tri_end_attn_mask: torch.Tensor,
|
||||
templ_dim: int,
|
||||
chunk_size: int,
|
||||
block_size: int,
|
||||
return_mean: bool,
|
||||
):
|
||||
|
||||
def one_template(i):
|
||||
(s, ) = checkpoint_sequential(
|
||||
functions=[
|
||||
partial(
|
||||
b,
|
||||
mask=mask,
|
||||
tri_start_attn_mask=tri_start_attn_mask,
|
||||
tri_end_attn_mask=tri_end_attn_mask,
|
||||
chunk_size=chunk_size,
|
||||
block_size=block_size,
|
||||
) for b in self.blocks
|
||||
],
|
||||
input=(single_templates[i], ),
|
||||
)
|
||||
return s
|
||||
|
||||
n_templ = len(single_templates)
|
||||
if n_templ > 0:
|
||||
new_single_templates = [one_template(0)]
|
||||
if return_mean:
|
||||
t = self.layer_norm(new_single_templates[0])
|
||||
for i in range(1, n_templ):
|
||||
s = one_template(i)
|
||||
if return_mean:
|
||||
t = residual(t, self.layer_norm(s), self.training)
|
||||
else:
|
||||
new_single_templates.append(s)
|
||||
|
||||
if return_mean:
|
||||
if n_templ > 0:
|
||||
t /= n_templ
|
||||
else:
|
||||
t = None
|
||||
else:
|
||||
t = torch.cat(
|
||||
[s.unsqueeze(templ_dim) for s in new_single_templates],
|
||||
dim=templ_dim)
|
||||
t = self.layer_norm(t)
|
||||
|
||||
return t
|
||||
@@ -0,0 +1,158 @@
|
||||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
|
||||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
|
||||
|
||||
from functools import partialmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from unicore.modules import LayerNorm
|
||||
from unicore.utils import permute_final_dims
|
||||
|
||||
from .common import Linear
|
||||
|
||||
|
||||
class TriangleMultiplication(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, d_hid, outgoing=True):
|
||||
super(TriangleMultiplication, self).__init__()
|
||||
self.outgoing = outgoing
|
||||
|
||||
self.linear_ab_p = Linear(d_pair, d_hid * 2)
|
||||
self.linear_ab_g = Linear(d_pair, d_hid * 2, init='gating')
|
||||
|
||||
self.linear_g = Linear(d_pair, d_pair, init='gating')
|
||||
self.linear_z = Linear(d_hid, d_pair, init='final')
|
||||
|
||||
self.layer_norm_in = LayerNorm(d_pair)
|
||||
self.layer_norm_out = LayerNorm(d_hid)
|
||||
|
||||
self._alphafold_original_mode = False
|
||||
|
||||
def _chunk_2d(
|
||||
self,
|
||||
z: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
block_size: int = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# avoid too small chunk size
|
||||
# block_size = max(block_size, 256)
|
||||
new_z = z.new_zeros(z.shape)
|
||||
dim1 = z.shape[-3]
|
||||
|
||||
def _slice_linear(z, linear: Linear, a=True):
|
||||
d_hid = linear.bias.shape[0] // 2
|
||||
index = 0 if a else d_hid
|
||||
p = (
|
||||
nn.functional.linear(z, linear.weight[index:index + d_hid])
|
||||
+ linear.bias[index:index + d_hid])
|
||||
return p
|
||||
|
||||
def _chunk_projection(z, mask, a=True):
|
||||
p = _slice_linear(z, self.linear_ab_p, a) * mask
|
||||
p *= torch.sigmoid(_slice_linear(z, self.linear_ab_g, a))
|
||||
return p
|
||||
|
||||
num_chunk = (dim1 + block_size - 1) // block_size
|
||||
for i in range(num_chunk):
|
||||
chunk_start = i * block_size
|
||||
chunk_end = min(chunk_start + block_size, dim1)
|
||||
if self.outgoing:
|
||||
a_chunk = _chunk_projection(
|
||||
z[..., chunk_start:chunk_end, :, :],
|
||||
mask[..., chunk_start:chunk_end, :, :],
|
||||
a=True,
|
||||
)
|
||||
a_chunk = permute_final_dims(a_chunk, (2, 0, 1))
|
||||
else:
|
||||
a_chunk = _chunk_projection(
|
||||
z[..., :, chunk_start:chunk_end, :],
|
||||
mask[..., :, chunk_start:chunk_end, :],
|
||||
a=True,
|
||||
)
|
||||
a_chunk = a_chunk.transpose(-1, -3)
|
||||
|
||||
for j in range(num_chunk):
|
||||
j_chunk_start = j * block_size
|
||||
j_chunk_end = min(j_chunk_start + block_size, dim1)
|
||||
if self.outgoing:
|
||||
b_chunk = _chunk_projection(
|
||||
z[..., j_chunk_start:j_chunk_end, :, :],
|
||||
mask[..., j_chunk_start:j_chunk_end, :, :],
|
||||
a=False,
|
||||
)
|
||||
b_chunk = b_chunk.transpose(-1, -3)
|
||||
else:
|
||||
b_chunk = _chunk_projection(
|
||||
z[..., :, j_chunk_start:j_chunk_end, :],
|
||||
mask[..., :, j_chunk_start:j_chunk_end, :],
|
||||
a=False,
|
||||
)
|
||||
b_chunk = permute_final_dims(b_chunk, (2, 0, 1))
|
||||
x_chunk = torch.matmul(a_chunk, b_chunk)
|
||||
del b_chunk
|
||||
x_chunk = permute_final_dims(x_chunk, (1, 2, 0))
|
||||
x_chunk = self.layer_norm_out(x_chunk)
|
||||
x_chunk = self.linear_z(x_chunk)
|
||||
x_chunk *= torch.sigmoid(
|
||||
self.linear_g(z[..., chunk_start:chunk_end,
|
||||
j_chunk_start:j_chunk_end, :]))
|
||||
new_z[..., chunk_start:chunk_end,
|
||||
j_chunk_start:j_chunk_end, :] = x_chunk
|
||||
del x_chunk
|
||||
del a_chunk
|
||||
return new_z
|
||||
|
||||
def forward(
|
||||
self,
|
||||
z: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
block_size=None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
mask = mask.unsqueeze(-1)
|
||||
if not self._alphafold_original_mode:
|
||||
# divided by 1/sqrt(dim) for numerical stability
|
||||
mask = mask * (mask.shape[-2]**-0.5)
|
||||
|
||||
z = self.layer_norm_in(z)
|
||||
if not self.training and block_size is not None:
|
||||
return self._chunk_2d(z, mask, block_size=block_size)
|
||||
|
||||
g = nn.functional.linear(z, self.linear_g.weight)
|
||||
if self.training:
|
||||
ab = self.linear_ab_p(z) * mask * torch.sigmoid(
|
||||
self.linear_ab_g(z))
|
||||
else:
|
||||
ab = self.linear_ab_p(z)
|
||||
ab *= mask
|
||||
ab *= torch.sigmoid(self.linear_ab_g(z))
|
||||
a, b = torch.chunk(ab, 2, dim=-1)
|
||||
del z, ab
|
||||
|
||||
if self.outgoing:
|
||||
a = permute_final_dims(a, (2, 0, 1))
|
||||
b = b.transpose(-1, -3)
|
||||
else:
|
||||
b = permute_final_dims(b, (2, 0, 1))
|
||||
a = a.transpose(-1, -3)
|
||||
x = torch.matmul(a, b)
|
||||
del a, b
|
||||
|
||||
x = permute_final_dims(x, (1, 2, 0))
|
||||
|
||||
x = self.layer_norm_out(x)
|
||||
x = nn.functional.linear(x, self.linear_z.weight)
|
||||
return x, g
|
||||
|
||||
def get_output_bias(self):
|
||||
return self.linear_z.bias, self.linear_g.bias
|
||||
|
||||
|
||||
class TriangleMultiplicationOutgoing(TriangleMultiplication):
|
||||
__init__ = partialmethod(TriangleMultiplication.__init__, outgoing=True)
|
||||
|
||||
|
||||
class TriangleMultiplicationIncoming(TriangleMultiplication):
|
||||
__init__ = partialmethod(TriangleMultiplication.__init__, outgoing=False)
|
||||
1
modelscope/models/science/unifold/msa/__init__.py
Normal file
1
modelscope/models/science/unifold/msa/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
""" Scripts for MSA & template searching. """
|
||||
483
modelscope/models/science/unifold/msa/mmcif.py
Normal file
483
modelscope/models/science/unifold/msa/mmcif.py
Normal file
@@ -0,0 +1,483 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Parses the mmCIF file format."""
|
||||
import collections
|
||||
import dataclasses
|
||||
import functools
|
||||
import io
|
||||
from typing import Any, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
from absl import logging
|
||||
from Bio import PDB
|
||||
from Bio.Data import SCOPData
|
||||
from Bio.PDB.MMCIFParser import MMCIFParser
|
||||
|
||||
# Type aliases:
|
||||
ChainId = str
|
||||
PdbHeader = Mapping[str, Any]
|
||||
PdbStructure = PDB.Structure.Structure
|
||||
SeqRes = str
|
||||
MmCIFDict = Mapping[str, Sequence[str]]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Monomer:
|
||||
id: str
|
||||
num: int
|
||||
|
||||
|
||||
# Note - mmCIF format provides no guarantees on the type of author-assigned
|
||||
# sequence numbers. They need not be integers.
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AtomSite:
|
||||
residue_name: str
|
||||
author_chain_id: str
|
||||
mmcif_chain_id: str
|
||||
author_seq_num: str
|
||||
mmcif_seq_num: int
|
||||
insertion_code: str
|
||||
hetatm_atom: str
|
||||
model_num: int
|
||||
|
||||
|
||||
# Used to map SEQRES index to a residue in the structure.
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ResiduePosition:
|
||||
chain_id: str
|
||||
residue_number: int
|
||||
insertion_code: str
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ResidueAtPosition:
|
||||
position: Optional[ResiduePosition]
|
||||
name: str
|
||||
is_missing: bool
|
||||
hetflag: str
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class MmcifObject:
|
||||
"""Representation of a parsed mmCIF file.
|
||||
|
||||
Contains:
|
||||
file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all
|
||||
files being processed.
|
||||
header: Biopython header.
|
||||
structure: Biopython structure.
|
||||
chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g.
|
||||
{'A': 'ABCDEFG'}
|
||||
seqres_to_structure: Dict; for each chain_id contains a mapping between
|
||||
SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition, 1: ResidueAtPosition, ...}}
|
||||
raw_string: The raw string used to construct the MmcifObject.
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
header: PdbHeader
|
||||
structure: PdbStructure
|
||||
chain_to_seqres: Mapping[ChainId, SeqRes]
|
||||
seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]]
|
||||
raw_string: Any
|
||||
mmcif_to_author_chain_id: Mapping[ChainId, ChainId]
|
||||
valid_chains: Mapping[ChainId, str]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ParsingResult:
|
||||
"""Returned by the parse function.
|
||||
|
||||
Contains:
|
||||
mmcif_object: A MmcifObject, may be None if no chain could be successfully
|
||||
parsed.
|
||||
errors: A dict mapping (file_id, chain_id) to any exception generated.
|
||||
"""
|
||||
|
||||
mmcif_object: Optional[MmcifObject]
|
||||
errors: Mapping[Tuple[str, str], Any]
|
||||
|
||||
|
||||
class ParseError(Exception):
|
||||
"""An error indicating that an mmCIF file could not be parsed."""
|
||||
|
||||
|
||||
def mmcif_loop_to_list(prefix: str,
|
||||
parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]:
|
||||
"""Extracts loop associated with a prefix from mmCIF data as a list.
|
||||
|
||||
Reference for loop_ in mmCIF:
|
||||
http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html
|
||||
|
||||
Args:
|
||||
prefix: Prefix shared by each of the data items in the loop.
|
||||
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
|
||||
_entity_poly_seq.mon_id. Should include the trailing period.
|
||||
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
|
||||
parser.
|
||||
|
||||
Returns:
|
||||
Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.
|
||||
"""
|
||||
cols = []
|
||||
data = []
|
||||
for key, value in parsed_info.items():
|
||||
if key.startswith(prefix):
|
||||
cols.append(key)
|
||||
data.append(value)
|
||||
|
||||
assert all([
|
||||
len(xs) == len(data[0]) for xs in data
|
||||
]), ('mmCIF error: Not all loops are the same length: %s' % cols)
|
||||
|
||||
return [dict(zip(cols, xs)) for xs in zip(*data)]
|
||||
|
||||
|
||||
def mmcif_loop_to_dict(
|
||||
prefix: str,
|
||||
index: str,
|
||||
parsed_info: MmCIFDict,
|
||||
) -> Mapping[str, Mapping[str, str]]:
|
||||
"""Extracts loop associated with a prefix from mmCIF data as a dictionary.
|
||||
|
||||
Args:
|
||||
prefix: Prefix shared by each of the data items in the loop.
|
||||
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
|
||||
_entity_poly_seq.mon_id. Should include the trailing period.
|
||||
index: Which item of loop data should serve as the key.
|
||||
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
|
||||
parser.
|
||||
|
||||
Returns:
|
||||
Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
|
||||
indexed by the index column.
|
||||
"""
|
||||
entries = mmcif_loop_to_list(prefix, parsed_info)
|
||||
return {entry[index]: entry for entry in entries}
|
||||
|
||||
|
||||
@functools.lru_cache(16, typed=False)
|
||||
def fast_parse(*,
|
||||
file_id: str,
|
||||
mmcif_string: str,
|
||||
catch_all_errors: bool = True) -> ParsingResult:
|
||||
"""Entry point, parses an mmcif_string.
|
||||
|
||||
Args:
|
||||
file_id: A string identifier for this file. Should be unique within the
|
||||
collection of files being processed.
|
||||
mmcif_string: Contents of an mmCIF file.
|
||||
catch_all_errors: If True, all exceptions are caught and error messages are
|
||||
returned as part of the ParsingResult. If False exceptions will be allowed
|
||||
to propagate.
|
||||
|
||||
Returns:
|
||||
A ParsingResult.
|
||||
"""
|
||||
errors = {}
|
||||
try:
|
||||
parser = MMCIFParser(QUIET=True)
|
||||
# handle = io.StringIO(mmcif_string)
|
||||
# full_structure = parser.get_structure('', handle)
|
||||
parsed_info = parser._mmcif_dict # pylint:disable=protected-access
|
||||
|
||||
# Ensure all values are lists, even if singletons.
|
||||
for key, value in parsed_info.items():
|
||||
if not isinstance(value, list):
|
||||
parsed_info[key] = [value]
|
||||
|
||||
header = _get_header(parsed_info)
|
||||
|
||||
# Determine the protein chains, and their start numbers according to the
|
||||
# internal mmCIF numbering scheme (likely but not guaranteed to be 1).
|
||||
valid_chains = _get_protein_chains(parsed_info=parsed_info)
|
||||
if not valid_chains:
|
||||
return ParsingResult(
|
||||
None, {(file_id, ''): 'No protein chains found in this file.'})
|
||||
|
||||
mmcif_to_author_chain_id = {}
|
||||
# seq_to_structure_mappings = {}
|
||||
for atom in _get_atom_site_list(parsed_info):
|
||||
if atom.model_num != '1':
|
||||
# We only process the first model at the moment.
|
||||
continue
|
||||
mmcif_to_author_chain_id[
|
||||
atom.mmcif_chain_id] = atom.author_chain_id
|
||||
|
||||
mmcif_object = MmcifObject(
|
||||
file_id=file_id,
|
||||
header=header,
|
||||
structure=None,
|
||||
chain_to_seqres=None,
|
||||
seqres_to_structure=None,
|
||||
raw_string=parsed_info,
|
||||
mmcif_to_author_chain_id=mmcif_to_author_chain_id,
|
||||
valid_chains=valid_chains,
|
||||
)
|
||||
|
||||
return ParsingResult(mmcif_object=mmcif_object, errors=errors)
|
||||
except Exception as e: # pylint:disable=broad-except
|
||||
errors[(file_id, '')] = e
|
||||
if not catch_all_errors:
|
||||
raise
|
||||
return ParsingResult(mmcif_object=None, errors=errors)
|
||||
|
||||
|
||||
@functools.lru_cache(16, typed=False)
|
||||
def parse(*,
|
||||
file_id: str,
|
||||
mmcif_string: str,
|
||||
catch_all_errors: bool = True) -> ParsingResult:
|
||||
"""Entry point, parses an mmcif_string.
|
||||
|
||||
Args:
|
||||
file_id: A string identifier for this file. Should be unique within the
|
||||
collection of files being processed.
|
||||
mmcif_string: Contents of an mmCIF file.
|
||||
catch_all_errors: If True, all exceptions are caught and error messages are
|
||||
returned as part of the ParsingResult. If False exceptions will be allowed
|
||||
to propagate.
|
||||
|
||||
Returns:
|
||||
A ParsingResult.
|
||||
"""
|
||||
errors = {}
|
||||
try:
|
||||
parser = PDB.MMCIFParser(QUIET=True)
|
||||
handle = io.StringIO(mmcif_string)
|
||||
full_structure = parser.get_structure('', handle)
|
||||
first_model_structure = _get_first_model(full_structure)
|
||||
# Extract the _mmcif_dict from the parser, which contains useful fields not
|
||||
# reflected in the Biopython structure.
|
||||
parsed_info = parser._mmcif_dict # pylint:disable=protected-access
|
||||
|
||||
# Ensure all values are lists, even if singletons.
|
||||
for key, value in parsed_info.items():
|
||||
if not isinstance(value, list):
|
||||
parsed_info[key] = [value]
|
||||
|
||||
header = _get_header(parsed_info)
|
||||
|
||||
# Determine the protein chains, and their start numbers according to the
|
||||
# internal mmCIF numbering scheme (likely but not guaranteed to be 1).
|
||||
valid_chains = _get_protein_chains(parsed_info=parsed_info)
|
||||
if not valid_chains:
|
||||
return ParsingResult(
|
||||
None, {(file_id, ''): 'No protein chains found in this file.'})
|
||||
seq_start_num = {
|
||||
chain_id: min([monomer.num for monomer in seq])
|
||||
for chain_id, seq in valid_chains.items()
|
||||
}
|
||||
|
||||
# Loop over the atoms for which we have coordinates. Populate two mappings:
|
||||
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
|
||||
# the authors / Biopython).
|
||||
# -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition).
|
||||
mmcif_to_author_chain_id = {}
|
||||
seq_to_structure_mappings = {}
|
||||
for atom in _get_atom_site_list(parsed_info):
|
||||
if atom.model_num != '1':
|
||||
# We only process the first model at the moment.
|
||||
continue
|
||||
|
||||
mmcif_to_author_chain_id[
|
||||
atom.mmcif_chain_id] = atom.author_chain_id
|
||||
|
||||
if atom.mmcif_chain_id in valid_chains:
|
||||
hetflag = ' '
|
||||
if atom.hetatm_atom == 'HETATM':
|
||||
# Water atoms are assigned a special hetflag of W in Biopython. We
|
||||
# need to do the same, so that this hetflag can be used to fetch
|
||||
# a residue from the Biopython structure by id.
|
||||
if atom.residue_name in ('HOH', 'WAT'):
|
||||
hetflag = 'W'
|
||||
else:
|
||||
hetflag = 'H_' + atom.residue_name
|
||||
insertion_code = atom.insertion_code
|
||||
if not _is_set(atom.insertion_code):
|
||||
insertion_code = ' '
|
||||
position = ResiduePosition(
|
||||
chain_id=atom.author_chain_id,
|
||||
residue_number=int(atom.author_seq_num),
|
||||
insertion_code=insertion_code,
|
||||
)
|
||||
seq_idx = int(
|
||||
atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
|
||||
current = seq_to_structure_mappings.get(
|
||||
atom.author_chain_id, {})
|
||||
current[seq_idx] = ResidueAtPosition(
|
||||
position=position,
|
||||
name=atom.residue_name,
|
||||
is_missing=False,
|
||||
hetflag=hetflag,
|
||||
)
|
||||
seq_to_structure_mappings[atom.author_chain_id] = current
|
||||
|
||||
# Add missing residue information to seq_to_structure_mappings.
|
||||
for chain_id, seq_info in valid_chains.items():
|
||||
author_chain = mmcif_to_author_chain_id[chain_id]
|
||||
current_mapping = seq_to_structure_mappings[author_chain]
|
||||
for idx, monomer in enumerate(seq_info):
|
||||
if idx not in current_mapping:
|
||||
current_mapping[idx] = ResidueAtPosition(
|
||||
position=None,
|
||||
name=monomer.id,
|
||||
is_missing=True,
|
||||
hetflag=' ')
|
||||
|
||||
author_chain_to_sequence = {}
|
||||
for chain_id, seq_info in valid_chains.items():
|
||||
author_chain = mmcif_to_author_chain_id[chain_id]
|
||||
seq = []
|
||||
for monomer in seq_info:
|
||||
code = SCOPData.protein_letters_3to1.get(monomer.id, 'X')
|
||||
seq.append(code if len(code) == 1 else 'X')
|
||||
seq = ''.join(seq)
|
||||
author_chain_to_sequence[author_chain] = seq
|
||||
|
||||
mmcif_object = MmcifObject(
|
||||
file_id=file_id,
|
||||
header=header,
|
||||
structure=first_model_structure,
|
||||
chain_to_seqres=author_chain_to_sequence,
|
||||
seqres_to_structure=seq_to_structure_mappings,
|
||||
raw_string=parsed_info,
|
||||
mmcif_to_author_chain_id=mmcif_to_author_chain_id,
|
||||
valid_chains=valid_chains,
|
||||
)
|
||||
|
||||
return ParsingResult(mmcif_object=mmcif_object, errors=errors)
|
||||
except Exception as e: # pylint:disable=broad-except
|
||||
errors[(file_id, '')] = e
|
||||
if not catch_all_errors:
|
||||
raise
|
||||
return ParsingResult(mmcif_object=None, errors=errors)
|
||||
|
||||
|
||||
def _get_first_model(structure: PdbStructure) -> PdbStructure:
|
||||
"""Returns the first model in a Biopython structure."""
|
||||
return next(structure.get_models())
|
||||
|
||||
|
||||
_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21
|
||||
|
||||
|
||||
def get_release_date(parsed_info: MmCIFDict) -> str:
|
||||
"""Returns the oldest revision date."""
|
||||
revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date']
|
||||
return min(revision_dates)
|
||||
|
||||
|
||||
def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
|
||||
"""Returns a basic header containing method, release date and resolution."""
|
||||
header = {}
|
||||
|
||||
experiments = mmcif_loop_to_list('_exptl.', parsed_info)
|
||||
header['structure_method'] = ','.join(
|
||||
[experiment['_exptl.method'].lower() for experiment in experiments])
|
||||
|
||||
# Note: The release_date here corresponds to the oldest revision. We prefer to
|
||||
# use this for dataset filtering over the deposition_date.
|
||||
if '_pdbx_audit_revision_history.revision_date' in parsed_info:
|
||||
header['release_date'] = get_release_date(parsed_info)
|
||||
else:
|
||||
logging.warning('Could not determine release_date: %s',
|
||||
parsed_info['_entry.id'])
|
||||
|
||||
header['resolution'] = 0.00
|
||||
for res_key in (
|
||||
'_refine.ls_d_res_high',
|
||||
'_em_3d_reconstruction.resolution',
|
||||
'_reflns.d_resolution_high',
|
||||
):
|
||||
if res_key in parsed_info:
|
||||
try:
|
||||
raw_resolution = parsed_info[res_key][0]
|
||||
header['resolution'] = float(raw_resolution)
|
||||
except ValueError:
|
||||
logging.debug('Invalid resolution format: %s',
|
||||
parsed_info[res_key])
|
||||
|
||||
return header
|
||||
|
||||
|
||||
def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
|
||||
"""Returns list of atom sites; contains data not present in the structure."""
|
||||
return [
|
||||
AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension
|
||||
parsed_info['_atom_site.label_comp_id'],
|
||||
parsed_info['_atom_site.auth_asym_id'],
|
||||
parsed_info['_atom_site.label_asym_id'],
|
||||
parsed_info['_atom_site.auth_seq_id'],
|
||||
parsed_info['_atom_site.label_seq_id'],
|
||||
parsed_info['_atom_site.pdbx_PDB_ins_code'],
|
||||
parsed_info['_atom_site.group_PDB'],
|
||||
parsed_info['_atom_site.pdbx_PDB_model_num'],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _get_protein_chains(
|
||||
*, parsed_info: Mapping[str,
|
||||
Any]) -> Mapping[ChainId, Sequence[Monomer]]:
|
||||
"""Extracts polymer information for protein chains only.
|
||||
|
||||
Args:
|
||||
parsed_info: _mmcif_dict produced by the Biopython parser.
|
||||
|
||||
Returns:
|
||||
A dict mapping mmcif chain id to a list of Monomers.
|
||||
"""
|
||||
# Get polymer information for each entity in the structure.
|
||||
entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info)
|
||||
|
||||
polymers = collections.defaultdict(list)
|
||||
for entity_poly_seq in entity_poly_seqs:
|
||||
polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append(
|
||||
Monomer(
|
||||
id=entity_poly_seq['_entity_poly_seq.mon_id'],
|
||||
num=int(entity_poly_seq['_entity_poly_seq.num']),
|
||||
))
|
||||
|
||||
# Get chemical compositions. Will allow us to identify which of these polymers
|
||||
# are proteins.
|
||||
chem_comps = mmcif_loop_to_dict('_chem_comp.', '_chem_comp.id',
|
||||
parsed_info)
|
||||
|
||||
# Get chains information for each entity. Necessary so that we can return a
|
||||
# dict keyed on chain id rather than entity.
|
||||
struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info)
|
||||
|
||||
entity_to_mmcif_chains = collections.defaultdict(list)
|
||||
for struct_asym in struct_asyms:
|
||||
chain_id = struct_asym['_struct_asym.id']
|
||||
entity_id = struct_asym['_struct_asym.entity_id']
|
||||
entity_to_mmcif_chains[entity_id].append(chain_id)
|
||||
|
||||
# Identify and return the valid protein chains.
|
||||
valid_chains = {}
|
||||
for entity_id, seq_info in polymers.items():
|
||||
chain_ids = entity_to_mmcif_chains[entity_id]
|
||||
|
||||
# Reject polymers without any peptide-like components, such as DNA/RNA.
|
||||
if any([
|
||||
'peptide' in chem_comps[monomer.id]['_chem_comp.type']
|
||||
for monomer in seq_info
|
||||
]):
|
||||
for chain_id in chain_ids:
|
||||
valid_chains[chain_id] = seq_info
|
||||
return valid_chains
|
||||
|
||||
|
||||
def _is_set(data: str) -> bool:
|
||||
"""Returns False if data is a special mmCIF character indicating 'unset'."""
|
||||
return data not in ('.', '?')
|
||||
88
modelscope/models/science/unifold/msa/msa_identifiers.py
Normal file
88
modelscope/models/science/unifold/msa/msa_identifiers.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for extracting identifiers from MSA sequence descriptions."""
|
||||
|
||||
import dataclasses
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
# Sequences coming from UniProtKB database come in the
|
||||
# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE`
|
||||
# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively).
|
||||
_UNIPROT_PATTERN = re.compile(
|
||||
r"""
|
||||
^
|
||||
# UniProtKB/TrEMBL or UniProtKB/Swiss-Prot
|
||||
(?:tr|sp)
|
||||
\|
|
||||
# A primary accession number of the UniProtKB entry.
|
||||
(?P<AccessionIdentifier>[A-Za-z0-9]{6,10})
|
||||
# Occasionally there is a _0 or _1 isoform suffix, which we ignore.
|
||||
(?:_\d)?
|
||||
\|
|
||||
# TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic
|
||||
# protein ID code.
|
||||
(?:[A-Za-z0-9]+)
|
||||
_
|
||||
# A mnemonic species identification code.
|
||||
(?P<SpeciesIdentifier>([A-Za-z0-9]){1,5})
|
||||
# Small BFD uses a final value after an underscore, which we ignore.
|
||||
(?:_\d+)?
|
||||
$
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Identifiers:
|
||||
species_id: str = ''
|
||||
|
||||
|
||||
def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
|
||||
"""Gets accession id and species from an msa sequence identifier.
|
||||
|
||||
The sequence identifier has the format specified by
|
||||
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
|
||||
An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE`
|
||||
|
||||
Args:
|
||||
msa_sequence_identifier: a sequence identifier.
|
||||
|
||||
Returns:
|
||||
An `Identifiers` instance with a species_id. These
|
||||
can be empty in the case where no identifier was found.
|
||||
"""
|
||||
matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip())
|
||||
if matches:
|
||||
return Identifiers(species_id=matches.group('SpeciesIdentifier'))
|
||||
return Identifiers()
|
||||
|
||||
|
||||
def _extract_sequence_identifier(description: str) -> Optional[str]:
|
||||
"""Extracts sequence identifier from description. Returns None if no match."""
|
||||
split_description = description.split()
|
||||
if split_description:
|
||||
return split_description[0].partition('/')[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_identifiers(description: str) -> Identifiers:
|
||||
"""Computes extra MSA features from the description."""
|
||||
sequence_identifier = _extract_sequence_identifier(description)
|
||||
if sequence_identifier is None:
|
||||
return Identifiers()
|
||||
else:
|
||||
return _parse_sequence_identifier(sequence_identifier)
|
||||
627
modelscope/models/science/unifold/msa/parsers.py
Normal file
627
modelscope/models/science/unifold/msa/parsers.py
Normal file
@@ -0,0 +1,627 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Functions for parsing various file formats."""
|
||||
import collections
|
||||
import dataclasses
|
||||
import itertools
|
||||
import re
|
||||
import string
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
||||
|
||||
DeletionMatrix = Sequence[Sequence[int]]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Msa:
|
||||
"""Class representing a parsed MSA file."""
|
||||
|
||||
sequences: Sequence[str]
|
||||
deletion_matrix: DeletionMatrix
|
||||
descriptions: Sequence[str]
|
||||
|
||||
def __post_init__(self):
|
||||
if not (len(self.sequences) == len(self.deletion_matrix) == len(
|
||||
self.descriptions)):
|
||||
raise ValueError(
|
||||
'All fields for an MSA must have the same length. '
|
||||
f'Got {len(self.sequences)} sequences, '
|
||||
f'{len(self.deletion_matrix)} rows in the deletion matrix and '
|
||||
f'{len(self.descriptions)} descriptions.')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequences)
|
||||
|
||||
def truncate(self, max_seqs: int):
|
||||
return Msa(
|
||||
sequences=self.sequences[:max_seqs],
|
||||
deletion_matrix=self.deletion_matrix[:max_seqs],
|
||||
descriptions=self.descriptions[:max_seqs],
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TemplateHit:
|
||||
"""Class representing a template hit."""
|
||||
|
||||
index: int
|
||||
name: str
|
||||
aligned_cols: int
|
||||
sum_probs: Optional[float]
|
||||
query: str
|
||||
hit_sequence: str
|
||||
indices_query: List[int]
|
||||
indices_hit: List[int]
|
||||
|
||||
|
||||
def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
"""Parses FASTA string and returns list of strings with amino-acid sequences.
|
||||
|
||||
Arguments:
|
||||
fasta_string: The string contents of a FASTA file.
|
||||
|
||||
Returns:
|
||||
A tuple of two lists:
|
||||
* A list of sequences.
|
||||
* A list of sequence descriptions taken from the comment lines. In the
|
||||
same order as the sequences.
|
||||
"""
|
||||
sequences = []
|
||||
descriptions = []
|
||||
index = -1
|
||||
for line in fasta_string.splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith('>'):
|
||||
index += 1
|
||||
descriptions.append(line[1:]) # Remove the '>' at the beginning.
|
||||
sequences.append('')
|
||||
continue
|
||||
elif not line:
|
||||
continue # Skip blank lines.
|
||||
sequences[index] += line
|
||||
|
||||
return sequences, descriptions
|
||||
|
||||
|
||||
def parse_stockholm(stockholm_string: str) -> Msa:
|
||||
"""Parses sequences and deletion matrix from stockholm format alignment.
|
||||
|
||||
Args:
|
||||
stockholm_string: The string contents of a stockholm file. The first
|
||||
sequence in the file should be the query sequence.
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
* A list of sequences that have been aligned to the query. These
|
||||
might contain duplicates.
|
||||
* The deletion matrix for the alignment as a list of lists. The element
|
||||
at `deletion_matrix[i][j]` is the number of residues deleted from
|
||||
the aligned sequence i at residue position j.
|
||||
* The names of the targets matched, including the jackhmmer subsequence
|
||||
suffix.
|
||||
"""
|
||||
name_to_sequence = collections.OrderedDict()
|
||||
for line in stockholm_string.splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith(('#', '//')):
|
||||
continue
|
||||
name, sequence = line.split()
|
||||
if name not in name_to_sequence:
|
||||
name_to_sequence[name] = ''
|
||||
name_to_sequence[name] += sequence
|
||||
|
||||
msa = []
|
||||
deletion_matrix = []
|
||||
|
||||
query = ''
|
||||
keep_columns = []
|
||||
for seq_index, sequence in enumerate(name_to_sequence.values()):
|
||||
if seq_index == 0:
|
||||
# Gather the columns with gaps from the query
|
||||
query = sequence
|
||||
keep_columns = [i for i, res in enumerate(query) if res != '-']
|
||||
|
||||
# Remove the columns with gaps in the query from all sequences.
|
||||
aligned_sequence = ''.join([sequence[c] for c in keep_columns])
|
||||
|
||||
msa.append(aligned_sequence)
|
||||
|
||||
# Count the number of deletions w.r.t. query.
|
||||
deletion_vec = []
|
||||
deletion_count = 0
|
||||
for seq_res, query_res in zip(sequence, query):
|
||||
if seq_res != '-' or query_res != '-':
|
||||
if query_res == '-':
|
||||
deletion_count += 1
|
||||
else:
|
||||
deletion_vec.append(deletion_count)
|
||||
deletion_count = 0
|
||||
deletion_matrix.append(deletion_vec)
|
||||
|
||||
return Msa(
|
||||
sequences=msa,
|
||||
deletion_matrix=deletion_matrix,
|
||||
descriptions=list(name_to_sequence.keys()),
|
||||
)
|
||||
|
||||
|
||||
def parse_a3m(a3m_string: str) -> Msa:
|
||||
"""Parses sequences and deletion matrix from a3m format alignment.
|
||||
|
||||
Args:
|
||||
a3m_string: The string contents of a a3m file. The first sequence in the
|
||||
file should be the query sequence.
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
* A list of sequences that have been aligned to the query. These
|
||||
might contain duplicates.
|
||||
* The deletion matrix for the alignment as a list of lists. The element
|
||||
at `deletion_matrix[i][j]` is the number of residues deleted from
|
||||
the aligned sequence i at residue position j.
|
||||
* A list of descriptions, one per sequence, from the a3m file.
|
||||
"""
|
||||
sequences, descriptions = parse_fasta(a3m_string)
|
||||
deletion_matrix = []
|
||||
for msa_sequence in sequences:
|
||||
deletion_vec = []
|
||||
deletion_count = 0
|
||||
for j in msa_sequence:
|
||||
if j.islower():
|
||||
deletion_count += 1
|
||||
else:
|
||||
deletion_vec.append(deletion_count)
|
||||
deletion_count = 0
|
||||
deletion_matrix.append(deletion_vec)
|
||||
|
||||
# Make the MSA matrix out of aligned (deletion-free) sequences.
|
||||
deletion_table = str.maketrans('', '', string.ascii_lowercase)
|
||||
aligned_sequences = [s.translate(deletion_table) for s in sequences]
|
||||
return Msa(
|
||||
sequences=aligned_sequences,
|
||||
deletion_matrix=deletion_matrix,
|
||||
descriptions=descriptions,
|
||||
)
|
||||
|
||||
|
||||
def _convert_sto_seq_to_a3m(query_non_gaps: Sequence[bool],
|
||||
sto_seq: str) -> Iterable[str]:
|
||||
for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
|
||||
if is_query_res_non_gap:
|
||||
yield sequence_res
|
||||
elif sequence_res != '-':
|
||||
yield sequence_res.lower()
|
||||
|
||||
|
||||
def convert_stockholm_to_a3m(
|
||||
stockholm_format: str,
|
||||
max_sequences: Optional[int] = None,
|
||||
remove_first_row_gaps: bool = True,
|
||||
) -> str:
|
||||
"""Converts MSA in Stockholm format to the A3M format."""
|
||||
descriptions = {}
|
||||
sequences = {}
|
||||
reached_max_sequences = False
|
||||
|
||||
for line in stockholm_format.splitlines():
|
||||
reached_max_sequences = max_sequences and len(
|
||||
sequences) >= max_sequences
|
||||
if line.strip() and not line.startswith(('#', '//')):
|
||||
# Ignore blank lines, markup and end symbols - remainder are alignment
|
||||
# sequence parts.
|
||||
seqname, aligned_seq = line.split(maxsplit=1)
|
||||
if seqname not in sequences:
|
||||
if reached_max_sequences:
|
||||
continue
|
||||
sequences[seqname] = ''
|
||||
sequences[seqname] += aligned_seq
|
||||
|
||||
for line in stockholm_format.splitlines():
|
||||
if line[:4] == '#=GS':
|
||||
# Description row - example format is:
|
||||
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
|
||||
columns = line.split(maxsplit=3)
|
||||
seqname, feature = columns[1:3]
|
||||
value = columns[3] if len(columns) == 4 else ''
|
||||
if feature != 'DE':
|
||||
continue
|
||||
if reached_max_sequences and seqname not in sequences:
|
||||
continue
|
||||
descriptions[seqname] = value
|
||||
if len(descriptions) == len(sequences):
|
||||
break
|
||||
|
||||
# Convert sto format to a3m line by line
|
||||
a3m_sequences = {}
|
||||
if remove_first_row_gaps:
|
||||
# query_sequence is assumed to be the first sequence
|
||||
query_sequence = next(iter(sequences.values()))
|
||||
query_non_gaps = [res != '-' for res in query_sequence]
|
||||
for seqname, sto_sequence in sequences.items():
|
||||
# Dots are optional in a3m format and are commonly removed.
|
||||
out_sequence = sto_sequence.replace('.', '')
|
||||
if remove_first_row_gaps:
|
||||
out_sequence = ''.join(
|
||||
_convert_sto_seq_to_a3m(query_non_gaps, out_sequence))
|
||||
a3m_sequences[seqname] = out_sequence
|
||||
|
||||
fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
|
||||
for k in a3m_sequences)
|
||||
return '\n'.join(fasta_chunks) + '\n' # Include terminating newline.
|
||||
|
||||
|
||||
def _keep_line(line: str, seqnames: Set[str]) -> bool:
|
||||
"""Function to decide which lines to keep."""
|
||||
if not line.strip():
|
||||
return True
|
||||
if line.strip() == '//': # End tag
|
||||
return True
|
||||
if line.startswith('# STOCKHOLM'): # Start tag
|
||||
return True
|
||||
if line.startswith('#=GC RF'): # Reference Annotation Line
|
||||
return True
|
||||
if line[:4] == '#=GS': # Description lines - keep if sequence in list.
|
||||
_, seqname, _ = line.split(maxsplit=2)
|
||||
return seqname in seqnames
|
||||
elif line.startswith('#'): # Other markup - filter out
|
||||
return False
|
||||
else: # Alignment data - keep if sequence in list.
|
||||
seqname = line.partition(' ')[0]
|
||||
return seqname in seqnames
|
||||
|
||||
|
||||
def truncate_stockholm_msa(stockholm_msa: str, max_sequences: int) -> str:
|
||||
"""Truncates a stockholm file to a maximum number of sequences."""
|
||||
seqnames = set()
|
||||
filtered_lines = []
|
||||
for line in stockholm_msa.splitlines():
|
||||
if line.strip() and not line.startswith(('#', '//')):
|
||||
# Ignore blank lines, markup and end symbols - remainder are alignment
|
||||
# sequence parts.
|
||||
seqname = line.partition(' ')[0]
|
||||
seqnames.add(seqname)
|
||||
if len(seqnames) >= max_sequences:
|
||||
break
|
||||
|
||||
for line in stockholm_msa.splitlines():
|
||||
if _keep_line(line, seqnames):
|
||||
filtered_lines.append(line)
|
||||
|
||||
return '\n'.join(filtered_lines) + '\n'
|
||||
|
||||
|
||||
def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str:
|
||||
"""Removes empty columns (dashes-only) from a Stockholm MSA."""
|
||||
processed_lines = {}
|
||||
unprocessed_lines = {}
|
||||
for i, line in enumerate(stockholm_msa.splitlines()):
|
||||
if line.startswith('#=GC RF'):
|
||||
reference_annotation_i = i
|
||||
reference_annotation_line = line
|
||||
# Reached the end of this chunk of the alignment. Process chunk.
|
||||
_, _, first_alignment = line.rpartition(' ')
|
||||
mask = []
|
||||
for j in range(len(first_alignment)):
|
||||
for _, unprocessed_line in unprocessed_lines.items():
|
||||
prefix, _, alignment = unprocessed_line.rpartition(' ')
|
||||
if alignment[j] != '-':
|
||||
mask.append(True)
|
||||
break
|
||||
else: # Every row contained a hyphen - empty column.
|
||||
mask.append(False)
|
||||
# Add reference annotation for processing with mask.
|
||||
unprocessed_lines[
|
||||
reference_annotation_i] = reference_annotation_line
|
||||
|
||||
if not any(
|
||||
mask
|
||||
): # All columns were empty. Output empty lines for chunk.
|
||||
for line_index in unprocessed_lines:
|
||||
processed_lines[line_index] = ''
|
||||
else:
|
||||
for line_index, unprocessed_line in unprocessed_lines.items():
|
||||
prefix, _, alignment = unprocessed_line.rpartition(' ')
|
||||
masked_alignment = ''.join(
|
||||
itertools.compress(alignment, mask))
|
||||
processed_lines[
|
||||
line_index] = f'{prefix} {masked_alignment}'
|
||||
|
||||
# Clear raw_alignments.
|
||||
unprocessed_lines = {}
|
||||
elif line.strip() and not line.startswith(('#', '//')):
|
||||
unprocessed_lines[i] = line
|
||||
else:
|
||||
processed_lines[i] = line
|
||||
return '\n'.join((processed_lines[i] for i in range(len(processed_lines))))
|
||||
|
||||
|
||||
def deduplicate_stockholm_msa(stockholm_msa: str) -> str:
|
||||
"""Remove duplicate sequences (ignoring insertions wrt query)."""
|
||||
sequence_dict = collections.defaultdict(str)
|
||||
|
||||
# First we must extract all sequences from the MSA.
|
||||
for line in stockholm_msa.splitlines():
|
||||
# Only consider the alignments - ignore reference annotation, empty lines,
|
||||
# descriptions or markup.
|
||||
if line.strip() and not line.startswith(('#', '//')):
|
||||
line = line.strip()
|
||||
seqname, alignment = line.split()
|
||||
sequence_dict[seqname] += alignment
|
||||
|
||||
seen_sequences = set()
|
||||
seqnames = set()
|
||||
# First alignment is the query.
|
||||
query_align = next(iter(sequence_dict.values()))
|
||||
mask = [c != '-' for c in query_align] # Mask is False for insertions.
|
||||
for seqname, alignment in sequence_dict.items():
|
||||
# Apply mask to remove all insertions from the string.
|
||||
masked_alignment = ''.join(itertools.compress(alignment, mask))
|
||||
if masked_alignment in seen_sequences:
|
||||
continue
|
||||
else:
|
||||
seen_sequences.add(masked_alignment)
|
||||
seqnames.add(seqname)
|
||||
|
||||
filtered_lines = []
|
||||
for line in stockholm_msa.splitlines():
|
||||
if _keep_line(line, seqnames):
|
||||
filtered_lines.append(line)
|
||||
|
||||
return '\n'.join(filtered_lines) + '\n'
|
||||
|
||||
|
||||
def _get_hhr_line_regex_groups(regex_pattern: str,
|
||||
line: str) -> Sequence[Optional[str]]:
|
||||
match = re.match(regex_pattern, line)
|
||||
if match is None:
|
||||
raise RuntimeError(f'Could not parse query line {line}')
|
||||
return match.groups()
|
||||
|
||||
|
||||
def _update_hhr_residue_indices_list(sequence: str, start_index: int,
|
||||
indices_list: List[int]):
|
||||
"""Computes the relative indices for each residue with respect to the original sequence."""
|
||||
counter = start_index
|
||||
for symbol in sequence:
|
||||
if symbol == '-':
|
||||
indices_list.append(-1)
|
||||
else:
|
||||
indices_list.append(counter)
|
||||
counter += 1
|
||||
|
||||
|
||||
def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
|
||||
"""Parses the detailed HMM HMM comparison section for a single Hit.
|
||||
|
||||
This works on .hhr files generated from both HHBlits and HHSearch.
|
||||
|
||||
Args:
|
||||
detailed_lines: A list of lines from a single comparison section between 2
|
||||
sequences (which each have their own HMM's)
|
||||
|
||||
Returns:
|
||||
A dictionary with the information from that detailed comparison section
|
||||
|
||||
Raises:
|
||||
RuntimeError: If a certain line cannot be processed
|
||||
"""
|
||||
# Parse first 2 lines.
|
||||
number_of_hit = int(detailed_lines[0].split()[-1])
|
||||
name_hit = detailed_lines[1][1:]
|
||||
|
||||
# Parse the summary line.
|
||||
pattern = (
|
||||
'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t'
|
||||
' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t '
|
||||
']*Template_Neff=(.*)')
|
||||
match = re.match(pattern, detailed_lines[2])
|
||||
if match is None:
|
||||
raise RuntimeError(
|
||||
'Could not parse section: %s. Expected this: \n%s to contain summary.'
|
||||
% (detailed_lines, detailed_lines[2]))
|
||||
(_, _, _, aligned_cols, _, _, sum_probs,
|
||||
_) = [float(x) for x in match.groups()]
|
||||
|
||||
# The next section reads the detailed comparisons. These are in a 'human
|
||||
# readable' format which has a fixed length. The strategy employed is to
|
||||
# assume that each block starts with the query sequence line, and to parse
|
||||
# that with a regexp in order to deduce the fixed length used for that block.
|
||||
query = ''
|
||||
hit_sequence = ''
|
||||
indices_query = []
|
||||
indices_hit = []
|
||||
length_block = None
|
||||
|
||||
for line in detailed_lines[3:]:
|
||||
# Parse the query sequence line
|
||||
if (line.startswith('Q ') and not line.startswith('Q ss_dssp')
|
||||
and not line.startswith('Q ss_pred')
|
||||
and not line.startswith('Q Consensus')):
|
||||
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse
|
||||
# everything after that.
|
||||
# start sequence end total_sequence_length
|
||||
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)'
|
||||
groups = _get_hhr_line_regex_groups(patt, line[17:])
|
||||
|
||||
# Get the length of the parsed block using the start and finish indices,
|
||||
# and ensure it is the same as the actual block length.
|
||||
start = int(groups[0]) - 1 # Make index zero based.
|
||||
delta_query = groups[1]
|
||||
end = int(groups[2])
|
||||
num_insertions = len([x for x in delta_query if x == '-'])
|
||||
length_block = end - start + num_insertions
|
||||
assert length_block == len(delta_query)
|
||||
|
||||
# Update the query sequence and indices list.
|
||||
query += delta_query
|
||||
_update_hhr_residue_indices_list(delta_query, start, indices_query)
|
||||
|
||||
elif line.startswith('T '):
|
||||
# Parse the hit sequence.
|
||||
if (not line.startswith('T ss_dssp')
|
||||
and not line.startswith('T ss_pred')
|
||||
and not line.startswith('T Consensus')):
|
||||
# Thus the first 17 characters must be 'T <hit_name> ', and we can
|
||||
# parse everything after that.
|
||||
# start sequence end total_sequence_length
|
||||
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)'
|
||||
groups = _get_hhr_line_regex_groups(patt, line[17:])
|
||||
start = int(groups[0]) - 1 # Make index zero based.
|
||||
delta_hit_sequence = groups[1]
|
||||
assert length_block == len(delta_hit_sequence)
|
||||
|
||||
# Update the hit sequence and indices list.
|
||||
hit_sequence += delta_hit_sequence
|
||||
_update_hhr_residue_indices_list(delta_hit_sequence, start,
|
||||
indices_hit)
|
||||
|
||||
return TemplateHit(
|
||||
index=number_of_hit,
|
||||
name=name_hit,
|
||||
aligned_cols=int(aligned_cols),
|
||||
sum_probs=sum_probs,
|
||||
query=query,
|
||||
hit_sequence=hit_sequence,
|
||||
indices_query=indices_query,
|
||||
indices_hit=indices_hit,
|
||||
)
|
||||
|
||||
|
||||
def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
|
||||
"""Parses the content of an entire HHR file."""
|
||||
lines = hhr_string.splitlines()
|
||||
|
||||
# Each .hhr file starts with a results table, then has a sequence of hit
|
||||
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We
|
||||
# iterate through each paragraph to parse each hit.
|
||||
|
||||
block_starts = [
|
||||
i for i, line in enumerate(lines) if line.startswith('No ')
|
||||
]
|
||||
|
||||
hits = []
|
||||
if block_starts:
|
||||
block_starts.append(len(lines)) # Add the end of the final block.
|
||||
for i in range(len(block_starts) - 1):
|
||||
hits.append(
|
||||
_parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]]))
|
||||
return hits
|
||||
|
||||
|
||||
def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
|
||||
"""Parse target to e-value mapping parsed from Jackhmmer tblout string."""
|
||||
e_values = {'query': 0}
|
||||
lines = [line for line in tblout.splitlines() if line[0] != '#']
|
||||
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
|
||||
# space-delimited. Relevant fields are (1) target name: and
|
||||
# (5) E-value (full sequence) (numbering from 1).
|
||||
for line in lines:
|
||||
fields = line.split()
|
||||
e_value = fields[4]
|
||||
target_name = fields[0]
|
||||
e_values[target_name] = float(e_value)
|
||||
return e_values
|
||||
|
||||
|
||||
def _get_indices(sequence: str, start: int) -> List[int]:
|
||||
"""Returns indices for non-gap/insert residues starting at the given index."""
|
||||
indices = []
|
||||
counter = start
|
||||
for symbol in sequence:
|
||||
# Skip gaps but add a placeholder so that the alignment is preserved.
|
||||
if symbol == '-':
|
||||
indices.append(-1)
|
||||
# Skip deleted residues, but increase the counter.
|
||||
elif symbol.islower():
|
||||
counter += 1
|
||||
# Normal aligned residue. Increase the counter and append to indices.
|
||||
else:
|
||||
indices.append(counter)
|
||||
counter += 1
|
||||
return indices
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class HitMetadata:
|
||||
pdb_id: str
|
||||
chain: str
|
||||
start: int
|
||||
end: int
|
||||
length: int
|
||||
text: str
|
||||
|
||||
|
||||
def _parse_hmmsearch_description(description: str) -> HitMetadata:
|
||||
"""Parses the hmmsearch A3M sequence description line."""
|
||||
# Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text
|
||||
# Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352
|
||||
match = re.match(
|
||||
r'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$',
|
||||
description.strip(),
|
||||
)
|
||||
|
||||
if not match:
|
||||
raise ValueError(f'Could not parse description: "{description}".')
|
||||
|
||||
return HitMetadata(
|
||||
pdb_id=match[1],
|
||||
chain=match[2],
|
||||
start=int(match[3]),
|
||||
end=int(match[4]),
|
||||
length=int(match[5]),
|
||||
text=match[6],
|
||||
)
|
||||
|
||||
|
||||
def parse_hmmsearch_a3m(query_sequence: str,
|
||||
a3m_string: str,
|
||||
skip_first: bool = True) -> Sequence[TemplateHit]:
|
||||
"""Parses an a3m string produced by hmmsearch.
|
||||
|
||||
Args:
|
||||
query_sequence: The query sequence.
|
||||
a3m_string: The a3m string produced by hmmsearch.
|
||||
skip_first: Whether to skip the first sequence in the a3m string.
|
||||
|
||||
Returns:
|
||||
A sequence of `TemplateHit` results.
|
||||
"""
|
||||
# Zip the descriptions and MSAs together, skip the first query sequence.
|
||||
parsed_a3m = list(zip(*parse_fasta(a3m_string)))
|
||||
if skip_first:
|
||||
parsed_a3m = parsed_a3m[1:]
|
||||
|
||||
indices_query = _get_indices(query_sequence, start=0)
|
||||
|
||||
hits = []
|
||||
for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1):
|
||||
if 'mol:protein' not in hit_description:
|
||||
continue # Skip non-protein chains.
|
||||
metadata = _parse_hmmsearch_description(hit_description)
|
||||
# Aligned columns are only the match states.
|
||||
aligned_cols = sum([r.isupper() and r != '-' for r in hit_sequence])
|
||||
indices_hit = _get_indices(hit_sequence, start=metadata.start - 1)
|
||||
|
||||
hit = TemplateHit(
|
||||
index=i,
|
||||
name=f'{metadata.pdb_id}_{metadata.chain}',
|
||||
aligned_cols=aligned_cols,
|
||||
sum_probs=None,
|
||||
query=query_sequence,
|
||||
hit_sequence=hit_sequence.upper(),
|
||||
indices_query=indices_query,
|
||||
indices_hit=indices_hit,
|
||||
)
|
||||
hits.append(hit)
|
||||
|
||||
return hits
|
||||
282
modelscope/models/science/unifold/msa/pipeline.py
Normal file
282
modelscope/models/science/unifold/msa/pipeline.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Functions for building the input features for the unifold model."""
|
||||
|
||||
import os
|
||||
from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from absl import logging
|
||||
|
||||
from modelscope.models.science.unifold.data import residue_constants
|
||||
from modelscope.models.science.unifold.msa import (msa_identifiers, parsers,
|
||||
templates)
|
||||
from modelscope.models.science.unifold.msa.tools import (hhblits, hhsearch,
|
||||
hmmsearch, jackhmmer)
|
||||
|
||||
FeatureDict = MutableMapping[str, np.ndarray]
|
||||
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
|
||||
|
||||
|
||||
def make_sequence_features(sequence: str, description: str,
|
||||
num_res: int) -> FeatureDict:
|
||||
"""Constructs a feature dict of sequence features."""
|
||||
features = {}
|
||||
features['aatype'] = residue_constants.sequence_to_onehot(
|
||||
sequence=sequence,
|
||||
mapping=residue_constants.restype_order_with_x,
|
||||
map_unknown_to_x=True,
|
||||
)
|
||||
features['between_segment_residues'] = np.zeros((num_res, ),
|
||||
dtype=np.int32)
|
||||
features['domain_name'] = np.array([description.encode('utf-8')],
|
||||
dtype=np.object_)
|
||||
features['residue_index'] = np.array(range(num_res), dtype=np.int32)
|
||||
features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32)
|
||||
features['sequence'] = np.array([sequence.encode('utf-8')],
|
||||
dtype=np.object_)
|
||||
return features
|
||||
|
||||
|
||||
def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
|
||||
"""Constructs a feature dict of MSA features."""
|
||||
if not msas:
|
||||
raise ValueError('At least one MSA must be provided.')
|
||||
|
||||
int_msa = []
|
||||
deletion_matrix = []
|
||||
species_ids = []
|
||||
seen_sequences = set()
|
||||
for msa_index, msa in enumerate(msas):
|
||||
if not msa:
|
||||
raise ValueError(
|
||||
f'MSA {msa_index} must contain at least one sequence.')
|
||||
for sequence_index, sequence in enumerate(msa.sequences):
|
||||
if sequence in seen_sequences:
|
||||
continue
|
||||
seen_sequences.add(sequence)
|
||||
int_msa.append(
|
||||
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])
|
||||
deletion_matrix.append(msa.deletion_matrix[sequence_index])
|
||||
identifiers = msa_identifiers.get_identifiers(
|
||||
msa.descriptions[sequence_index])
|
||||
species_ids.append(identifiers.species_id.encode('utf-8'))
|
||||
|
||||
num_res = len(msas[0].sequences[0])
|
||||
num_alignments = len(int_msa)
|
||||
features = {}
|
||||
features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32)
|
||||
features['msa'] = np.array(int_msa, dtype=np.int32)
|
||||
features['num_alignments'] = np.array(
|
||||
[num_alignments] * num_res, dtype=np.int32)
|
||||
features['msa_species_identifiers'] = np.array(
|
||||
species_ids, dtype=np.object_)
|
||||
return features
|
||||
|
||||
|
||||
def run_msa_tool(
|
||||
msa_runner,
|
||||
input_fasta_path: str,
|
||||
msa_out_path: str,
|
||||
msa_format: str,
|
||||
use_precomputed_msas: bool,
|
||||
) -> Mapping[str, Any]:
|
||||
"""Runs an MSA tool, checking if output already exists first."""
|
||||
if not use_precomputed_msas or not os.path.exists(msa_out_path):
|
||||
result = msa_runner.query(input_fasta_path)[0]
|
||||
with open(msa_out_path, 'w') as f:
|
||||
f.write(result[msa_format])
|
||||
else:
|
||||
logging.warning('Reading MSA from file %s', msa_out_path)
|
||||
with open(msa_out_path, 'r') as f:
|
||||
result = {msa_format: f.read()}
|
||||
return result
|
||||
|
||||
|
||||
class DataPipeline:
|
||||
"""Runs the alignment tools and assembles the input features."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
jackhmmer_binary_path: str,
|
||||
hhblits_binary_path: str,
|
||||
uniref90_database_path: str,
|
||||
mgnify_database_path: str,
|
||||
bfd_database_path: Optional[str],
|
||||
uniclust30_database_path: Optional[str],
|
||||
small_bfd_database_path: Optional[str],
|
||||
uniprot_database_path: Optional[str],
|
||||
template_searcher: TemplateSearcher,
|
||||
template_featurizer: templates.TemplateHitFeaturizer,
|
||||
use_small_bfd: bool,
|
||||
mgnify_max_hits: int = 501,
|
||||
uniref_max_hits: int = 10000,
|
||||
use_precomputed_msas: bool = False,
|
||||
):
|
||||
"""Initializes the data pipeline."""
|
||||
self._use_small_bfd = use_small_bfd
|
||||
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
|
||||
binary_path=jackhmmer_binary_path,
|
||||
database_path=uniref90_database_path)
|
||||
if use_small_bfd:
|
||||
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
|
||||
binary_path=jackhmmer_binary_path,
|
||||
database_path=small_bfd_database_path)
|
||||
else:
|
||||
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
|
||||
binary_path=hhblits_binary_path,
|
||||
databases=[bfd_database_path, uniclust30_database_path],
|
||||
)
|
||||
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
|
||||
binary_path=jackhmmer_binary_path,
|
||||
database_path=mgnify_database_path)
|
||||
self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer(
|
||||
binary_path=jackhmmer_binary_path,
|
||||
database_path=uniprot_database_path)
|
||||
self.template_searcher = template_searcher
|
||||
self.template_featurizer = template_featurizer
|
||||
self.mgnify_max_hits = mgnify_max_hits
|
||||
self.uniref_max_hits = uniref_max_hits
|
||||
self.use_precomputed_msas = use_precomputed_msas
|
||||
|
||||
def process(self, input_fasta_path: str,
|
||||
msa_output_dir: str) -> FeatureDict:
|
||||
"""Runs alignment tools on the input sequence and creates features."""
|
||||
with open(input_fasta_path) as f:
|
||||
input_fasta_str = f.read()
|
||||
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
|
||||
if len(input_seqs) != 1:
|
||||
raise ValueError(
|
||||
f'More than one input sequence found in {input_fasta_path}.')
|
||||
input_sequence = input_seqs[0]
|
||||
input_description = input_descs[0]
|
||||
num_res = len(input_sequence)
|
||||
|
||||
uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
|
||||
jackhmmer_uniref90_result = run_msa_tool(
|
||||
self.jackhmmer_uniref90_runner,
|
||||
input_fasta_path,
|
||||
uniref90_out_path,
|
||||
'sto',
|
||||
self.use_precomputed_msas,
|
||||
)
|
||||
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
|
||||
jackhmmer_mgnify_result = run_msa_tool(
|
||||
self.jackhmmer_mgnify_runner,
|
||||
input_fasta_path,
|
||||
mgnify_out_path,
|
||||
'sto',
|
||||
self.use_precomputed_msas,
|
||||
)
|
||||
|
||||
msa_for_templates = jackhmmer_uniref90_result['sto']
|
||||
msa_for_templates = parsers.truncate_stockholm_msa(
|
||||
msa_for_templates, max_sequences=self.uniref_max_hits)
|
||||
msa_for_templates = parsers.deduplicate_stockholm_msa(
|
||||
msa_for_templates)
|
||||
msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(
|
||||
msa_for_templates)
|
||||
|
||||
if self.template_searcher.input_format == 'sto':
|
||||
pdb_templates_result = self.template_searcher.query(
|
||||
msa_for_templates)
|
||||
elif self.template_searcher.input_format == 'a3m':
|
||||
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
|
||||
msa_for_templates)
|
||||
pdb_templates_result = self.template_searcher.query(
|
||||
uniref90_msa_as_a3m)
|
||||
else:
|
||||
raise ValueError('Unrecognized template input format: '
|
||||
f'{self.template_searcher.input_format}')
|
||||
|
||||
pdb_hits_out_path = os.path.join(
|
||||
msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}')
|
||||
with open(pdb_hits_out_path, 'w') as f:
|
||||
f.write(pdb_templates_result)
|
||||
|
||||
uniref90_msa = parsers.parse_stockholm(
|
||||
jackhmmer_uniref90_result['sto'])
|
||||
uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits)
|
||||
mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
|
||||
mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits)
|
||||
|
||||
pdb_template_hits = self.template_searcher.get_template_hits(
|
||||
output_string=pdb_templates_result, input_sequence=input_sequence)
|
||||
|
||||
if self._use_small_bfd:
|
||||
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
|
||||
jackhmmer_small_bfd_result = run_msa_tool(
|
||||
self.jackhmmer_small_bfd_runner,
|
||||
input_fasta_path,
|
||||
bfd_out_path,
|
||||
'sto',
|
||||
self.use_precomputed_msas,
|
||||
)
|
||||
bfd_msa = parsers.parse_stockholm(
|
||||
jackhmmer_small_bfd_result['sto'])
|
||||
else:
|
||||
bfd_out_path = os.path.join(msa_output_dir,
|
||||
'bfd_uniclust_hits.a3m')
|
||||
hhblits_bfd_uniclust_result = run_msa_tool(
|
||||
self.hhblits_bfd_uniclust_runner,
|
||||
input_fasta_path,
|
||||
bfd_out_path,
|
||||
'a3m',
|
||||
self.use_precomputed_msas,
|
||||
)
|
||||
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
|
||||
|
||||
templates_result = self.template_featurizer.get_templates(
|
||||
query_sequence=input_sequence, hits=pdb_template_hits)
|
||||
|
||||
sequence_features = make_sequence_features(
|
||||
sequence=input_sequence,
|
||||
description=input_description,
|
||||
num_res=num_res)
|
||||
|
||||
msa_features = make_msa_features((uniref90_msa, bfd_msa, mgnify_msa))
|
||||
|
||||
logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa))
|
||||
logging.info('BFD MSA size: %d sequences.', len(bfd_msa))
|
||||
logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa))
|
||||
logging.info(
|
||||
'Final (deduplicated) MSA size: %d sequences.',
|
||||
msa_features['num_alignments'][0],
|
||||
)
|
||||
logging.info(
|
||||
'Total number of templates (NB: this can include bad '
|
||||
'templates and is later filtered to top 4): %d.',
|
||||
templates_result.features['template_domain_names'].shape[0],
|
||||
)
|
||||
|
||||
return {
|
||||
**sequence_features,
|
||||
**msa_features,
|
||||
**templates_result.features
|
||||
}
|
||||
|
||||
def process_uniprot(self, input_fasta_path: str,
|
||||
msa_output_dir: str) -> FeatureDict:
|
||||
uniprot_path = os.path.join(msa_output_dir, 'uniprot_hits.sto')
|
||||
uniprot_result = run_msa_tool(
|
||||
self.jackhmmer_uniprot_runner,
|
||||
input_fasta_path,
|
||||
uniprot_path,
|
||||
'sto',
|
||||
self.use_precomputed_msas,
|
||||
)
|
||||
msa = parsers.parse_stockholm(uniprot_result['sto'])
|
||||
msa = msa.truncate(max_seqs=50000)
|
||||
all_seq_dict = make_msa_features([msa])
|
||||
return all_seq_dict
|
||||
1110
modelscope/models/science/unifold/msa/templates.py
Normal file
1110
modelscope/models/science/unifold/msa/templates.py
Normal file
File diff suppressed because it is too large
Load Diff
14
modelscope/models/science/unifold/msa/tools/__init__.py
Normal file
14
modelscope/models/science/unifold/msa/tools/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Python wrappers for third party tools."""
|
||||
170
modelscope/models/science/unifold/msa/tools/hhblits.py
Normal file
170
modelscope/models/science/unifold/msa/tools/hhblits.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Library to run HHblits from Python."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Any, List, Mapping, Optional, Sequence
|
||||
|
||||
from absl import logging
|
||||
|
||||
from . import utils
|
||||
|
||||
_HHBLITS_DEFAULT_P = 20
|
||||
_HHBLITS_DEFAULT_Z = 500
|
||||
|
||||
|
||||
class HHBlits:
|
||||
"""Python wrapper of the HHblits binary."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
binary_path: str,
|
||||
databases: Sequence[str],
|
||||
n_cpu: int = 4,
|
||||
n_iter: int = 3,
|
||||
e_value: float = 0.001,
|
||||
maxseq: int = 1_000_000,
|
||||
realign_max: int = 100_000,
|
||||
maxfilt: int = 100_000,
|
||||
min_prefilter_hits: int = 1000,
|
||||
all_seqs: bool = False,
|
||||
alt: Optional[int] = None,
|
||||
p: int = _HHBLITS_DEFAULT_P,
|
||||
z: int = _HHBLITS_DEFAULT_Z,
|
||||
):
|
||||
"""Initializes the Python HHblits wrapper.
|
||||
|
||||
Args:
|
||||
binary_path: The path to the HHblits executable.
|
||||
databases: A sequence of HHblits database paths. This should be the
|
||||
common prefix for the database files (i.e. up to but not including
|
||||
_hhm.ffindex etc.)
|
||||
n_cpu: The number of CPUs to give HHblits.
|
||||
n_iter: The number of HHblits iterations.
|
||||
e_value: The E-value, see HHblits docs for more details.
|
||||
maxseq: The maximum number of rows in an input alignment. Note that this
|
||||
parameter is only supported in HHBlits version 3.1 and higher.
|
||||
realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
|
||||
maxfilt: Max number of hits allowed to pass the 2nd prefilter.
|
||||
HHblits default: 20000.
|
||||
min_prefilter_hits: Min number of hits to pass prefilter.
|
||||
HHblits default: 100.
|
||||
all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
|
||||
HHblits default: False.
|
||||
alt: Show up to this many alternative alignments.
|
||||
p: Minimum Prob for a hit to be included in the output hhr file.
|
||||
HHblits default: 20.
|
||||
z: Hard cap on number of hits reported in the hhr file.
|
||||
HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If HHblits binary not found within the path.
|
||||
"""
|
||||
self.binary_path = binary_path
|
||||
self.databases = databases
|
||||
|
||||
for database_path in self.databases:
|
||||
if not glob.glob(database_path + '_*'):
|
||||
logging.error('Could not find HHBlits database %s',
|
||||
database_path)
|
||||
raise ValueError(
|
||||
f'Could not find HHBlits database {database_path}')
|
||||
|
||||
self.n_cpu = n_cpu
|
||||
self.n_iter = n_iter
|
||||
self.e_value = e_value
|
||||
self.maxseq = maxseq
|
||||
self.realign_max = realign_max
|
||||
self.maxfilt = maxfilt
|
||||
self.min_prefilter_hits = min_prefilter_hits
|
||||
self.all_seqs = all_seqs
|
||||
self.alt = alt
|
||||
self.p = p
|
||||
self.z = z
|
||||
|
||||
def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]:
|
||||
"""Queries the database using HHblits."""
|
||||
with utils.tmpdir_manager() as query_tmp_dir:
|
||||
a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
|
||||
|
||||
db_cmd = []
|
||||
for db_path in self.databases:
|
||||
db_cmd.append('-d')
|
||||
db_cmd.append(db_path)
|
||||
cmd = [
|
||||
self.binary_path,
|
||||
'-i',
|
||||
input_fasta_path,
|
||||
'-cpu',
|
||||
str(self.n_cpu),
|
||||
'-oa3m',
|
||||
a3m_path,
|
||||
'-o',
|
||||
'/dev/null',
|
||||
'-n',
|
||||
str(self.n_iter),
|
||||
'-e',
|
||||
str(self.e_value),
|
||||
'-maxseq',
|
||||
str(self.maxseq),
|
||||
'-realign_max',
|
||||
str(self.realign_max),
|
||||
'-maxfilt',
|
||||
str(self.maxfilt),
|
||||
'-min_prefilter_hits',
|
||||
str(self.min_prefilter_hits),
|
||||
]
|
||||
if self.all_seqs:
|
||||
cmd += ['-all']
|
||||
if self.alt:
|
||||
cmd += ['-alt', str(self.alt)]
|
||||
if self.p != _HHBLITS_DEFAULT_P:
|
||||
cmd += ['-p', str(self.p)]
|
||||
if self.z != _HHBLITS_DEFAULT_Z:
|
||||
cmd += ['-Z', str(self.z)]
|
||||
cmd += db_cmd
|
||||
|
||||
logging.info('Launching subprocess "%s"', ' '.join(cmd))
|
||||
process = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
with utils.timing('HHblits query'):
|
||||
stdout, stderr = process.communicate()
|
||||
retcode = process.wait()
|
||||
|
||||
if retcode:
|
||||
# Logs have a 15k character limit, so log HHblits error line by line.
|
||||
logging.error('HHblits failed. HHblits stderr begin:')
|
||||
for error_line in stderr.decode('utf-8').splitlines():
|
||||
if error_line.strip():
|
||||
logging.error(error_line.strip())
|
||||
logging.error('HHblits stderr end')
|
||||
raise RuntimeError(
|
||||
'HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' %
|
||||
(stdout.decode('utf-8'), stderr[:500_000].decode('utf-8')))
|
||||
|
||||
with open(a3m_path) as f:
|
||||
a3m = f.read()
|
||||
|
||||
raw_output = dict(
|
||||
a3m=a3m,
|
||||
output=stdout,
|
||||
stderr=stderr,
|
||||
n_iter=self.n_iter,
|
||||
e_value=self.e_value,
|
||||
)
|
||||
return [raw_output]
|
||||
111
modelscope/models/science/unifold/msa/tools/hhsearch.py
Normal file
111
modelscope/models/science/unifold/msa/tools/hhsearch.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Library to run HHsearch from Python."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Sequence
|
||||
|
||||
from absl import logging
|
||||
|
||||
from modelscope.models.science.unifold.msa import parsers
|
||||
from . import utils
|
||||
|
||||
|
||||
class HHSearch:
|
||||
"""Python wrapper of the HHsearch binary."""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
binary_path: str,
|
||||
databases: Sequence[str],
|
||||
maxseq: int = 1_000_000):
|
||||
"""Initializes the Python HHsearch wrapper.
|
||||
|
||||
Args:
|
||||
binary_path: The path to the HHsearch executable.
|
||||
databases: A sequence of HHsearch database paths. This should be the
|
||||
common prefix for the database files (i.e. up to but not including
|
||||
_hhm.ffindex etc.)
|
||||
maxseq: The maximum number of rows in an input alignment. Note that this
|
||||
parameter is only supported in HHBlits version 3.1 and higher.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If HHsearch binary not found within the path.
|
||||
"""
|
||||
self.binary_path = binary_path
|
||||
self.databases = databases
|
||||
self.maxseq = maxseq
|
||||
|
||||
for database_path in self.databases:
|
||||
if not glob.glob(database_path + '_*'):
|
||||
logging.error('Could not find HHsearch database %s',
|
||||
database_path)
|
||||
raise ValueError(
|
||||
f'Could not find HHsearch database {database_path}')
|
||||
|
||||
@property
|
||||
def output_format(self) -> str:
|
||||
return 'hhr'
|
||||
|
||||
@property
|
||||
def input_format(self) -> str:
|
||||
return 'a3m'
|
||||
|
||||
def query(self, a3m: str) -> str:
|
||||
"""Queries the database using HHsearch using a given a3m."""
|
||||
with utils.tmpdir_manager() as query_tmp_dir:
|
||||
input_path = os.path.join(query_tmp_dir, 'query.a3m')
|
||||
hhr_path = os.path.join(query_tmp_dir, 'output.hhr')
|
||||
with open(input_path, 'w') as f:
|
||||
f.write(a3m)
|
||||
|
||||
db_cmd = []
|
||||
for db_path in self.databases:
|
||||
db_cmd.append('-d')
|
||||
db_cmd.append(db_path)
|
||||
cmd = [
|
||||
self.binary_path,
|
||||
'-i',
|
||||
input_path,
|
||||
'-o',
|
||||
hhr_path,
|
||||
'-maxseq',
|
||||
str(self.maxseq),
|
||||
] + db_cmd
|
||||
|
||||
logging.info('Launching subprocess "%s"', ' '.join(cmd))
|
||||
process = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
with utils.timing('HHsearch query'):
|
||||
stdout, stderr = process.communicate()
|
||||
retcode = process.wait()
|
||||
|
||||
if retcode:
|
||||
# Stderr is truncated to prevent proto size errors in Beam.
|
||||
raise RuntimeError(
|
||||
'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' %
|
||||
(stdout.decode('utf-8'), stderr[:100_000].decode('utf-8')))
|
||||
|
||||
with open(hhr_path) as f:
|
||||
hhr = f.read()
|
||||
return hhr
|
||||
|
||||
def get_template_hits(
|
||||
self, output_string: str,
|
||||
input_sequence: str) -> Sequence[parsers.TemplateHit]:
|
||||
"""Gets parsed template hits from the raw string output by the tool."""
|
||||
del input_sequence # Used by hmmseach but not needed for hhsearch.
|
||||
return parsers.parse_hhr(output_string)
|
||||
143
modelscope/models/science/unifold/msa/tools/hmmbuild.py
Normal file
143
modelscope/models/science/unifold/msa/tools/hmmbuild.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
from absl import logging
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
class Hmmbuild(object):
|
||||
"""Python wrapper of the hmmbuild binary."""
|
||||
|
||||
def __init__(self, *, binary_path: str, singlemx: bool = False):
|
||||
"""Initializes the Python hmmbuild wrapper.
|
||||
|
||||
Args:
|
||||
binary_path: The path to the hmmbuild executable.
|
||||
singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
|
||||
just use a common substitution score matrix.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If hmmbuild binary not found within the path.
|
||||
"""
|
||||
self.binary_path = binary_path
|
||||
self.singlemx = singlemx
|
||||
|
||||
def build_profile_from_sto(self,
|
||||
sto: str,
|
||||
model_construction='fast') -> str:
|
||||
"""Builds a HHM for the aligned sequences given as an A3M string.
|
||||
|
||||
Args:
|
||||
sto: A string with the aligned sequences in the Stockholm format.
|
||||
model_construction: Whether to use reference annotation in the msa to
|
||||
determine consensus columns ('hand') or default ('fast').
|
||||
|
||||
Returns:
|
||||
A string with the profile in the HMM format.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If hmmbuild fails.
|
||||
"""
|
||||
return self._build_profile(sto, model_construction=model_construction)
|
||||
|
||||
def build_profile_from_a3m(self, a3m: str) -> str:
|
||||
"""Builds a HHM for the aligned sequences given as an A3M string.
|
||||
|
||||
Args:
|
||||
a3m: A string with the aligned sequences in the A3M format.
|
||||
|
||||
Returns:
|
||||
A string with the profile in the HMM format.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If hmmbuild fails.
|
||||
"""
|
||||
lines = []
|
||||
for line in a3m.splitlines():
|
||||
if not line.startswith('>'):
|
||||
line = re.sub('[a-z]+', '', line) # Remove inserted residues.
|
||||
lines.append(line + '\n')
|
||||
msa = ''.join(lines)
|
||||
return self._build_profile(msa, model_construction='fast')
|
||||
|
||||
def _build_profile(self,
|
||||
msa: str,
|
||||
model_construction: str = 'fast') -> str:
|
||||
"""Builds a HMM for the aligned sequences given as an MSA string.
|
||||
|
||||
Args:
|
||||
msa: A string with the aligned sequences, in A3M or STO format.
|
||||
model_construction: Whether to use reference annotation in the msa to
|
||||
determine consensus columns ('hand') or default ('fast').
|
||||
|
||||
Returns:
|
||||
A string with the profile in the HMM format.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If hmmbuild fails.
|
||||
ValueError: If unspecified arguments are provided.
|
||||
"""
|
||||
if model_construction not in {'hand', 'fast'}:
|
||||
raise ValueError(
|
||||
f'Invalid model_construction {model_construction} - only'
|
||||
'hand and fast supported.')
|
||||
|
||||
with utils.tmpdir_manager() as query_tmp_dir:
|
||||
input_query = os.path.join(query_tmp_dir, 'query.msa')
|
||||
output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm')
|
||||
|
||||
with open(input_query, 'w') as f:
|
||||
f.write(msa)
|
||||
|
||||
cmd = [self.binary_path]
|
||||
# If adding flags, we have to do so before the output and input:
|
||||
|
||||
if model_construction == 'hand':
|
||||
cmd.append(f'--{model_construction}')
|
||||
if self.singlemx:
|
||||
cmd.append('--singlemx')
|
||||
cmd.extend([
|
||||
'--amino',
|
||||
output_hmm_path,
|
||||
input_query,
|
||||
])
|
||||
|
||||
logging.info('Launching subprocess %s', cmd)
|
||||
process = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
with utils.timing('hmmbuild query'):
|
||||
stdout, stderr = process.communicate()
|
||||
retcode = process.wait()
|
||||
logging.info(
|
||||
'hmmbuild stdout:\n%s\n\nstderr:\n%s\n',
|
||||
stdout.decode('utf-8'),
|
||||
stderr.decode('utf-8'),
|
||||
)
|
||||
|
||||
if retcode:
|
||||
raise RuntimeError(
|
||||
'hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n' %
|
||||
(stdout.decode('utf-8'), stderr.decode('utf-8')))
|
||||
|
||||
with open(output_hmm_path, encoding='utf-8') as f:
|
||||
hmm = f.read()
|
||||
|
||||
return hmm
|
||||
146
modelscope/models/science/unifold/msa/tools/hmmsearch.py
Normal file
146
modelscope/models/science/unifold/msa/tools/hmmsearch.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from absl import logging
|
||||
|
||||
from modelscope.models.science.unifold.msa import parsers
|
||||
from . import hmmbuild, utils
|
||||
|
||||
|
||||
class Hmmsearch(object):
|
||||
"""Python wrapper of the hmmsearch binary."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
binary_path: str,
|
||||
hmmbuild_binary_path: str,
|
||||
database_path: str,
|
||||
flags: Optional[Sequence[str]] = None,
|
||||
):
|
||||
"""Initializes the Python hmmsearch wrapper.
|
||||
|
||||
Args:
|
||||
binary_path: The path to the hmmsearch executable.
|
||||
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
|
||||
an hmm from an input a3m.
|
||||
database_path: The path to the hmmsearch database (FASTA format).
|
||||
flags: List of flags to be used by hmmsearch.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If hmmsearch binary not found within the path.
|
||||
"""
|
||||
self.binary_path = binary_path
|
||||
self.hmmbuild_runner = hmmbuild.Hmmbuild(
|
||||
binary_path=hmmbuild_binary_path)
|
||||
self.database_path = database_path
|
||||
if flags is None:
|
||||
# Default hmmsearch run settings.
|
||||
flags = [
|
||||
'--F1',
|
||||
'0.1',
|
||||
'--F2',
|
||||
'0.1',
|
||||
'--F3',
|
||||
'0.1',
|
||||
'--incE',
|
||||
'100',
|
||||
'-E',
|
||||
'100',
|
||||
'--domE',
|
||||
'100',
|
||||
'--incdomE',
|
||||
'100',
|
||||
]
|
||||
self.flags = flags
|
||||
|
||||
if not os.path.exists(self.database_path):
|
||||
logging.error('Could not find hmmsearch database %s',
|
||||
database_path)
|
||||
raise ValueError(
|
||||
f'Could not find hmmsearch database {database_path}')
|
||||
|
||||
@property
|
||||
def output_format(self) -> str:
|
||||
return 'sto'
|
||||
|
||||
@property
|
||||
def input_format(self) -> str:
|
||||
return 'sto'
|
||||
|
||||
def query(self, msa_sto: str) -> str:
|
||||
"""Queries the database using hmmsearch using a given stockholm msa."""
|
||||
hmm = self.hmmbuild_runner.build_profile_from_sto(
|
||||
msa_sto, model_construction='hand')
|
||||
return self.query_with_hmm(hmm)
|
||||
|
||||
def query_with_hmm(self, hmm: str) -> str:
|
||||
"""Queries the database using hmmsearch using a given hmm."""
|
||||
with utils.tmpdir_manager() as query_tmp_dir:
|
||||
hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm')
|
||||
out_path = os.path.join(query_tmp_dir, 'output.sto')
|
||||
with open(hmm_input_path, 'w') as f:
|
||||
f.write(hmm)
|
||||
|
||||
cmd = [
|
||||
self.binary_path,
|
||||
'--noali', # Don't include the alignment in stdout.
|
||||
'--cpu',
|
||||
'8',
|
||||
]
|
||||
# If adding flags, we have to do so before the output and input:
|
||||
if self.flags:
|
||||
cmd.extend(self.flags)
|
||||
cmd.extend([
|
||||
'-A',
|
||||
out_path,
|
||||
hmm_input_path,
|
||||
self.database_path,
|
||||
])
|
||||
|
||||
logging.info('Launching sub-process %s', cmd)
|
||||
process = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
with utils.timing(
|
||||
f'hmmsearch ({os.path.basename(self.database_path)}) query'
|
||||
):
|
||||
stdout, stderr = process.communicate()
|
||||
retcode = process.wait()
|
||||
|
||||
if retcode:
|
||||
raise RuntimeError(
|
||||
'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' %
|
||||
(stdout.decode('utf-8'), stderr.decode('utf-8')))
|
||||
|
||||
with open(out_path) as f:
|
||||
out_msa = f.read()
|
||||
|
||||
return out_msa
|
||||
|
||||
def get_template_hits(
|
||||
self, output_string: str,
|
||||
input_sequence: str) -> Sequence[parsers.TemplateHit]:
|
||||
"""Gets parsed template hits from the raw string output by the tool."""
|
||||
a3m_string = parsers.convert_stockholm_to_a3m(
|
||||
output_string, remove_first_row_gaps=False)
|
||||
template_hits = parsers.parse_hmmsearch_a3m(
|
||||
query_sequence=input_sequence,
|
||||
a3m_string=a3m_string,
|
||||
skip_first=False)
|
||||
return template_hits
|
||||
224
modelscope/models/science/unifold/msa/tools/jackhmmer.py
Normal file
224
modelscope/models/science/unifold/msa/tools/jackhmmer.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Library to run Jackhmmer from Python."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import subprocess
|
||||
from concurrent import futures
|
||||
from typing import Any, Callable, Mapping, Optional, Sequence
|
||||
from urllib import request
|
||||
|
||||
from absl import logging
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
class Jackhmmer:
|
||||
"""Python wrapper of the Jackhmmer binary."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
binary_path: str,
|
||||
database_path: str,
|
||||
n_cpu: int = 8,
|
||||
n_iter: int = 1,
|
||||
e_value: float = 0.0001,
|
||||
z_value: Optional[int] = None,
|
||||
get_tblout: bool = False,
|
||||
filter_f1: float = 0.0005,
|
||||
filter_f2: float = 0.00005,
|
||||
filter_f3: float = 0.0000005,
|
||||
incdom_e: Optional[float] = None,
|
||||
dom_e: Optional[float] = None,
|
||||
num_streamed_chunks: Optional[int] = None,
|
||||
streaming_callback: Optional[Callable[[int], None]] = None,
|
||||
):
|
||||
"""Initializes the Python Jackhmmer wrapper.
|
||||
|
||||
Args:
|
||||
binary_path: The path to the jackhmmer executable.
|
||||
database_path: The path to the jackhmmer database (FASTA format).
|
||||
n_cpu: The number of CPUs to give Jackhmmer.
|
||||
n_iter: The number of Jackhmmer iterations.
|
||||
e_value: The E-value, see Jackhmmer docs for more details.
|
||||
z_value: The Z-value, see Jackhmmer docs for more details.
|
||||
get_tblout: Whether to save tblout string.
|
||||
filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
|
||||
filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
|
||||
filter_f3: Forward pre-filter, set to >1.0 to turn off.
|
||||
incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
|
||||
round.
|
||||
dom_e: Domain e-value criteria for inclusion in tblout.
|
||||
num_streamed_chunks: Number of database chunks to stream over.
|
||||
streaming_callback: Callback function run after each chunk iteration with
|
||||
the iteration number as argument.
|
||||
"""
|
||||
self.binary_path = binary_path
|
||||
self.database_path = database_path
|
||||
self.num_streamed_chunks = num_streamed_chunks
|
||||
|
||||
if not os.path.exists(
|
||||
self.database_path) and num_streamed_chunks is None:
|
||||
logging.error('Could not find Jackhmmer database %s',
|
||||
database_path)
|
||||
raise ValueError(
|
||||
f'Could not find Jackhmmer database {database_path}')
|
||||
|
||||
self.n_cpu = n_cpu
|
||||
self.n_iter = n_iter
|
||||
self.e_value = e_value
|
||||
self.z_value = z_value
|
||||
self.filter_f1 = filter_f1
|
||||
self.filter_f2 = filter_f2
|
||||
self.filter_f3 = filter_f3
|
||||
self.incdom_e = incdom_e
|
||||
self.dom_e = dom_e
|
||||
self.get_tblout = get_tblout
|
||||
self.streaming_callback = streaming_callback
|
||||
|
||||
def _query_chunk(self, input_fasta_path: str,
|
||||
database_path: str) -> Mapping[str, Any]:
|
||||
"""Queries the database chunk using Jackhmmer."""
|
||||
with utils.tmpdir_manager() as query_tmp_dir:
|
||||
sto_path = os.path.join(query_tmp_dir, 'output.sto')
|
||||
|
||||
# The F1/F2/F3 are the expected proportion to pass each of the filtering
|
||||
# stages (which get progressively more expensive), reducing these
|
||||
# speeds up the pipeline at the expensive of sensitivity. They are
|
||||
# currently set very low to make querying Mgnify run in a reasonable
|
||||
# amount of time.
|
||||
cmd_flags = [
|
||||
# Don't pollute stdout with Jackhmmer output.
|
||||
'-o',
|
||||
'/dev/null',
|
||||
'-A',
|
||||
sto_path,
|
||||
'--noali',
|
||||
'--F1',
|
||||
str(self.filter_f1),
|
||||
'--F2',
|
||||
str(self.filter_f2),
|
||||
'--F3',
|
||||
str(self.filter_f3),
|
||||
'--incE',
|
||||
str(self.e_value),
|
||||
# Report only sequences with E-values <= x in per-sequence output.
|
||||
'-E',
|
||||
str(self.e_value),
|
||||
'--cpu',
|
||||
str(self.n_cpu),
|
||||
'-N',
|
||||
str(self.n_iter),
|
||||
]
|
||||
if self.get_tblout:
|
||||
tblout_path = os.path.join(query_tmp_dir, 'tblout.txt')
|
||||
cmd_flags.extend(['--tblout', tblout_path])
|
||||
|
||||
if self.z_value:
|
||||
cmd_flags.extend(['-Z', str(self.z_value)])
|
||||
|
||||
if self.dom_e is not None:
|
||||
cmd_flags.extend(['--domE', str(self.dom_e)])
|
||||
|
||||
if self.incdom_e is not None:
|
||||
cmd_flags.extend(['--incdomE', str(self.incdom_e)])
|
||||
|
||||
cmd = [self.binary_path
|
||||
] + cmd_flags + [input_fasta_path, database_path]
|
||||
|
||||
logging.info('Launching subprocess "%s"', ' '.join(cmd))
|
||||
process = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
with utils.timing(
|
||||
f'Jackhmmer ({os.path.basename(database_path)}) query'):
|
||||
_, stderr = process.communicate()
|
||||
retcode = process.wait()
|
||||
|
||||
if retcode:
|
||||
raise RuntimeError('Jackhmmer failed\nstderr:\n%s\n'
|
||||
% stderr.decode('utf-8'))
|
||||
|
||||
# Get e-values for each target name
|
||||
tbl = ''
|
||||
if self.get_tblout:
|
||||
with open(tblout_path) as f:
|
||||
tbl = f.read()
|
||||
|
||||
with open(sto_path) as f:
|
||||
sto = f.read()
|
||||
|
||||
raw_output = dict(
|
||||
sto=sto,
|
||||
tbl=tbl,
|
||||
stderr=stderr,
|
||||
n_iter=self.n_iter,
|
||||
e_value=self.e_value)
|
||||
|
||||
return raw_output
|
||||
|
||||
def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:
|
||||
"""Queries the database using Jackhmmer."""
|
||||
if self.num_streamed_chunks is None:
|
||||
return [self._query_chunk(input_fasta_path, self.database_path)]
|
||||
|
||||
db_basename = os.path.basename(self.database_path)
|
||||
|
||||
def db_remote_chunk(db_idx):
|
||||
return f'{self.database_path}.{db_idx}'
|
||||
|
||||
def db_local_chunk(db_idx):
|
||||
return f'/tmp/ramdisk/{db_basename}.{db_idx}'
|
||||
|
||||
# db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}'
|
||||
# db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}'
|
||||
|
||||
# Remove existing files to prevent OOM
|
||||
for f in glob.glob(db_local_chunk('[0-9]*')):
|
||||
try:
|
||||
os.remove(f)
|
||||
except OSError:
|
||||
print(f'OSError while deleting {f}')
|
||||
|
||||
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
|
||||
with futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
chunked_output = []
|
||||
for i in range(1, self.num_streamed_chunks + 1):
|
||||
# Copy the chunk locally
|
||||
if i == 1:
|
||||
future = executor.submit(request.urlretrieve,
|
||||
db_remote_chunk(i),
|
||||
db_local_chunk(i))
|
||||
if i < self.num_streamed_chunks:
|
||||
next_future = executor.submit(
|
||||
request.urlretrieve,
|
||||
db_remote_chunk(i + 1),
|
||||
db_local_chunk(i + 1),
|
||||
)
|
||||
|
||||
# Run Jackhmmer with the chunk
|
||||
future.result()
|
||||
chunked_output.append(
|
||||
self._query_chunk(input_fasta_path, db_local_chunk(i)))
|
||||
|
||||
# Remove the local copy of the chunk
|
||||
os.remove(db_local_chunk(i))
|
||||
# Do not set next_future for the last chunk so that this works even for
|
||||
# databases with only 1 chunk.
|
||||
if i < self.num_streamed_chunks:
|
||||
future = next_future
|
||||
if self.streaming_callback:
|
||||
self.streaming_callback(i)
|
||||
return chunked_output
|
||||
110
modelscope/models/science/unifold/msa/tools/kalign.py
Normal file
110
modelscope/models/science/unifold/msa/tools/kalign.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A Python wrapper for Kalign."""
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Sequence
|
||||
|
||||
from absl import logging
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
def _to_a3m(sequences: Sequence[str]) -> str:
|
||||
"""Converts sequences to an a3m file."""
|
||||
names = ['sequence %d' % i for i in range(1, len(sequences) + 1)]
|
||||
a3m = []
|
||||
for sequence, name in zip(sequences, names):
|
||||
a3m.append('>' + name + '\n')
|
||||
a3m.append(sequence + '\n')
|
||||
return ''.join(a3m)
|
||||
|
||||
|
||||
class Kalign:
|
||||
"""Python wrapper of the Kalign binary."""
|
||||
|
||||
def __init__(self, *, binary_path: str):
|
||||
"""Initializes the Python Kalign wrapper.
|
||||
|
||||
Args:
|
||||
binary_path: The path to the Kalign binary.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If Kalign binary not found within the path.
|
||||
"""
|
||||
self.binary_path = binary_path
|
||||
|
||||
def align(self, sequences: Sequence[str]) -> str:
|
||||
"""Aligns the sequences and returns the alignment in A3M string.
|
||||
|
||||
Args:
|
||||
sequences: A list of query sequence strings. The sequences have to be at
|
||||
least 6 residues long (Kalign requires this). Note that the order in
|
||||
which you give the sequences might alter the output slightly as
|
||||
different alignment tree might get constructed.
|
||||
|
||||
Returns:
|
||||
A string with the alignment in a3m format.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If Kalign fails.
|
||||
ValueError: If any of the sequences is less than 6 residues long.
|
||||
"""
|
||||
logging.info('Aligning %d sequences', len(sequences))
|
||||
|
||||
for s in sequences:
|
||||
if len(s) < 6:
|
||||
raise ValueError(
|
||||
'Kalign requires all sequences to be at least 6 '
|
||||
'residues long. Got %s (%d residues).' % (s, len(s)))
|
||||
|
||||
with utils.tmpdir_manager() as query_tmp_dir:
|
||||
input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta')
|
||||
output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
|
||||
|
||||
with open(input_fasta_path, 'w') as f:
|
||||
f.write(_to_a3m(sequences))
|
||||
|
||||
cmd = [
|
||||
self.binary_path,
|
||||
'-i',
|
||||
input_fasta_path,
|
||||
'-o',
|
||||
output_a3m_path,
|
||||
'-format',
|
||||
'fasta',
|
||||
]
|
||||
|
||||
logging.info('Launching subprocess "%s"', ' '.join(cmd))
|
||||
process = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
with utils.timing('Kalign query'):
|
||||
stdout, stderr = process.communicate()
|
||||
retcode = process.wait()
|
||||
logging.info(
|
||||
'Kalign stdout:\n%s\n\nstderr:\n%s\n',
|
||||
stdout.decode('utf-8'),
|
||||
stderr.decode('utf-8'),
|
||||
)
|
||||
|
||||
if retcode:
|
||||
raise RuntimeError(
|
||||
'Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' %
|
||||
(stdout.decode('utf-8'), stderr.decode('utf-8')))
|
||||
|
||||
with open(output_a3m_path) as f:
|
||||
a3m = f.read()
|
||||
|
||||
return a3m
|
||||
40
modelscope/models/science/unifold/msa/tools/utils.py
Normal file
40
modelscope/models/science/unifold/msa/tools/utils.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Common utilities for data pipeline tools."""
|
||||
import contextlib
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from absl import logging
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tmpdir_manager(base_dir: Optional[str] = None):
|
||||
"""Context manager that deletes a temporary directory on exit."""
|
||||
tmpdir = tempfile.mkdtemp(dir=base_dir)
|
||||
try:
|
||||
yield tmpdir
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def timing(msg: str):
|
||||
logging.info('Started %s', msg)
|
||||
tic = time.time()
|
||||
yield
|
||||
toc = time.time()
|
||||
logging.info('Finished %s in %.3f seconds', msg, toc - tic)
|
||||
89
modelscope/models/science/unifold/msa/utils.py
Normal file
89
modelscope/models/science/unifold/msa/utils.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
|
||||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
|
||||
|
||||
import os
|
||||
from typing import Mapping, Sequence
|
||||
|
||||
import json
|
||||
from absl import logging
|
||||
|
||||
from modelscope.models.science.unifold.data import protein
|
||||
|
||||
|
||||
def get_chain_id_map(
|
||||
sequences: Sequence[str],
|
||||
descriptions: Sequence[str],
|
||||
):
|
||||
"""
|
||||
Makes a mapping from PDB-format chain ID to sequence and description,
|
||||
and parses the order of multi-chains
|
||||
"""
|
||||
unique_seqs = []
|
||||
for seq in sequences:
|
||||
if seq not in unique_seqs:
|
||||
unique_seqs.append(seq)
|
||||
|
||||
chain_id_map = {
|
||||
chain_id: {
|
||||
'descriptions': [],
|
||||
'sequence': seq
|
||||
}
|
||||
for chain_id, seq in zip(protein.PDB_CHAIN_IDS, unique_seqs)
|
||||
}
|
||||
chain_order = []
|
||||
|
||||
for seq, des in zip(sequences, descriptions):
|
||||
chain_id = protein.PDB_CHAIN_IDS[unique_seqs.index(seq)]
|
||||
chain_id_map[chain_id]['descriptions'].append(des)
|
||||
chain_order.append(chain_id)
|
||||
|
||||
return chain_id_map, chain_order
|
||||
|
||||
|
||||
def divide_multi_chains(
|
||||
fasta_name: str,
|
||||
output_dir_base: str,
|
||||
sequences: Sequence[str],
|
||||
descriptions: Sequence[str],
|
||||
):
|
||||
"""
|
||||
Divides the multi-chains fasta into several single fasta files and
|
||||
records multi-chains mapping information.
|
||||
"""
|
||||
if len(sequences) != len(descriptions):
|
||||
raise ValueError('sequences and descriptions must have equal length. '
|
||||
f'Got {len(sequences)} != {len(descriptions)}.')
|
||||
if len(sequences) > protein.PDB_MAX_CHAINS:
|
||||
raise ValueError(
|
||||
'Cannot process more chains than the PDB format supports. '
|
||||
f'Got {len(sequences)} chains.')
|
||||
|
||||
chain_id_map, chain_order = get_chain_id_map(sequences, descriptions)
|
||||
|
||||
output_dir = os.path.join(output_dir_base, fasta_name)
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
chain_id_map_path = os.path.join(output_dir, 'chain_id_map.json')
|
||||
with open(chain_id_map_path, 'w') as f:
|
||||
json.dump(chain_id_map, f, indent=4, sort_keys=True)
|
||||
|
||||
chain_order_path = os.path.join(output_dir, 'chains.txt')
|
||||
with open(chain_order_path, 'w') as f:
|
||||
f.write(' '.join(chain_order))
|
||||
|
||||
logging.info('Mapping multi-chains fasta with chain order: %s',
|
||||
' '.join(chain_order))
|
||||
|
||||
temp_names = []
|
||||
temp_paths = []
|
||||
for chain_id in chain_id_map.keys():
|
||||
temp_name = fasta_name + '_{}'.format(chain_id)
|
||||
temp_path = os.path.join(output_dir, temp_name + '.fasta')
|
||||
des = 'chain_{}'.format(chain_id)
|
||||
seq = chain_id_map[chain_id]['sequence']
|
||||
with open(temp_path, 'w') as f:
|
||||
f.write('>' + des + '\n' + seq)
|
||||
temp_names.append(temp_name)
|
||||
temp_paths.append(temp_path)
|
||||
return temp_names, temp_paths
|
||||
@@ -39,8 +39,7 @@ class TextRankingDataset(TorchTaskDataset):
|
||||
['title', 'text'])
|
||||
self.qid_field = self.dataset_config.get('qid_field', 'query_id')
|
||||
if mode == ModeKeys.TRAIN:
|
||||
train_config = kwargs.get('train', {})
|
||||
self.neg_samples = train_config.get('neg_samples', 4)
|
||||
self.neg_samples = self.dataset_config.get('neg_sample', 4)
|
||||
|
||||
super().__init__(datasets, mode, preprocessor, **kwargs)
|
||||
|
||||
|
||||
@@ -762,12 +762,13 @@ TASK_OUTPUTS = {
|
||||
# }
|
||||
Tasks.hand_static: [OutputKeys.OUTPUT],
|
||||
|
||||
# 'output': [
|
||||
# [2, 75, 287, 240, 510, 0.8335018754005432],
|
||||
# [1, 127, 83, 332, 366, 0.9175254702568054],
|
||||
# [0, 0, 0, 367, 639, 0.9693422317504883]]
|
||||
# { 'labels': [2, 1, 0],
|
||||
# 'boxes':[[[78, 282, 240, 504], [127, 87, 332, 370], [0, 0, 367, 639]]
|
||||
# 'scores':[0.8202137351036072, 0.8987470269203186, 0.9679114818572998]
|
||||
# }
|
||||
Tasks.face_human_hand_detection: [OutputKeys.OUTPUT],
|
||||
Tasks.face_human_hand_detection: [
|
||||
OutputKeys.LABELS, OutputKeys.BOXES, OutputKeys.SCORES
|
||||
],
|
||||
|
||||
# {
|
||||
# {'output': 'Happiness', 'boxes': (203, 104, 663, 564)}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.face_emotion import emotion_infer
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -28,10 +31,11 @@ class FaceEmotionPipeline(Pipeline):
|
||||
logger.info('load model done')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
return input
|
||||
img = LoadImage.convert_to_ndarray(input['img_path'])
|
||||
return img
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result, bbox = emotion_infer.inference(input['img_path'], self.model,
|
||||
result, bbox = emotion_infer.inference(input, self.model,
|
||||
self.face_model)
|
||||
return {OutputKeys.OUTPUT: result, OutputKeys.BOXES: bbox}
|
||||
|
||||
|
||||
@@ -2,11 +2,14 @@
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.face_human_hand_detection import det_infer
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -29,14 +32,19 @@ class NanoDettForFaceHumanHandDetectionPipeline(Pipeline):
|
||||
logger.info('load model done')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
return input
|
||||
img = LoadImage.convert_to_ndarray(input['input_path'])
|
||||
return img
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
result = det_infer.inference(self.model, self.device,
|
||||
input['input_path'])
|
||||
logger.info(result)
|
||||
return {OutputKeys.OUTPUT: result}
|
||||
cls_list, bbox_list, score_list = det_infer.inference(
|
||||
self.model, self.device, input)
|
||||
logger.info(cls_list, bbox_list, score_list)
|
||||
return {
|
||||
OutputKeys.LABELS: cls_list,
|
||||
OutputKeys.BOXES: bbox_list,
|
||||
OutputKeys.SCORES: score_list
|
||||
}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.hand_static import hand_model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -27,10 +30,11 @@ class HandStaticPipeline(Pipeline):
|
||||
logger.info('load model done')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
return input
|
||||
img = LoadImage.convert_to_ndarray(input['img_path'])
|
||||
return img
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = hand_model.infer(input['img_path'], self.model, self.device)
|
||||
result = hand_model.infer(input, self.model, self.device)
|
||||
return {OutputKeys.OUTPUT: result}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@@ -2,11 +2,14 @@
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.product_segmentation import seg_infer
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -28,12 +31,13 @@ class F3NetForProductSegmentationPipeline(Pipeline):
|
||||
logger.info('load model done')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
return input
|
||||
img = LoadImage.convert_to_ndarray(input['input_path'])
|
||||
img = img.astype(np.float32)
|
||||
return img
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
mask = seg_infer.inference(self.model, self.device,
|
||||
input['input_path'])
|
||||
mask = seg_infer.inference(self.model, self.device, input)
|
||||
return {OutputKeys.MASKS: mask}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
22
modelscope/pipelines/science/__init__.py
Normal file
22
modelscope/pipelines/science/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protein_structure_pipeline import ProteinStructurePipeline
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'protein_structure_pipeline': ['ProteinStructurePipeline']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
215
modelscope/pipelines/science/protein_structure_pipeline.py
Normal file
215
modelscope/pipelines/science/protein_structure_pipeline.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from unicore.utils import tensor_tree_map
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.science.unifold.config import model_config
|
||||
from modelscope.models.science.unifold.data import protein, residue_constants
|
||||
from modelscope.models.science.unifold.dataset import (UnifoldDataset,
|
||||
load_and_process)
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline, Tensor
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import Preprocessor, build_preprocessor
|
||||
from modelscope.utils.constant import Fields, Frameworks, Tasks
|
||||
from modelscope.utils.device import device_placement
|
||||
from modelscope.utils.hub import read_config
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['ProteinStructurePipeline']
|
||||
|
||||
|
||||
def automatic_chunk_size(seq_len):
|
||||
if seq_len < 512:
|
||||
chunk_size = 256
|
||||
elif seq_len < 1024:
|
||||
chunk_size = 128
|
||||
elif seq_len < 2048:
|
||||
chunk_size = 32
|
||||
elif seq_len < 3072:
|
||||
chunk_size = 16
|
||||
else:
|
||||
chunk_size = 1
|
||||
return chunk_size
|
||||
|
||||
|
||||
def load_feature_for_one_target(
|
||||
config,
|
||||
data_folder,
|
||||
seed=0,
|
||||
is_multimer=False,
|
||||
use_uniprot=False,
|
||||
symmetry_group=None,
|
||||
):
|
||||
if not is_multimer:
|
||||
uniprot_msa_dir = None
|
||||
sequence_ids = ['A']
|
||||
if use_uniprot:
|
||||
uniprot_msa_dir = data_folder
|
||||
|
||||
else:
|
||||
uniprot_msa_dir = data_folder
|
||||
sequence_ids = open(os.path.join(data_folder,
|
||||
'chains.txt')).readline().split()
|
||||
|
||||
if symmetry_group is None:
|
||||
batch, _ = load_and_process(
|
||||
config=config.data,
|
||||
mode='predict',
|
||||
seed=seed,
|
||||
batch_idx=None,
|
||||
data_idx=0,
|
||||
is_distillation=False,
|
||||
sequence_ids=sequence_ids,
|
||||
monomer_feature_dir=data_folder,
|
||||
uniprot_msa_dir=uniprot_msa_dir,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
batch = UnifoldDataset.collater([batch])
|
||||
return batch
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.protein_structure, module_name=Pipelines.protein_structure)
|
||||
class ProteinStructurePipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str],
|
||||
preprocessor: Optional[Preprocessor] = None,
|
||||
**kwargs):
|
||||
"""Use `model` and `preprocessor` to create a protein structure pipeline for prediction.
|
||||
|
||||
Args:
|
||||
model (str or Model): Supply either a local model dir which supported the protein structure task,
|
||||
or a model id from the model hub, or a torch model instance.
|
||||
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
|
||||
the model if supplied.
|
||||
|
||||
Example:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> pipeline_ins = pipeline(task='protein-structure',
|
||||
>>> model='DPTech/uni-fold-monomer')
|
||||
>>> protein = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVC'
|
||||
>>> print(pipeline_ins(protein))
|
||||
|
||||
"""
|
||||
import copy
|
||||
model_path = copy.deepcopy(model) if isinstance(model, str) else None
|
||||
cfg = read_config(model_path) # only model is str
|
||||
self.cfg = cfg
|
||||
self.config = model_config(
|
||||
cfg['pipeline']['model_name']) # alphafold config
|
||||
model = model if isinstance(
|
||||
model, Model) else Model.from_pretrained(model_path)
|
||||
self.postprocessor = cfg.pop('postprocessor', None)
|
||||
if preprocessor is None:
|
||||
preprocessor_cfg = cfg.preprocessor
|
||||
preprocessor = build_preprocessor(preprocessor_cfg, Fields.science)
|
||||
model.eval()
|
||||
model.model.inference_mode()
|
||||
model.model_dir = model_path
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, **pipeline_parameters):
|
||||
return pipeline_parameters, pipeline_parameters, pipeline_parameters
|
||||
|
||||
def _process_single(self, input, *args, **kwargs) -> Dict[str, Any]:
|
||||
preprocess_params = kwargs.get('preprocess_params', {})
|
||||
forward_params = kwargs.get('forward_params', {})
|
||||
postprocess_params = kwargs.get('postprocess_params', {})
|
||||
out = self.preprocess(input, **preprocess_params)
|
||||
with device_placement(self.framework, self.device_name):
|
||||
with torch.no_grad():
|
||||
out = self.forward(out, **forward_params)
|
||||
|
||||
out = self.postprocess(out, **postprocess_params)
|
||||
return out
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
plddts = {}
|
||||
ptms = {}
|
||||
|
||||
output_dir = os.path.join(self.preprocessor.output_dir_base,
|
||||
inputs['target_id'])
|
||||
|
||||
pdbs = []
|
||||
for seed in range(self.cfg['pipeline']['times']):
|
||||
cur_seed = hash((42, seed)) % 100000
|
||||
batch = load_feature_for_one_target(
|
||||
self.config,
|
||||
output_dir,
|
||||
cur_seed,
|
||||
is_multimer=inputs['is_multimer'],
|
||||
use_uniprot=inputs['is_multimer'],
|
||||
symmetry_group=self.preprocessor.symmetry_group,
|
||||
)
|
||||
seq_len = batch['aatype'].shape[-1]
|
||||
self.model.model.globals.chunk_size = automatic_chunk_size(seq_len)
|
||||
|
||||
with torch.no_grad():
|
||||
batch = {
|
||||
k: torch.as_tensor(v, device='cuda:0')
|
||||
for k, v in batch.items()
|
||||
}
|
||||
out = self.model(batch)
|
||||
|
||||
def to_float(x):
|
||||
if x.dtype == torch.bfloat16 or x.dtype == torch.half:
|
||||
return x.float()
|
||||
else:
|
||||
return x
|
||||
|
||||
# Toss out the recycling dimensions --- we don't need them anymore
|
||||
batch = tensor_tree_map(lambda t: t[-1, 0, ...], batch)
|
||||
batch = tensor_tree_map(to_float, batch)
|
||||
out = tensor_tree_map(lambda t: t[0, ...], out[0])
|
||||
out = tensor_tree_map(to_float, out)
|
||||
batch = tensor_tree_map(lambda x: np.array(x.cpu()), batch)
|
||||
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
|
||||
|
||||
plddt = out['plddt']
|
||||
mean_plddt = np.mean(plddt)
|
||||
plddt_b_factors = np.repeat(
|
||||
plddt[..., None], residue_constants.atom_type_num, axis=-1)
|
||||
# TODO: , may need to reorder chains, based on entity_ids
|
||||
cur_protein = protein.from_prediction(
|
||||
features=batch, result=out, b_factors=plddt_b_factors)
|
||||
cur_save_name = (f'{cur_seed}')
|
||||
plddts[cur_save_name] = str(mean_plddt)
|
||||
if inputs[
|
||||
'is_multimer'] and self.preprocessor.symmetry_group is None:
|
||||
ptms[cur_save_name] = str(np.mean(out['iptm+ptm']))
|
||||
with open(os.path.join(output_dir, cur_save_name + '.pdb'),
|
||||
'w') as f:
|
||||
f.write(protein.to_pdb(cur_protein))
|
||||
pdbs.append(protein.to_pdb(cur_protein))
|
||||
|
||||
logger.info('plddts:' + str(plddts))
|
||||
model_name = self.cfg['pipeline']['model_name']
|
||||
score_name = f'{model_name}'
|
||||
plddt_fname = score_name + '_plddt.json'
|
||||
|
||||
with open(os.path.join(output_dir, plddt_fname), 'w') as f:
|
||||
json.dump(plddts, f, indent=4)
|
||||
if ptms:
|
||||
logger.info('ptms' + str(ptms))
|
||||
ptm_fname = score_name + '_ptm.json'
|
||||
with open(os.path.join(output_dir, ptm_fname), 'w') as f:
|
||||
json.dump(ptms, f, indent=4)
|
||||
|
||||
return pdbs
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Tensor], **postprocess_params):
|
||||
return inputs
|
||||
20
modelscope/preprocessors/science/__init__.py
Normal file
20
modelscope/preprocessors/science/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .unifold import (UniFoldPreprocessor)
|
||||
|
||||
else:
|
||||
_import_structure = {'unifold': ['UniFoldPreprocessor']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
569
modelscope/preprocessors/science/uni_fold.py
Normal file
569
modelscope/preprocessors/science/uni_fold.py
Normal file
@@ -0,0 +1,569 @@
|
||||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
|
||||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.
|
||||
|
||||
import gzip
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import re
|
||||
import tarfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from unittest import result
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.models.science.unifold.data import protein, residue_constants
|
||||
from modelscope.models.science.unifold.data.protein import PDB_CHAIN_IDS
|
||||
from modelscope.models.science.unifold.data.utils import compress_features
|
||||
from modelscope.models.science.unifold.msa import parsers, pipeline, templates
|
||||
from modelscope.models.science.unifold.msa.tools import hhsearch
|
||||
from modelscope.models.science.unifold.msa.utils import divide_multi_chains
|
||||
from modelscope.preprocessors.base import Preprocessor
|
||||
from modelscope.preprocessors.builder import PREPROCESSORS
|
||||
from modelscope.utils.constant import Fields
|
||||
|
||||
__all__ = [
|
||||
'UniFoldPreprocessor',
|
||||
]
|
||||
|
||||
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'
|
||||
DEFAULT_API_SERVER = 'https://api.colabfold.com'
|
||||
|
||||
|
||||
def run_mmseqs2(
|
||||
x,
|
||||
prefix,
|
||||
use_env=True,
|
||||
use_templates=False,
|
||||
use_pairing=False,
|
||||
host_url='https://api.colabfold.com') -> Tuple[List[str], List[str]]:
|
||||
submission_endpoint = 'ticket/pair' if use_pairing else 'ticket/msa'
|
||||
|
||||
def submit(seqs, mode, N=101):
|
||||
n, query = N, ''
|
||||
for seq in seqs:
|
||||
query += f'>{n}\n{seq}\n'
|
||||
n += 1
|
||||
|
||||
res = requests.post(
|
||||
f'{host_url}/{submission_endpoint}',
|
||||
data={
|
||||
'q': query,
|
||||
'mode': mode
|
||||
})
|
||||
try:
|
||||
out = res.json()
|
||||
except ValueError:
|
||||
out = {'status': 'ERROR'}
|
||||
return out
|
||||
|
||||
def status(ID):
|
||||
res = requests.get(f'{host_url}/ticket/{ID}')
|
||||
try:
|
||||
out = res.json()
|
||||
except ValueError:
|
||||
out = {'status': 'ERROR'}
|
||||
return out
|
||||
|
||||
def download(ID, path):
|
||||
res = requests.get(f'{host_url}/result/download/{ID}')
|
||||
with open(path, 'wb') as out:
|
||||
out.write(res.content)
|
||||
|
||||
# process input x
|
||||
seqs = [x] if isinstance(x, str) else x
|
||||
|
||||
mode = 'env'
|
||||
if use_pairing:
|
||||
mode = ''
|
||||
use_templates = False
|
||||
use_env = False
|
||||
|
||||
# define path
|
||||
path = f'{prefix}'
|
||||
if not os.path.isdir(path):
|
||||
os.mkdir(path)
|
||||
|
||||
# call mmseqs2 api
|
||||
tar_gz_file = f'{path}/out_{mode}.tar.gz'
|
||||
N, REDO = 101, True
|
||||
|
||||
# deduplicate and keep track of order
|
||||
seqs_unique = []
|
||||
# TODO this might be slow for large sets
|
||||
[seqs_unique.append(x) for x in seqs if x not in seqs_unique]
|
||||
Ms = [N + seqs_unique.index(seq) for seq in seqs]
|
||||
# lets do it!
|
||||
if not os.path.isfile(tar_gz_file):
|
||||
TIME_ESTIMATE = 150 * len(seqs_unique)
|
||||
with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar:
|
||||
while REDO:
|
||||
pbar.set_description('SUBMIT')
|
||||
|
||||
# Resubmit job until it goes through
|
||||
out = submit(seqs_unique, mode, N)
|
||||
while out['status'] in ['UNKNOWN', 'RATELIMIT']:
|
||||
sleep_time = 5 + random.randint(0, 5)
|
||||
# logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}")
|
||||
# resubmit
|
||||
time.sleep(sleep_time)
|
||||
out = submit(seqs_unique, mode, N)
|
||||
|
||||
if out['status'] == 'ERROR':
|
||||
error = 'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence.'
|
||||
error = error + 'If error persists, please try again an hour later.'
|
||||
raise Exception(error)
|
||||
|
||||
if out['status'] == 'MAINTENANCE':
|
||||
raise Exception(
|
||||
'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.'
|
||||
)
|
||||
|
||||
# wait for job to finish
|
||||
ID, TIME = out['id'], 0
|
||||
pbar.set_description(out['status'])
|
||||
while out['status'] in ['UNKNOWN', 'RUNNING', 'PENDING']:
|
||||
t = 5 + random.randint(0, 5)
|
||||
# logger.error(f"Sleeping for {t}s. Reason: {out['status']}")
|
||||
time.sleep(t)
|
||||
out = status(ID)
|
||||
pbar.set_description(out['status'])
|
||||
if out['status'] == 'RUNNING':
|
||||
TIME += t
|
||||
pbar.update(n=t)
|
||||
|
||||
if out['status'] == 'COMPLETE':
|
||||
if TIME < TIME_ESTIMATE:
|
||||
pbar.update(n=(TIME_ESTIMATE - TIME))
|
||||
REDO = False
|
||||
|
||||
if out['status'] == 'ERROR':
|
||||
REDO = False
|
||||
error = 'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence.'
|
||||
error = error + 'If error persists, please try again an hour later.'
|
||||
raise Exception(error)
|
||||
|
||||
# Download results
|
||||
download(ID, tar_gz_file)
|
||||
|
||||
# prep list of a3m files
|
||||
if use_pairing:
|
||||
a3m_files = [f'{path}/pair.a3m']
|
||||
else:
|
||||
a3m_files = [f'{path}/uniref.a3m']
|
||||
if use_env:
|
||||
a3m_files.append(f'{path}/bfd.mgnify30.metaeuk30.smag30.a3m')
|
||||
|
||||
# extract a3m files
|
||||
if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files):
|
||||
with tarfile.open(tar_gz_file) as tar_gz:
|
||||
tar_gz.extractall(path)
|
||||
|
||||
# templates
|
||||
if use_templates:
|
||||
templates = {}
|
||||
|
||||
with open(f'{path}/pdb70.m8', 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
p = line.rstrip().split()
|
||||
M, pdb, _, _ = p[0], p[1], p[2], p[10] # qid, e_value
|
||||
M = int(M)
|
||||
if M not in templates:
|
||||
templates[M] = []
|
||||
templates[M].append(pdb)
|
||||
|
||||
template_paths = {}
|
||||
for k, TMPL in templates.items():
|
||||
TMPL_PATH = f'{prefix}/templates_{k}'
|
||||
if not os.path.isdir(TMPL_PATH):
|
||||
os.mkdir(TMPL_PATH)
|
||||
TMPL_LINE = ','.join(TMPL[:20])
|
||||
os.system(
|
||||
f'curl -s -L {host_url}/template/{TMPL_LINE} | tar xzf - -C {TMPL_PATH}/'
|
||||
)
|
||||
os.system(
|
||||
f'cp {TMPL_PATH}/pdb70_a3m.ffindex {TMPL_PATH}/pdb70_cs219.ffindex'
|
||||
)
|
||||
os.system(f'touch {TMPL_PATH}/pdb70_cs219.ffdata')
|
||||
template_paths[k] = TMPL_PATH
|
||||
|
||||
# gather a3m lines
|
||||
a3m_lines = {}
|
||||
for a3m_file in a3m_files:
|
||||
update_M, M = True, None
|
||||
with open(a3m_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
if len(line) > 0:
|
||||
if '\x00' in line:
|
||||
line = line.replace('\x00', '')
|
||||
update_M = True
|
||||
if line.startswith('>') and update_M:
|
||||
M = int(line[1:].rstrip())
|
||||
update_M = False
|
||||
if M not in a3m_lines:
|
||||
a3m_lines[M] = []
|
||||
a3m_lines[M].append(line)
|
||||
|
||||
# return results
|
||||
|
||||
a3m_lines = [''.join(a3m_lines[n]) for n in Ms]
|
||||
|
||||
if use_templates:
|
||||
template_paths_ = []
|
||||
for n in Ms:
|
||||
if n not in template_paths:
|
||||
template_paths_.append(None)
|
||||
# print(f"{n-N}\tno_templates_found")
|
||||
else:
|
||||
template_paths_.append(template_paths[n])
|
||||
template_paths = template_paths_
|
||||
|
||||
return (a3m_lines, template_paths) if use_templates else a3m_lines
|
||||
|
||||
|
||||
def get_null_template(query_sequence: Union[List[str], str],
|
||||
num_temp: int = 1) -> Dict[str, Any]:
|
||||
ln = (
|
||||
len(query_sequence) if isinstance(query_sequence, str) else sum(
|
||||
len(s) for s in query_sequence))
|
||||
output_templates_sequence = 'A' * ln
|
||||
# output_confidence_scores = np.full(ln, 1.0)
|
||||
|
||||
templates_all_atom_positions = np.zeros(
|
||||
(ln, templates.residue_constants.atom_type_num, 3))
|
||||
templates_all_atom_masks = np.zeros(
|
||||
(ln, templates.residue_constants.atom_type_num))
|
||||
templates_aatype = templates.residue_constants.sequence_to_onehot(
|
||||
output_templates_sequence,
|
||||
templates.residue_constants.HHBLITS_AA_TO_ID)
|
||||
template_features = {
|
||||
'template_all_atom_positions':
|
||||
np.tile(templates_all_atom_positions[None], [num_temp, 1, 1, 1]),
|
||||
'template_all_atom_masks':
|
||||
np.tile(templates_all_atom_masks[None], [num_temp, 1, 1]),
|
||||
'template_sequence': ['none'.encode()] * num_temp,
|
||||
'template_aatype':
|
||||
np.tile(np.array(templates_aatype)[None], [num_temp, 1, 1]),
|
||||
'template_domain_names': ['none'.encode()] * num_temp,
|
||||
'template_sum_probs':
|
||||
np.zeros([num_temp], dtype=np.float32),
|
||||
}
|
||||
return template_features
|
||||
|
||||
|
||||
def get_template(a3m_lines: str, template_path: str,
|
||||
query_sequence: str) -> Dict[str, Any]:
|
||||
template_featurizer = templates.HhsearchHitFeaturizer(
|
||||
mmcif_dir=template_path,
|
||||
max_template_date='2100-01-01',
|
||||
max_hits=20,
|
||||
kalign_binary_path='kalign',
|
||||
release_dates_path=None,
|
||||
obsolete_pdbs_path=None,
|
||||
)
|
||||
|
||||
hhsearch_pdb70_runner = hhsearch.HHSearch(
|
||||
binary_path='hhsearch', databases=[f'{template_path}/pdb70'])
|
||||
|
||||
hhsearch_result = hhsearch_pdb70_runner.query(a3m_lines)
|
||||
hhsearch_hits = pipeline.parsers.parse_hhr(hhsearch_result)
|
||||
templates_result = template_featurizer.get_templates(
|
||||
query_sequence=query_sequence, hits=hhsearch_hits)
|
||||
return dict(templates_result.features)
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.science, module_name=Preprocessors.unifold_preprocessor)
|
||||
class UniFoldPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, **cfg):
|
||||
self.symmetry_group = cfg['symmetry_group'] # "C1"
|
||||
if not self.symmetry_group:
|
||||
self.symmetry_group = None
|
||||
self.MIN_SINGLE_SEQUENCE_LENGTH = 16 # TODO: change to cfg
|
||||
self.MAX_SINGLE_SEQUENCE_LENGTH = 1000
|
||||
self.MAX_MULTIMER_LENGTH = 1000
|
||||
self.jobname = 'unifold'
|
||||
self.output_dir_base = './unifold-predictions'
|
||||
os.makedirs(self.output_dir_base, exist_ok=True)
|
||||
|
||||
def clean_and_validate_sequence(self, input_sequence: str, min_length: int,
|
||||
max_length: int) -> str:
|
||||
clean_sequence = input_sequence.translate(
|
||||
str.maketrans('', '', ' \n\t')).upper()
|
||||
aatypes = set(residue_constants.restypes) # 20 standard aatypes.
|
||||
if not set(clean_sequence).issubset(aatypes):
|
||||
raise ValueError(
|
||||
f'Input sequence contains non-amino acid letters: '
|
||||
f'{set(clean_sequence) - aatypes}. AlphaFold only supports 20 standard '
|
||||
'amino acids as inputs.')
|
||||
if len(clean_sequence) < min_length:
|
||||
raise ValueError(
|
||||
f'Input sequence is too short: {len(clean_sequence)} amino acids, '
|
||||
f'while the minimum is {min_length}')
|
||||
if len(clean_sequence) > max_length:
|
||||
raise ValueError(
|
||||
f'Input sequence is too long: {len(clean_sequence)} amino acids, while '
|
||||
f'the maximum is {max_length}. You may be able to run it with the full '
|
||||
f'Uni-Fold system depending on your resources (system memory, '
|
||||
f'GPU memory).')
|
||||
return clean_sequence
|
||||
|
||||
def validate_input(self, input_sequences: Sequence[str],
|
||||
symmetry_group: str, min_length: int, max_length: int,
|
||||
max_multimer_length: int) -> Tuple[Sequence[str], bool]:
|
||||
"""Validates and cleans input sequences and determines which model to use."""
|
||||
sequences = []
|
||||
|
||||
for input_sequence in input_sequences:
|
||||
if input_sequence.strip():
|
||||
input_sequence = self.clean_and_validate_sequence(
|
||||
input_sequence=input_sequence,
|
||||
min_length=min_length,
|
||||
max_length=max_length)
|
||||
sequences.append(input_sequence)
|
||||
|
||||
if symmetry_group is not None and symmetry_group != 'C1':
|
||||
if symmetry_group.startswith(
|
||||
'C') and symmetry_group[1:].isnumeric():
|
||||
print(
|
||||
f'Using UF-Symmetry with group {symmetry_group}. If you do not '
|
||||
f'want to use UF-Symmetry, please use `C1` and copy the AU '
|
||||
f'sequences to the count in the assembly.')
|
||||
is_multimer = (len(sequences) > 1)
|
||||
return sequences, is_multimer, symmetry_group
|
||||
else:
|
||||
raise ValueError(
|
||||
f'UF-Symmetry does not support symmetry group '
|
||||
f'{symmetry_group} currently. Cyclic groups (Cx) are '
|
||||
f'supported only.')
|
||||
|
||||
elif len(sequences) == 1:
|
||||
print('Using the single-chain model.')
|
||||
return sequences, False, None
|
||||
|
||||
elif len(sequences) > 1:
|
||||
total_multimer_length = sum([len(seq) for seq in sequences])
|
||||
if total_multimer_length > max_multimer_length:
|
||||
raise ValueError(
|
||||
f'The total length of multimer sequences is too long: '
|
||||
f'{total_multimer_length}, while the maximum is '
|
||||
f'{max_multimer_length}. Please use the full AlphaFold '
|
||||
f'system for long multimers.')
|
||||
print(f'Using the multimer model with {len(sequences)} sequences.')
|
||||
return sequences, True, None
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
'No input amino acid sequence provided, please provide at '
|
||||
'least one sequence.')
|
||||
|
||||
def add_hash(self, x, y):
|
||||
return x + '_' + hashlib.sha1(y.encode()).hexdigest()[:5]
|
||||
|
||||
def get_msa_and_templates(
|
||||
self,
|
||||
jobname: str,
|
||||
query_seqs_unique: Union[str, List[str]],
|
||||
result_dir: Path,
|
||||
msa_mode: str,
|
||||
use_templates: bool,
|
||||
homooligomers_num: int = 1,
|
||||
host_url: str = DEFAULT_API_SERVER,
|
||||
) -> Tuple[Optional[List[str]], Optional[List[str]], List[str], List[int],
|
||||
List[Dict[str, Any]]]:
|
||||
|
||||
use_env = msa_mode == 'MMseqs2'
|
||||
|
||||
template_features = []
|
||||
if use_templates:
|
||||
a3m_lines_mmseqs2, template_paths = run_mmseqs2(
|
||||
query_seqs_unique,
|
||||
str(result_dir.joinpath(jobname)),
|
||||
use_env,
|
||||
use_templates=True,
|
||||
host_url=host_url,
|
||||
)
|
||||
if template_paths is None:
|
||||
for index in range(0, len(query_seqs_unique)):
|
||||
template_feature = get_null_template(
|
||||
query_seqs_unique[index])
|
||||
template_features.append(template_feature)
|
||||
else:
|
||||
for index in range(0, len(query_seqs_unique)):
|
||||
if template_paths[index] is not None:
|
||||
template_feature = get_template(
|
||||
a3m_lines_mmseqs2[index],
|
||||
template_paths[index],
|
||||
query_seqs_unique[index],
|
||||
)
|
||||
if len(template_feature['template_domain_names']) == 0:
|
||||
template_feature = get_null_template(
|
||||
query_seqs_unique[index])
|
||||
else:
|
||||
template_feature = get_null_template(
|
||||
query_seqs_unique[index])
|
||||
template_features.append(template_feature)
|
||||
else:
|
||||
for index in range(0, len(query_seqs_unique)):
|
||||
template_feature = get_null_template(query_seqs_unique[index])
|
||||
template_features.append(template_feature)
|
||||
|
||||
if msa_mode == 'single_sequence':
|
||||
a3m_lines = []
|
||||
num = 101
|
||||
for i, seq in enumerate(query_seqs_unique):
|
||||
a3m_lines.append('>' + str(num + i) + '\n' + seq)
|
||||
else:
|
||||
# find normal a3ms
|
||||
a3m_lines = run_mmseqs2(
|
||||
query_seqs_unique,
|
||||
str(result_dir.joinpath(jobname)),
|
||||
use_env,
|
||||
use_pairing=False,
|
||||
host_url=host_url,
|
||||
)
|
||||
if len(query_seqs_unique) > 1:
|
||||
# find paired a3m if not a homooligomers
|
||||
paired_a3m_lines = run_mmseqs2(
|
||||
query_seqs_unique,
|
||||
str(result_dir.joinpath(jobname)),
|
||||
use_env,
|
||||
use_pairing=True,
|
||||
host_url=host_url,
|
||||
)
|
||||
else:
|
||||
num = 101
|
||||
paired_a3m_lines = []
|
||||
for i in range(0, homooligomers_num):
|
||||
paired_a3m_lines.append('>' + str(num + i) + '\n'
|
||||
+ query_seqs_unique[0] + '\n')
|
||||
|
||||
return (
|
||||
a3m_lines,
|
||||
paired_a3m_lines,
|
||||
template_features,
|
||||
)
|
||||
|
||||
def __call__(self, data: Union[str, Tuple]):
|
||||
if isinstance(data, str):
|
||||
data = [data, '', '', '']
|
||||
basejobname = ''.join(data)
|
||||
basejobname = re.sub(r'\W+', '', basejobname)
|
||||
target_id = self.add_hash(self.jobname, basejobname)
|
||||
|
||||
sequences, is_multimer, _ = self.validate_input(
|
||||
input_sequences=data,
|
||||
symmetry_group=self.symmetry_group,
|
||||
min_length=self.MIN_SINGLE_SEQUENCE_LENGTH,
|
||||
max_length=self.MAX_SINGLE_SEQUENCE_LENGTH,
|
||||
max_multimer_length=self.MAX_MULTIMER_LENGTH)
|
||||
|
||||
descriptions = [
|
||||
'> ' + target_id + ' seq' + str(ii)
|
||||
for ii in range(len(sequences))
|
||||
]
|
||||
|
||||
if is_multimer:
|
||||
divide_multi_chains(target_id, self.output_dir_base, sequences,
|
||||
descriptions)
|
||||
|
||||
s = []
|
||||
for des, seq in zip(descriptions, sequences):
|
||||
s += [des, seq]
|
||||
|
||||
unique_sequences = []
|
||||
[
|
||||
unique_sequences.append(x) for x in sequences
|
||||
if x not in unique_sequences
|
||||
]
|
||||
|
||||
if len(unique_sequences) == 1:
|
||||
homooligomers_num = len(sequences)
|
||||
else:
|
||||
homooligomers_num = 1
|
||||
|
||||
with open(f'{self.jobname}.fasta', 'w') as f:
|
||||
f.write('\n'.join(s))
|
||||
|
||||
result_dir = Path(self.output_dir_base)
|
||||
output_dir = os.path.join(self.output_dir_base, target_id)
|
||||
|
||||
# msa_mode = 'single_sequence'
|
||||
msa_mode = 'MMseqs2'
|
||||
use_templates = True
|
||||
|
||||
unpaired_msa, paired_msa, template_results = self.get_msa_and_templates(
|
||||
target_id,
|
||||
unique_sequences,
|
||||
result_dir=result_dir,
|
||||
msa_mode=msa_mode,
|
||||
use_templates=use_templates,
|
||||
homooligomers_num=homooligomers_num)
|
||||
|
||||
features = []
|
||||
pair_features = []
|
||||
|
||||
for idx, seq in enumerate(unique_sequences):
|
||||
chain_id = PDB_CHAIN_IDS[idx]
|
||||
sequence_features = pipeline.make_sequence_features(
|
||||
sequence=seq,
|
||||
description=f'> {self.jobname} seq {chain_id}',
|
||||
num_res=len(seq))
|
||||
monomer_msa = parsers.parse_a3m(unpaired_msa[idx])
|
||||
msa_features = pipeline.make_msa_features([monomer_msa])
|
||||
template_features = template_results[idx]
|
||||
feature_dict = {
|
||||
**sequence_features,
|
||||
**msa_features,
|
||||
**template_features
|
||||
}
|
||||
feature_dict = compress_features(feature_dict)
|
||||
features_output_path = os.path.join(
|
||||
output_dir, '{}.feature.pkl.gz'.format(chain_id))
|
||||
pickle.dump(
|
||||
feature_dict,
|
||||
gzip.GzipFile(features_output_path, 'wb'),
|
||||
protocol=4)
|
||||
features.append(feature_dict)
|
||||
|
||||
if is_multimer:
|
||||
multimer_msa = parsers.parse_a3m(paired_msa[idx])
|
||||
pair_features = pipeline.make_msa_features([multimer_msa])
|
||||
pair_feature_dict = compress_features(pair_features)
|
||||
uniprot_output_path = os.path.join(
|
||||
output_dir, '{}.uniprot.pkl.gz'.format(chain_id))
|
||||
pickle.dump(
|
||||
pair_feature_dict,
|
||||
gzip.GzipFile(uniprot_output_path, 'wb'),
|
||||
protocol=4,
|
||||
)
|
||||
pair_features.append(pair_feature_dict)
|
||||
|
||||
# return features, pair_features, target_id
|
||||
return {
|
||||
'features': features,
|
||||
'pair_features': pair_features,
|
||||
'target_id': target_id,
|
||||
'is_multimer': is_multimer,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
proc = UniFoldPreprocessor()
|
||||
protein_example = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVC' + \
|
||||
'TVNHRFYDPESKLWKSVCPHPGSGISFLKKYDYLLSEEGEKLQITEIKTFTTKQPVFIYHIQVENNHNFFANGVLAHAMQVSI'
|
||||
features, pair_features = proc.__call__(protein_example)
|
||||
import ipdb
|
||||
ipdb.set_trace()
|
||||
@@ -69,7 +69,7 @@ class CheckpointHook(Hook):
|
||||
self.rng_state = meta.get('rng_state')
|
||||
self.need_load_rng_state = True
|
||||
|
||||
def before_train_epoch(self, trainer):
|
||||
def before_train_iter(self, trainer):
|
||||
if self.need_load_rng_state:
|
||||
if self.rng_state is not None:
|
||||
random.setstate(self.rng_state['random'])
|
||||
@@ -84,13 +84,6 @@ class CheckpointHook(Hook):
|
||||
'this may cause a random data order or model initialization.'
|
||||
)
|
||||
|
||||
self.rng_state = {
|
||||
'random': random.getstate(),
|
||||
'numpy': np.random.get_state(),
|
||||
'cpu': torch.random.get_rng_state(),
|
||||
'cuda': torch.cuda.get_rng_state_all(),
|
||||
}
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
if not self.by_epoch:
|
||||
return
|
||||
@@ -142,6 +135,12 @@ class CheckpointHook(Hook):
|
||||
cur_save_name = os.path.join(
|
||||
self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth')
|
||||
|
||||
self.rng_state = {
|
||||
'random': random.getstate(),
|
||||
'numpy': np.random.get_state(),
|
||||
'cpu': torch.random.get_rng_state(),
|
||||
'cuda': torch.cuda.get_rng_state_all(),
|
||||
}
|
||||
meta = {
|
||||
'epoch': trainer.epoch,
|
||||
'iter': trainer.iter + 1,
|
||||
|
||||
@@ -354,6 +354,9 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
task_dataset.trainer = self
|
||||
return task_dataset
|
||||
else:
|
||||
if task_data_config is None:
|
||||
# adapt to some special models
|
||||
task_data_config = {}
|
||||
# avoid add no str value datasets, preprocessors in cfg
|
||||
task_data_build_config = ConfigDict(
|
||||
type=self.cfg.model.type,
|
||||
@@ -419,13 +422,17 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
return metrics
|
||||
|
||||
def set_checkpoint_file_to_hook(self, checkpoint_path):
|
||||
if checkpoint_path is not None and os.path.isfile(checkpoint_path):
|
||||
from modelscope.trainers.hooks import CheckpointHook
|
||||
checkpoint_hooks = list(
|
||||
filter(lambda hook: isinstance(hook, CheckpointHook),
|
||||
self.hooks))
|
||||
for hook in checkpoint_hooks:
|
||||
hook.checkpoint_file = checkpoint_path
|
||||
if checkpoint_path is not None:
|
||||
if os.path.isfile(checkpoint_path):
|
||||
from modelscope.trainers.hooks import CheckpointHook
|
||||
checkpoint_hooks = list(
|
||||
filter(lambda hook: isinstance(hook, CheckpointHook),
|
||||
self.hooks))
|
||||
for hook in checkpoint_hooks:
|
||||
hook.checkpoint_file = checkpoint_path
|
||||
else:
|
||||
self.logger.error(
|
||||
f'No {checkpoint_path} found in local file system.')
|
||||
|
||||
def train(self, checkpoint_path=None, *args, **kwargs):
|
||||
self._mode = ModeKeys.TRAIN
|
||||
|
||||
@@ -9,6 +9,7 @@ class Fields(object):
|
||||
nlp = 'nlp'
|
||||
audio = 'audio'
|
||||
multi_modal = 'multi-modal'
|
||||
science = 'science'
|
||||
|
||||
|
||||
class CVTasks(object):
|
||||
@@ -151,6 +152,10 @@ class MultiModalTasks(object):
|
||||
image_text_retrieval = 'image-text-retrieval'
|
||||
|
||||
|
||||
class ScienceTasks(object):
|
||||
protein_structure = 'protein-structure'
|
||||
|
||||
|
||||
class TasksIODescriptions(object):
|
||||
image_to_image = 'image_to_image',
|
||||
images_to_image = 'images_to_image',
|
||||
@@ -167,7 +172,7 @@ class TasksIODescriptions(object):
|
||||
generative_multi_modal_embedding = 'generative_multi_modal_embedding'
|
||||
|
||||
|
||||
class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks):
|
||||
class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks, ScienceTasks):
|
||||
""" Names for tasks supported by modelscope.
|
||||
|
||||
Holds the standard task name to use for identifying different tasks.
|
||||
@@ -196,6 +201,10 @@ class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks):
|
||||
getattr(Tasks, attr) for attr in dir(MultiModalTasks)
|
||||
if not attr.startswith('__')
|
||||
],
|
||||
Fields.science: [
|
||||
getattr(Tasks, attr) for attr in dir(ScienceTasks)
|
||||
if not attr.startswith('__')
|
||||
],
|
||||
}
|
||||
|
||||
for field, tasks in field_dict.items():
|
||||
|
||||
6
requirements/science.txt
Normal file
6
requirements/science.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
iopath
|
||||
lmdb
|
||||
ml_collections
|
||||
scipy
|
||||
tensorboardX
|
||||
tokenizers
|
||||
@@ -83,7 +83,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
# bert
|
||||
language = 'zh'
|
||||
model_dir = snapshot_download(self.model_id_bert, revision='beta')
|
||||
model_dir = snapshot_download(self.model_id_bert)
|
||||
preprocessor = NLPPreprocessor(
|
||||
model_dir, first_sequence='sentence', second_sequence=None)
|
||||
model = Model.from_pretrained(model_dir)
|
||||
@@ -149,10 +149,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
# Bert
|
||||
language = 'zh'
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.fill_mask,
|
||||
model=self.model_id_bert,
|
||||
model_revision='beta')
|
||||
pipeline_ins = pipeline(task=Tasks.fill_mask, model=self.model_id_bert)
|
||||
print(
|
||||
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: '
|
||||
f'{pipeline_ins(self.test_inputs[language])}\n')
|
||||
|
||||
@@ -24,10 +24,10 @@ class SentimentClassificationTaskModelTest(unittest.TestCase,
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_direct_file_download(self):
|
||||
cache_path = snapshot_download(self.model_id, revision='beta')
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
tokenizer = SequenceClassificationPreprocessor(cache_path)
|
||||
model = SequenceClassificationModel.from_pretrained(
|
||||
self.model_id, num_labels=2, revision='beta')
|
||||
self.model_id, num_labels=2)
|
||||
pipeline1 = TextClassificationPipeline(model, preprocessor=tokenizer)
|
||||
pipeline2 = pipeline(
|
||||
Tasks.text_classification, model=model, preprocessor=tokenizer)
|
||||
@@ -38,7 +38,7 @@ class SentimentClassificationTaskModelTest(unittest.TestCase,
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
model = Model.from_pretrained(self.model_id, revision='beta')
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
tokenizer = SequenceClassificationPreprocessor(model.model_dir)
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.text_classification,
|
||||
@@ -51,17 +51,14 @@ class SentimentClassificationTaskModelTest(unittest.TestCase,
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_name(self):
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.text_classification,
|
||||
model=self.model_id,
|
||||
model_revision='beta')
|
||||
task=Tasks.text_classification, model=self.model_id)
|
||||
print(pipeline_ins(input=self.sentence1))
|
||||
self.assertTrue(
|
||||
isinstance(pipeline_ins.model, SequenceClassificationModel))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_default_model(self):
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.text_classification, model_revision='beta')
|
||||
pipeline_ins = pipeline(task=Tasks.text_classification)
|
||||
print(pipeline_ins(input=self.sentence1))
|
||||
self.assertTrue(
|
||||
isinstance(pipeline_ins.model, SequenceClassificationModel))
|
||||
|
||||
34
tests/pipelines/test_unifold.py
Normal file
34
tests/pipelines/test_unifold.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class UnifoldProteinStructureTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.protein_structure
|
||||
self.model_id = 'DPTech/uni-fold-monomer'
|
||||
self.model_id_multimer = 'DPTech/uni-fold-multimer'
|
||||
|
||||
self.protein = 'MGLPKKALKESQLQFLTAGTAVSDSSHQTYKVSFIENGVIKNAFYKKLDPKNHYPELLAKISVAVSLFKRIFQGRRSAEERLVFDD'
|
||||
self.protein_multimer = 'GAMGLPEEPSSPQESTLKALSLYEAHLSSYIMYLQTFLVKTKQKVNNKNYPEFTLFDTSKLKKDQTLKSIKT' + \
|
||||
'NIAALKNHIDKIKPIAMQIYKKYSKNIP'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_by_direct_model_download(self):
|
||||
model_dir = snapshot_download(self.model_id)
|
||||
mono_pipeline_ins = pipeline(task=self.task, model=model_dir)
|
||||
_ = mono_pipeline_ins(self.protein)
|
||||
|
||||
model_dir1 = snapshot_download(self.model_id_multimer)
|
||||
multi_pipeline_ins = pipeline(task=self.task, model=model_dir1)
|
||||
_ = multi_pipeline_ins(self.protein_multimer)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -63,6 +63,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
|
||||
def test_finetune_msmarco(self):
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
neg_sample = 4
|
||||
cfg.task = 'text-ranking'
|
||||
cfg['preprocessor'] = {'type': 'text-ranking'}
|
||||
cfg.train.optimizer.lr = 2e-5
|
||||
@@ -73,7 +74,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
|
||||
'pos_sequence': 'positive_passages',
|
||||
'neg_sequence': 'negative_passages',
|
||||
'text_fileds': ['title', 'text'],
|
||||
'qid_field': 'query_id'
|
||||
'qid_field': 'query_id',
|
||||
'neg_sample': neg_sample
|
||||
},
|
||||
'val': {
|
||||
'type': 'bert',
|
||||
@@ -84,7 +86,6 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
|
||||
'qid_field': 'query_id'
|
||||
},
|
||||
}
|
||||
cfg['train']['neg_samples'] = 4
|
||||
cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30
|
||||
cfg.train.max_epochs = 1
|
||||
cfg.train.train_batch_size = 4
|
||||
@@ -96,6 +97,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
|
||||
'by_epoch': False
|
||||
}
|
||||
}
|
||||
cfg.model['neg_sample'] = 4
|
||||
cfg.train.hooks = [{
|
||||
'type': 'CheckpointHook',
|
||||
'interval': 1
|
||||
@@ -151,7 +153,6 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
|
||||
'qid_field': 'query_id'
|
||||
},
|
||||
}
|
||||
cfg['train']['neg_samples'] = 4
|
||||
cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30
|
||||
cfg.train.max_epochs = 1
|
||||
cfg.train.train_batch_size = 4
|
||||
@@ -180,9 +181,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
|
||||
|
||||
# load dataset
|
||||
ds = MsDataset.load('dureader-retrieval-ranking', 'zyznull')
|
||||
train_ds = ds['train'].to_hf_dataset()
|
||||
train_ds = ds['train'].to_hf_dataset().shard(1000, index=0)
|
||||
dev_ds = ds['dev'].to_hf_dataset()
|
||||
|
||||
model_id = 'damo/nlp_rom_passage-ranking_chinese-base'
|
||||
self.finetune(
|
||||
model_id=model_id,
|
||||
|
||||
@@ -37,13 +37,12 @@ class TestTrainerWithNlp(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer(self):
|
||||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base'
|
||||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny'
|
||||
kwargs = dict(
|
||||
model=model_id,
|
||||
train_dataset=self.dataset,
|
||||
eval_dataset=self.dataset,
|
||||
work_dir=self.tmp_dir,
|
||||
model_revision='beta')
|
||||
work_dir=self.tmp_dir)
|
||||
|
||||
trainer = build_trainer(default_args=kwargs)
|
||||
trainer.train()
|
||||
@@ -80,8 +79,7 @@ class TestTrainerWithNlp(unittest.TestCase):
|
||||
model=model_id,
|
||||
train_dataset=self.dataset,
|
||||
eval_dataset=self.dataset,
|
||||
work_dir=self.tmp_dir,
|
||||
model_revision='beta')
|
||||
work_dir=self.tmp_dir)
|
||||
|
||||
trainer = build_trainer(default_args=kwargs)
|
||||
trainer.train()
|
||||
@@ -97,7 +95,7 @@ class TestTrainerWithNlp(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_trainer_with_user_defined_config(self):
|
||||
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base'
|
||||
cfg = read_config(model_id, revision='beta')
|
||||
cfg = read_config(model_id)
|
||||
cfg.train.max_epochs = 20
|
||||
cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1}
|
||||
cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1}
|
||||
@@ -108,8 +106,7 @@ class TestTrainerWithNlp(unittest.TestCase):
|
||||
model=model_id,
|
||||
train_dataset=self.dataset,
|
||||
eval_dataset=self.dataset,
|
||||
cfg_file=cfg_file,
|
||||
model_revision='beta')
|
||||
cfg_file=cfg_file)
|
||||
|
||||
trainer = build_trainer(default_args=kwargs)
|
||||
trainer.train()
|
||||
@@ -233,7 +230,7 @@ class TestTrainerWithNlp(unittest.TestCase):
|
||||
os.makedirs(tmp_dir)
|
||||
|
||||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base'
|
||||
cache_path = snapshot_download(model_id, revision='beta')
|
||||
cache_path = snapshot_download(model_id)
|
||||
model = SbertForSequenceClassification.from_pretrained(cache_path)
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
|
||||
|
||||
Reference in New Issue
Block a user