moPPIt / models /peptiverse_classifiers.py
AlienChen's picture
Update models/peptiverse_classifiers.py
1d4edd1 verified
import torch
import pytorch_lightning as pl
from modules.bindevaluator_modules import *
from transformers import AutoModelWithLMHead, AutoTokenizer, EsmModel
from flow_matching.path import MixtureDiscreteProbPath
from flow_matching.path.scheduler import PolynomialConvexScheduler
from flow_matching.solver import MixtureDiscreteEulerSolver
from flow_matching.utils import ModelWrapper
from flow_matching.loss import MixturePathGeneralizedKL
from models.peptide_models import CNNModel
from modules.bindevaluator_modules import *
# from models.uaa_models import *
import sys
sys.path.append('./PeptiVerse')
from inference import PeptiVersePredictor
pred = PeptiVersePredictor(
manifest_path="./PeptiVerse/best_models.txt", # best model list
classifier_weight_root="./PeptiVerse/", # repo root (where training_classifiers/ lives)
device="cuda", # or "cpu"
)
class HemolysisWT:
def __call__(self, input_seqs):
scores = []
for seq in input_seqs:
score = pred.predict_property("hemolysis", col="wt", input_str=seq)['score']
scores.append(1 - score)
return torch.tensor(scores)
class NonfoulingWT:
def __call__(self, input_seqs):
scores = []
for seq in input_seqs:
score = pred.predict_property("nf", col="wt", input_str=seq)['score']
scores.append(score)
return torch.tensor(scores)
# class Solubility:
# def __init__(self):
# self.hydrophobic = list("AVLIMFWPavilmfwpŶƘṂŁĊ")
# def __call__(self, aa_seqs: list):
# scores = []
# for seq in aa_seqs:
# if len(seq) == 0:
# scores.append(0)
# continue
# score = len([tok for tok in seq if tok not in self.hydrophobic]) / len(seq)
# scores.append(score)
# return torch.tensor(scores)
class Solubility:
def __call__(self, input_seqs):
scores = []
for seq in input_seqs:
score = pred.predict_property("solubility", col="wt", input_str=seq)['score']
scores.append(score)
return torch.tensor(scores)
class PermeabilityWT:
def __call__(self, input_seqs):
scores = []
for seq in input_seqs:
score = pred.predict_property("permeability_penetrance", col="wt", input_str=seq)['score']
scores.append(score)
return torch.tensor(scores)
class HalfLifeWT:
def __call__(self, input_seqs):
scores = []
for seq in input_seqs:
score = pred.predict_property("halflife", col="wt", input_str=seq)['score']
scores.append(score)
return torch.tensor(scores)
class AffinityWT:
def __init__(self, target):
self.target = target
def __call__(self, input_seqs):
scores = []
for seq in input_seqs:
score = pred.predict_binding_affinity(col="wt", target_seq=self.target, binder_str=seq)['affinity']
scores.append(score / 10)
return torch.tensor(scores)
def parse_motifs(motif: str) -> list:
parts = motif.split(',')
result = []
for part in parts:
part = part.strip()
if '-' in part:
start, end = map(int, part.split('-'))
result.extend(range(start, end + 1))
else:
result.append(int(part))
# result = [pos-1 for pos in result]
return torch.tensor(result)
class BindEvaluatorWT(pl.LightningModule):
def __init__(self, n_layers, d_model, d_hidden, n_head,
d_k, d_v, d_inner, dropout=0.2,
learning_rate=0.00001, max_epochs=15, kl_weight=1):
super(BindEvaluatorWT, self).__init__()
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
self.esm_model.eval()
# freeze all the esm_model parameters
for param in self.esm_model.parameters():
param.requires_grad = False
self.repeated_module = RepeatedModule3(n_layers, d_model, d_hidden,
n_head, d_k, d_v, d_inner, dropout=dropout)
self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
d_k, d_v, dropout=dropout)
self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
self.output_projection_prot = nn.Linear(d_model, 1)
self.learning_rate = learning_rate
self.max_epochs = max_epochs
self.kl_weight = kl_weight
self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold
self.historical_memory = 0.9
self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights
def forward(self, binder_tokens, target_tokens):
peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
protein_sequence = self.esm_model(**target_tokens).last_hidden_state
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
protein_sequence)
prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
prot_enc = self.final_ffn(prot_enc)
prot_enc = self.output_projection_prot(prot_enc)
return prot_enc
def get_probs(self, x_t, target_sequence):
'''
Inputs:
- xt: Shape (bsz, seq_len)
- target_sequence: Shape (1, tgt_len)
'''
# pdb.set_trace()
target_sequence = target_sequence.repeat(x_t.shape[0], 1)
binder_attention_mask = torch.ones_like(x_t)
target_attention_mask = torch.ones_like(target_sequence)
binder_attention_mask[:, 0] = binder_attention_mask[:, -1] = 0
target_attention_mask[:, 0] = target_attention_mask[:, -1] = 0
binder_tokens = {'input_ids': x_t, 'attention_mask': binder_attention_mask.to(x_t.device)}
target_tokens = {'input_ids': target_sequence, 'attention_mask': target_attention_mask.to(target_sequence.device)}
logits = self.forward(binder_tokens, target_tokens).squeeze(-1)
logits[:, 0] = logits[:, -1] = -100 # float('-inf')
probs = torch.sigmoid(logits)
return probs # shape (bsz, tgt_len)
def motif_score(self, x_t, target_sequence, motifs):
probs = self.get_probs(x_t, target_sequence)
motif_probs = probs[:, motifs]
motif_score = motif_probs.sum(dim=-1) / len(motifs)
# pdb.set_trace()
return motif_score
def non_motif_score(self, x_t, target_sequence, motifs):
probs = self.get_probs(x_t, target_sequence)
non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]]
mask = non_motif_probs >= 0.5
count = mask.sum(dim=-1)
non_motif_score = torch.where(count > 0, (non_motif_probs * mask).sum(dim=-1) / count, torch.zeros_like(count))
return non_motif_score
def scoring(self, x_t, target_sequence, motifs, penalty=False):
probs = self.get_probs(x_t, target_sequence)
motif_probs = probs[:, motifs]
motif_score = motif_probs.sum(dim=-1) / len(motifs)
# pdb.set_trace()
if penalty:
non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]]
mask = non_motif_probs >= 0.5
count = mask.sum(dim=-1)
# non_motif_score = 1 - torch.where(count > 0, (non_motif_probs * mask).sum(dim=-1) / count, torch.zeros_like(count))
non_motif_score = count / target_sequence.shape[1]
return motif_score, 1 - non_motif_score
else:
return motif_score
class MotifModelWT(nn.Module):
def __init__(self, bindevaluator, target_sequence, motifs, tokenizer, device, penalty=False):
super(MotifModelWT, self).__init__()
self.bindevaluator = bindevaluator
self.target_sequence = target_sequence
self.motifs = motifs
self.penalty = penalty
self.tokenizer = tokenizer
self.device = device
def forward(self, input_seqs):
x = self.tokenizer(input_seqs, return_tensors='pt')['input_ids'].to(self.device)
return self.bindevaluator.scoring(x, self.target_sequence, self.motifs, self.penalty)
def load_bindevaluator(checkpoint_path, device):
bindevaluator = BindEvaluatorWT.load_from_checkpoint(checkpoint_path, weights_only=False, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64).to(device)
bindevaluator.eval()
for param in bindevaluator.parameters():
param.requires_grad = False
return bindevaluator
def load_solver(checkpoint_path, vocab_size, device):
lr = 1e-4
epochs = 200
embed_dim = 512
hidden_dim = 256
epsilon = 1e-3
batch_size = 256
warmup_epochs = epochs // 10
device = 'cuda:0'
probability_denoiser = CNNModel(alphabet_size=vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim).to(device)
probability_denoiser.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=False))
probability_denoiser.eval()
for param in probability_denoiser.parameters():
param.requires_grad = False
# instantiate a convex path object
scheduler = PolynomialConvexScheduler(n=2.0)
path = MixtureDiscreteProbPath(scheduler=scheduler)
class WrappedModel(ModelWrapper):
def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
return torch.softmax(self.model(x, t), dim=-1)
wrapped_probability_denoiser = WrappedModel(probability_denoiser)
solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=path, vocabulary_size=vocab_size)
return solver
# def load_uaa_solver(checkpoint_path, vocab_size, device):
# checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
# model_args = checkpoint['args']
# probability_denoiser = MDLM(
# vocab_size=model_args.vocab_size,
# seq_len=model_args.seq_len,
# model_dim=model_args.model_dim,
# n_heads=model_args.n_heads,
# n_layers=model_args.n_layers
# ).to(device)
# probability_denoiser.load_state_dict(checkpoint['model_state_dict'])
# probability_denoiser.eval()
# for param in probability_denoiser.parameters():
# param.requires_grad = False
# # instantiate a convex path object
# scheduler = PolynomialConvexScheduler(n=2.0)
# path = MixtureDiscreteProbPath(scheduler=scheduler)
# class WrappedModel(ModelWrapper):
# def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
# return torch.softmax(self.model(x, t), dim=-1)
# wrapped_probability_denoiser = WrappedModel(probability_denoiser)
# solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=path, vocabulary_size=vocab_size)
# return solver
class HemolysisSMILES:
def __call__(self, smiles_seqs):
scores = []
for seq in smiles_seqs:
score = pred.predict_property("hemolysis", col="smiles", input_str=seq)['score']
scores.append(score)
return torch.tensor(scores)
class NonfoulingSMILES:
def __call__(self, smiles_seqs):
scores = []
for seq in smiles_seqs:
score = pred.predict_property("nf", col="smiles", input_str=seq)['score']
scores.append(score)
return torch.tensor(scores)
class PermeabilitySMILES:
def __call__(self, smiles_seqs):
scores = []
for seq in smiles_seqs:
score = pred.predict_property("permeability_pampa", col="smiles", input_str=seq)['score']
score = (score + 9) / (-4 + 9)
scores.append(score)
return torch.tensor(scores)
class HalfLifeSMILES:
def __call__(self, smiles_seqs):
scores = []
for seq in smiles_seqs:
score = pred.predict_property("halflife", col="smiles", input_str=seq)['score']
scores.append(score)
return torch.tensor(scores)
class ToxicitySMILES:
def __call__(self, smiles_seqs):
scores = []
for seq in smiles_seqs:
score = pred.predict_property("toxicity", col="smiles", input_str=seq)['score']
scores.append(score)
return torch.tensor(scores)
class AffinitySMILES:
def __init__(self, target):
self.target = target
def __call__(self, smiles_seqs):
scores = []
for seq in smiles_seqs:
score = pred.predict_binding_affinity(col="smiles", target_seq=self.target, binder_str=seq)['affinity']
scores.append(score / 10)
return torch.tensor(scores)
class BindEvaluatorSMILES(pl.LightningModule):
def __init__(self, cfg):
super(BindEvaluatorSMILES, self).__init__()
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").eval()
for param in self.esm_model.parameters():
param.requires_grad = False
self.chemberta_model = AutoModelWithLMHead.from_pretrained("seyonec/ChemBERTa_zinc250k_v2_40k").roberta.eval()
for param in self.chemberta_model.parameters():
param.requires_grad = False
self.repeated_module = RepeatedModule(cfg.model.n_layers, cfg.model.d_model, cfg.model.d_hidden,
cfg.model.n_head, cfg.model.d_k, cfg.model.d_v, cfg.model.d_inner, dropout=cfg.model.dropout)
self.final_attention_layer = MultiHeadAttentionSequence(cfg.model.n_head, cfg.model.d_model,
cfg.model.d_k, cfg.model.d_v, dropout=cfg.model.dropout)
self.final_ffn = FFN(cfg.model.d_model, cfg.model.d_inner, dropout=cfg.model.dropout)
self.output_projection_prot = nn.Linear(cfg.model.d_model, 1)
def forward(self, binder_tokens, target_tokens):
peptide_sequence = self.chemberta_model(**binder_tokens).last_hidden_state
protein_sequence = self.esm_model(**target_tokens).last_hidden_state
binder_mask = binder_tokens["attention_mask"] # [B, Ls]
target_mask = target_tokens["attention_mask"] # [B, Lp]
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
prot_seq_attention_list, seq_prot_attention_list = self.repeated_module(
peptide_sequence,
protein_sequence,
peptide_mask=binder_mask,
protein_mask=target_mask,
)
# final cross-attention: protein queries attend to binder keys
prot_enc, final_prot_seq_attention = self.final_attention_layer(
prot_enc, sequence_enc, sequence_enc,
key_padding_mask=binder_mask,
query_padding_mask=target_mask,
)
prot_enc = self.final_ffn(prot_enc, padding_mask=target_mask)
prot_enc = self.output_projection_prot(prot_enc)
return prot_enc
class MotifModelSMILES:
def __init__(self, cfg, target, motifs, device, specificity):
self.cfg = cfg
self.threshold = 0.918
self.device = device
self.specificity = specificity
self.chemberta_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa_zinc250k_v2_40k")
self.esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
self.target = self.esm_tokenizer(target, return_tensors='pt').to(device)
self.motifs = parse_motifs(motifs).to(device)
self.bindevaluator = BindEvaluatorSMILES.load_from_checkpoint(cfg.inference.ckpt, cfg=cfg, map_location=device)
def __call__(self, smiles_seqs):
L = self.target['input_ids'].shape[1]
motif_scores = []
specificity_scores = []
for seq in smiles_seqs:
binder = self.chemberta_tokenizer(seq, return_tensors='pt').to(self.device)
prediction = self.bindevaluator(binder, self.target).squeeze(-1)
# pdb.set_trace()
probs = torch.sigmoid(prediction).squeeze(0) # (1, L)
motif_score = probs[self.motifs].mean()
motif_scores.append(motif_score)
if self.specificity:
non_motif_probs = probs[[i for i in range(probs.shape[0]) if i not in self.motifs]]
mask = non_motif_probs >= self.threshold
count = mask.sum()
specificity = 1 - count / (L-2)
specificity_scores.append(specificity)
if self.specificity:
return torch.tensor(motif_scores), torch.tensor(specificity_scores)
else:
return torch.tensor(motif_scores)