fix multimer input for science/protein_structure

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10608748
This commit is contained in:
yuanzheng.yuanzhen
2022-12-13 14:18:39 +08:00
committed by wenmeng.zwm
parent 9172637ec8
commit 93f2c52303
3 changed files with 24 additions and 10 deletions

View File

@@ -76,7 +76,19 @@ def load_feature_for_one_target(
uniprot_msa_dir=uniprot_msa_dir,
)
else:
raise NotImplementedError
# Not for unifold-symmetry
# only for unifold-multimer
batch, _ = load_and_process(
config=config.data,
mode='predict',
seed=seed,
batch_idx=None,
data_idx=0,
is_distillation=False,
sequence_ids=sequence_ids,
monomer_feature_dir=data_folder,
uniprot_msa_dir=uniprot_msa_dir,
)
batch = UnifoldDataset.collater([batch])
return batch

View File

@@ -458,7 +458,9 @@ class UniFoldPreprocessor(Preprocessor):
def __call__(self, data: Union[str, Tuple]):
if isinstance(data, str):
data = [data, '', '', '']
data = data.strip().split()
if len(data) < 4:
data = data + [''] * (4 - len(data))
basejobname = ''.join(data)
basejobname = re.sub(r'\W+', '', basejobname)
target_id = self.add_hash(self.jobname, basejobname)
@@ -513,7 +515,7 @@ class UniFoldPreprocessor(Preprocessor):
homooligomers_num=homooligomers_num)
features = []
pair_features = []
pair_features_list = []
for idx, seq in enumerate(unique_sequences):
chain_id = PDB_CHAIN_IDS[idx]
@@ -549,12 +551,12 @@ class UniFoldPreprocessor(Preprocessor):
gzip.GzipFile(uniprot_output_path, 'wb'),
protocol=4,
)
pair_features.append(pair_feature_dict)
pair_features_list.append(pair_feature_dict)
# return features, pair_features, target_id
return {
'features': features,
'pair_features': pair_features,
'pair_features': pair_features_list,
'target_id': target_id,
'is_multimer': is_multimer,
}

View File

@@ -17,18 +17,18 @@ class UnifoldProteinStructureTest(unittest.TestCase, DemoCompatibilityCheck):
self.protein = 'MGLPKKALKESQLQFLTAGTAVSDSSHQTYKVSFIENGVIKNAFYKKLDPKNHYPELLAKISVAVSLFKRIFQGRRSAEERLVFDD'
self.protein_multimer = 'GAMGLPEEPSSPQESTLKALSLYEAHLSSYIMYLQTFLVKTKQKVNNKNYPEFTLFDTSKLKKDQTLKSIKT' + \
'NIAALKNHIDKIKPIAMQIYKKYSKNIP'
'NIAALKNHIDKIKPIAMQIYKKYSKNIP NIAALKNHIDKIKPIAMQIYKKYSKNIP'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_by_direct_model_download(self):
model_dir = snapshot_download(self.model_id)
mono_pipeline_ins = pipeline(task=self.task, model=model_dir)
_ = mono_pipeline_ins(self.protein)
model_dir1 = snapshot_download(self.model_id_multimer)
multi_pipeline_ins = pipeline(task=self.task, model=model_dir1)
_ = multi_pipeline_ins(self.protein_multimer)
model_dir = snapshot_download(self.model_id)
mono_pipeline_ins = pipeline(task=self.task, model=model_dir)
_ = mono_pipeline_ins(self.protein)
if __name__ == '__main__':
unittest.main()