From 93f2c523032a36dfdbe9104db44409e47bf0f4c4 Mon Sep 17 00:00:00 2001 From: "yuanzheng.yuanzhen" Date: Tue, 13 Dec 2022 14:18:39 +0800 Subject: [PATCH] fix multimer input for science/protein_structure Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10608748 --- .../science/protein_structure_pipeline.py | 14 +++++++++++++- modelscope/preprocessors/science/uni_fold.py | 10 ++++++---- tests/pipelines/test_unifold.py | 10 +++++----- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/modelscope/pipelines/science/protein_structure_pipeline.py b/modelscope/pipelines/science/protein_structure_pipeline.py index e326f50b..f5056c9a 100644 --- a/modelscope/pipelines/science/protein_structure_pipeline.py +++ b/modelscope/pipelines/science/protein_structure_pipeline.py @@ -76,7 +76,19 @@ def load_feature_for_one_target( uniprot_msa_dir=uniprot_msa_dir, ) else: - raise NotImplementedError + # Not for unifold-symmetry + # only for unifold-multimer + batch, _ = load_and_process( + config=config.data, + mode='predict', + seed=seed, + batch_idx=None, + data_idx=0, + is_distillation=False, + sequence_ids=sequence_ids, + monomer_feature_dir=data_folder, + uniprot_msa_dir=uniprot_msa_dir, + ) batch = UnifoldDataset.collater([batch]) return batch diff --git a/modelscope/preprocessors/science/uni_fold.py b/modelscope/preprocessors/science/uni_fold.py index ae72433c..c8400dfb 100644 --- a/modelscope/preprocessors/science/uni_fold.py +++ b/modelscope/preprocessors/science/uni_fold.py @@ -458,7 +458,9 @@ class UniFoldPreprocessor(Preprocessor): def __call__(self, data: Union[str, Tuple]): if isinstance(data, str): - data = [data, '', '', ''] + data = data.strip().split() + if len(data) < 4: + data = data + [''] * (4 - len(data)) basejobname = ''.join(data) basejobname = re.sub(r'\W+', '', basejobname) target_id = self.add_hash(self.jobname, basejobname) @@ -513,7 +515,7 @@ class UniFoldPreprocessor(Preprocessor): homooligomers_num=homooligomers_num) features = [] - pair_features = [] + pair_features_list = [] for idx, seq in enumerate(unique_sequences): chain_id = PDB_CHAIN_IDS[idx] @@ -549,12 +551,12 @@ class UniFoldPreprocessor(Preprocessor): gzip.GzipFile(uniprot_output_path, 'wb'), protocol=4, ) - pair_features.append(pair_feature_dict) + pair_features_list.append(pair_feature_dict) # return features, pair_features, target_id return { 'features': features, - 'pair_features': pair_features, + 'pair_features': pair_features_list, 'target_id': target_id, 'is_multimer': is_multimer, } diff --git a/tests/pipelines/test_unifold.py b/tests/pipelines/test_unifold.py index 47bb7874..22e29cb2 100644 --- a/tests/pipelines/test_unifold.py +++ b/tests/pipelines/test_unifold.py @@ -17,18 +17,18 @@ class UnifoldProteinStructureTest(unittest.TestCase, DemoCompatibilityCheck): self.protein = 'MGLPKKALKESQLQFLTAGTAVSDSSHQTYKVSFIENGVIKNAFYKKLDPKNHYPELLAKISVAVSLFKRIFQGRRSAEERLVFDD' self.protein_multimer = 'GAMGLPEEPSSPQESTLKALSLYEAHLSSYIMYLQTFLVKTKQKVNNKNYPEFTLFDTSKLKKDQTLKSIKT' + \ - 'NIAALKNHIDKIKPIAMQIYKKYSKNIP' + 'NIAALKNHIDKIKPIAMQIYKKYSKNIP NIAALKNHIDKIKPIAMQIYKKYSKNIP' @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_by_direct_model_download(self): - model_dir = snapshot_download(self.model_id) - mono_pipeline_ins = pipeline(task=self.task, model=model_dir) - _ = mono_pipeline_ins(self.protein) - model_dir1 = snapshot_download(self.model_id_multimer) multi_pipeline_ins = pipeline(task=self.task, model=model_dir1) _ = multi_pipeline_ins(self.protein_multimer) + model_dir = snapshot_download(self.model_id) + mono_pipeline_ins = pipeline(task=self.task, model=model_dir) + _ = mono_pipeline_ins(self.protein) + if __name__ == '__main__': unittest.main()