| import torch |
| import pytorch_lightning as pl |
| from modules.bindevaluator_modules import * |
| from transformers import AutoModelWithLMHead, AutoTokenizer, EsmModel |
|
|
| from flow_matching.path import MixtureDiscreteProbPath |
| from flow_matching.path.scheduler import PolynomialConvexScheduler |
| from flow_matching.solver import MixtureDiscreteEulerSolver |
| from flow_matching.utils import ModelWrapper |
| from flow_matching.loss import MixturePathGeneralizedKL |
|
|
| from models.peptide_models import CNNModel |
| from modules.bindevaluator_modules import * |
| |
|
|
| import sys |
| sys.path.append('./PeptiVerse') |
| from inference import PeptiVersePredictor |
|
|
| pred = PeptiVersePredictor( |
| manifest_path="./PeptiVerse/best_models.txt", |
| classifier_weight_root="./PeptiVerse/", |
| device="cuda", |
| ) |
|
|
| class HemolysisWT: |
| def __call__(self, input_seqs): |
| scores = [] |
| for seq in input_seqs: |
| score = pred.predict_property("hemolysis", col="wt", input_str=seq)['score'] |
| scores.append(1 - score) |
| |
| return torch.tensor(scores) |
| |
| class NonfoulingWT: |
| def __call__(self, input_seqs): |
| scores = [] |
| for seq in input_seqs: |
| score = pred.predict_property("nf", col="wt", input_str=seq)['score'] |
| scores.append(score) |
| |
| return torch.tensor(scores) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| class Solubility: |
| def __call__(self, input_seqs): |
| scores = [] |
| for seq in input_seqs: |
| score = pred.predict_property("solubility", col="wt", input_str=seq)['score'] |
| scores.append(score) |
| |
| return torch.tensor(scores) |
|
|
| |
| class PermeabilityWT: |
| def __call__(self, input_seqs): |
| scores = [] |
| for seq in input_seqs: |
| score = pred.predict_property("permeability_penetrance", col="wt", input_str=seq)['score'] |
| scores.append(score) |
| |
| return torch.tensor(scores) |
|
|
| class HalfLifeWT: |
| def __call__(self, input_seqs): |
| scores = [] |
| for seq in input_seqs: |
| score = pred.predict_property("halflife", col="wt", input_str=seq)['score'] |
| scores.append(score) |
| |
| return torch.tensor(scores) |
| |
| class AffinityWT: |
| def __init__(self, target): |
| self.target = target |
|
|
| def __call__(self, input_seqs): |
| scores = [] |
| for seq in input_seqs: |
| score = pred.predict_binding_affinity(col="wt", target_seq=self.target, binder_str=seq)['affinity'] |
| scores.append(score / 10) |
|
|
| return torch.tensor(scores) |
|
|
|
|
| def parse_motifs(motif: str) -> list: |
| parts = motif.split(',') |
| result = [] |
|
|
| for part in parts: |
| part = part.strip() |
| if '-' in part: |
| start, end = map(int, part.split('-')) |
| result.extend(range(start, end + 1)) |
| else: |
| result.append(int(part)) |
|
|
| |
| return torch.tensor(result) |
|
|
| |
| class BindEvaluatorWT(pl.LightningModule): |
| def __init__(self, n_layers, d_model, d_hidden, n_head, |
| d_k, d_v, d_inner, dropout=0.2, |
| learning_rate=0.00001, max_epochs=15, kl_weight=1): |
| super(BindEvaluatorWT, self).__init__() |
|
|
| self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| self.esm_model.eval() |
| |
| for param in self.esm_model.parameters(): |
| param.requires_grad = False |
|
|
| self.repeated_module = RepeatedModule3(n_layers, d_model, d_hidden, |
| n_head, d_k, d_v, d_inner, dropout=dropout) |
|
|
| self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model, |
| d_k, d_v, dropout=dropout) |
|
|
| self.final_ffn = FFN(d_model, d_inner, dropout=dropout) |
|
|
| self.output_projection_prot = nn.Linear(d_model, 1) |
|
|
| self.learning_rate = learning_rate |
| self.max_epochs = max_epochs |
| self.kl_weight = kl_weight |
|
|
| self.classification_threshold = nn.Parameter(torch.tensor(0.5)) |
| self.historical_memory = 0.9 |
| self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) |
|
|
| def forward(self, binder_tokens, target_tokens): |
| peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state |
| protein_sequence = self.esm_model(**target_tokens).last_hidden_state |
|
|
| prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ |
| seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence, |
| protein_sequence) |
|
|
| prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc) |
|
|
| prot_enc = self.final_ffn(prot_enc) |
|
|
| prot_enc = self.output_projection_prot(prot_enc) |
|
|
| return prot_enc |
|
|
| def get_probs(self, x_t, target_sequence): |
| ''' |
| Inputs: |
| - xt: Shape (bsz, seq_len) |
| - target_sequence: Shape (1, tgt_len) |
| ''' |
| |
| target_sequence = target_sequence.repeat(x_t.shape[0], 1) |
| binder_attention_mask = torch.ones_like(x_t) |
| target_attention_mask = torch.ones_like(target_sequence) |
|
|
| binder_attention_mask[:, 0] = binder_attention_mask[:, -1] = 0 |
| target_attention_mask[:, 0] = target_attention_mask[:, -1] = 0 |
|
|
| binder_tokens = {'input_ids': x_t, 'attention_mask': binder_attention_mask.to(x_t.device)} |
| target_tokens = {'input_ids': target_sequence, 'attention_mask': target_attention_mask.to(target_sequence.device)} |
| |
| logits = self.forward(binder_tokens, target_tokens).squeeze(-1) |
| logits[:, 0] = logits[:, -1] = -100 |
| probs = torch.sigmoid(logits) |
|
|
| return probs |
|
|
| def motif_score(self, x_t, target_sequence, motifs): |
| probs = self.get_probs(x_t, target_sequence) |
| motif_probs = probs[:, motifs] |
| motif_score = motif_probs.sum(dim=-1) / len(motifs) |
| |
| return motif_score |
|
|
| def non_motif_score(self, x_t, target_sequence, motifs): |
| probs = self.get_probs(x_t, target_sequence) |
| non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]] |
| mask = non_motif_probs >= 0.5 |
| count = mask.sum(dim=-1) |
|
|
| non_motif_score = torch.where(count > 0, (non_motif_probs * mask).sum(dim=-1) / count, torch.zeros_like(count)) |
|
|
| return non_motif_score |
|
|
| def scoring(self, x_t, target_sequence, motifs, penalty=False): |
| probs = self.get_probs(x_t, target_sequence) |
| motif_probs = probs[:, motifs] |
| motif_score = motif_probs.sum(dim=-1) / len(motifs) |
| |
|
|
| if penalty: |
| non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]] |
| mask = non_motif_probs >= 0.5 |
| count = mask.sum(dim=-1) |
| |
| non_motif_score = count / target_sequence.shape[1] |
| return motif_score, 1 - non_motif_score |
| else: |
| return motif_score |
|
|
| class MotifModelWT(nn.Module): |
| def __init__(self, bindevaluator, target_sequence, motifs, tokenizer, device, penalty=False): |
| super(MotifModelWT, self).__init__() |
| self.bindevaluator = bindevaluator |
| self.target_sequence = target_sequence |
| self.motifs = motifs |
| self.penalty = penalty |
|
|
| self.tokenizer = tokenizer |
| self.device = device |
| |
| def forward(self, input_seqs): |
| x = self.tokenizer(input_seqs, return_tensors='pt')['input_ids'].to(self.device) |
| return self.bindevaluator.scoring(x, self.target_sequence, self.motifs, self.penalty) |
|
|
| def load_bindevaluator(checkpoint_path, device): |
| bindevaluator = BindEvaluatorWT.load_from_checkpoint(checkpoint_path, weights_only=False, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64).to(device) |
| bindevaluator.eval() |
| for param in bindevaluator.parameters(): |
| param.requires_grad = False |
|
|
| return bindevaluator |
|
|
|
|
| def load_solver(checkpoint_path, vocab_size, device): |
| lr = 1e-4 |
| epochs = 200 |
| embed_dim = 512 |
| hidden_dim = 256 |
| epsilon = 1e-3 |
| batch_size = 256 |
| warmup_epochs = epochs // 10 |
| device = 'cuda:0' |
| |
|
|
| probability_denoiser = CNNModel(alphabet_size=vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim).to(device) |
| probability_denoiser.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=False)) |
| probability_denoiser.eval() |
| for param in probability_denoiser.parameters(): |
| param.requires_grad = False |
|
|
| |
| scheduler = PolynomialConvexScheduler(n=2.0) |
| path = MixtureDiscreteProbPath(scheduler=scheduler) |
|
|
| class WrappedModel(ModelWrapper): |
| def forward(self, x: torch.Tensor, t: torch.Tensor, **extras): |
| return torch.softmax(self.model(x, t), dim=-1) |
|
|
| wrapped_probability_denoiser = WrappedModel(probability_denoiser) |
| solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=path, vocabulary_size=vocab_size) |
|
|
| return solver |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
|
|
|
|
| class HemolysisSMILES: |
| def __call__(self, smiles_seqs): |
| scores = [] |
| for seq in smiles_seqs: |
| score = pred.predict_property("hemolysis", col="smiles", input_str=seq)['score'] |
| scores.append(score) |
| |
| return torch.tensor(scores) |
| |
| class NonfoulingSMILES: |
| def __call__(self, smiles_seqs): |
| scores = [] |
| for seq in smiles_seqs: |
| score = pred.predict_property("nf", col="smiles", input_str=seq)['score'] |
| scores.append(score) |
| |
| return torch.tensor(scores) |
| |
| class PermeabilitySMILES: |
| def __call__(self, smiles_seqs): |
| scores = [] |
| for seq in smiles_seqs: |
| score = pred.predict_property("permeability_pampa", col="smiles", input_str=seq)['score'] |
| score = (score + 9) / (-4 + 9) |
| scores.append(score) |
| |
| return torch.tensor(scores) |
|
|
| class HalfLifeSMILES: |
| def __call__(self, smiles_seqs): |
| scores = [] |
| for seq in smiles_seqs: |
| score = pred.predict_property("halflife", col="smiles", input_str=seq)['score'] |
| scores.append(score) |
| |
| return torch.tensor(scores) |
| |
| class ToxicitySMILES: |
| def __call__(self, smiles_seqs): |
| scores = [] |
| for seq in smiles_seqs: |
| score = pred.predict_property("toxicity", col="smiles", input_str=seq)['score'] |
| scores.append(score) |
| |
| return torch.tensor(scores) |
| |
| class AffinitySMILES: |
| def __init__(self, target): |
| self.target = target |
|
|
| def __call__(self, smiles_seqs): |
| scores = [] |
| for seq in smiles_seqs: |
| score = pred.predict_binding_affinity(col="smiles", target_seq=self.target, binder_str=seq)['affinity'] |
| scores.append(score / 10) |
|
|
| return torch.tensor(scores) |
|
|
| |
| class BindEvaluatorSMILES(pl.LightningModule): |
| def __init__(self, cfg): |
| super(BindEvaluatorSMILES, self).__init__() |
|
|
| self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").eval() |
| for param in self.esm_model.parameters(): |
| param.requires_grad = False |
|
|
| self.chemberta_model = AutoModelWithLMHead.from_pretrained("seyonec/ChemBERTa_zinc250k_v2_40k").roberta.eval() |
| for param in self.chemberta_model.parameters(): |
| param.requires_grad = False |
|
|
| self.repeated_module = RepeatedModule(cfg.model.n_layers, cfg.model.d_model, cfg.model.d_hidden, |
| cfg.model.n_head, cfg.model.d_k, cfg.model.d_v, cfg.model.d_inner, dropout=cfg.model.dropout) |
|
|
| self.final_attention_layer = MultiHeadAttentionSequence(cfg.model.n_head, cfg.model.d_model, |
| cfg.model.d_k, cfg.model.d_v, dropout=cfg.model.dropout) |
|
|
| self.final_ffn = FFN(cfg.model.d_model, cfg.model.d_inner, dropout=cfg.model.dropout) |
|
|
| self.output_projection_prot = nn.Linear(cfg.model.d_model, 1) |
|
|
| def forward(self, binder_tokens, target_tokens): |
| peptide_sequence = self.chemberta_model(**binder_tokens).last_hidden_state |
| protein_sequence = self.esm_model(**target_tokens).last_hidden_state |
|
|
| binder_mask = binder_tokens["attention_mask"] |
| target_mask = target_tokens["attention_mask"] |
| |
| prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ |
| prot_seq_attention_list, seq_prot_attention_list = self.repeated_module( |
| peptide_sequence, |
| protein_sequence, |
| peptide_mask=binder_mask, |
| protein_mask=target_mask, |
| ) |
| |
| |
| prot_enc, final_prot_seq_attention = self.final_attention_layer( |
| prot_enc, sequence_enc, sequence_enc, |
| key_padding_mask=binder_mask, |
| query_padding_mask=target_mask, |
| ) |
| |
| prot_enc = self.final_ffn(prot_enc, padding_mask=target_mask) |
| prot_enc = self.output_projection_prot(prot_enc) |
| return prot_enc |
| |
| class MotifModelSMILES: |
| def __init__(self, cfg, target, motifs, device, specificity): |
| self.cfg = cfg |
| self.threshold = 0.918 |
| self.device = device |
| self.specificity = specificity |
| self.chemberta_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa_zinc250k_v2_40k") |
| self.esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| self.target = self.esm_tokenizer(target, return_tensors='pt').to(device) |
| self.motifs = parse_motifs(motifs).to(device) |
| self.bindevaluator = BindEvaluatorSMILES.load_from_checkpoint(cfg.inference.ckpt, cfg=cfg, map_location=device) |
| |
| def __call__(self, smiles_seqs): |
| L = self.target['input_ids'].shape[1] |
|
|
| motif_scores = [] |
| specificity_scores = [] |
| for seq in smiles_seqs: |
| binder = self.chemberta_tokenizer(seq, return_tensors='pt').to(self.device) |
| prediction = self.bindevaluator(binder, self.target).squeeze(-1) |
| |
| probs = torch.sigmoid(prediction).squeeze(0) |
| motif_score = probs[self.motifs].mean() |
| motif_scores.append(motif_score) |
| if self.specificity: |
| non_motif_probs = probs[[i for i in range(probs.shape[0]) if i not in self.motifs]] |
| mask = non_motif_probs >= self.threshold |
| count = mask.sum() |
| specificity = 1 - count / (L-2) |
|
|
| specificity_scores.append(specificity) |
| |
| if self.specificity: |
| return torch.tensor(motif_scores), torch.tensor(specificity_scores) |
| else: |
| return torch.tensor(motif_scores) |
|
|
|
|