import torch import pytorch_lightning as pl from modules.bindevaluator_modules import * from transformers import AutoModelWithLMHead, AutoTokenizer, EsmModel from flow_matching.path import MixtureDiscreteProbPath from flow_matching.path.scheduler import PolynomialConvexScheduler from flow_matching.solver import MixtureDiscreteEulerSolver from flow_matching.utils import ModelWrapper from flow_matching.loss import MixturePathGeneralizedKL from models.peptide_models import CNNModel from modules.bindevaluator_modules import * # from models.uaa_models import * import sys sys.path.append('./PeptiVerse') from inference import PeptiVersePredictor pred = PeptiVersePredictor( manifest_path="./PeptiVerse/best_models.txt", # best model list classifier_weight_root="./PeptiVerse/", # repo root (where training_classifiers/ lives) device="cuda", # or "cpu" ) class HemolysisWT: def __call__(self, input_seqs): scores = [] for seq in input_seqs: score = pred.predict_property("hemolysis", col="wt", input_str=seq)['score'] scores.append(1 - score) return torch.tensor(scores) class NonfoulingWT: def __call__(self, input_seqs): scores = [] for seq in input_seqs: score = pred.predict_property("nf", col="wt", input_str=seq)['score'] scores.append(score) return torch.tensor(scores) # class Solubility: # def __init__(self): # self.hydrophobic = list("AVLIMFWPavilmfwpŶƘṂŁĊ") # def __call__(self, aa_seqs: list): # scores = [] # for seq in aa_seqs: # if len(seq) == 0: # scores.append(0) # continue # score = len([tok for tok in seq if tok not in self.hydrophobic]) / len(seq) # scores.append(score) # return torch.tensor(scores) class Solubility: def __call__(self, input_seqs): scores = [] for seq in input_seqs: score = pred.predict_property("solubility", col="wt", input_str=seq)['score'] scores.append(score) return torch.tensor(scores) class PermeabilityWT: def __call__(self, input_seqs): scores = [] for seq in input_seqs: score = pred.predict_property("permeability_penetrance", col="wt", input_str=seq)['score'] scores.append(score) return torch.tensor(scores) class HalfLifeWT: def __call__(self, input_seqs): scores = [] for seq in input_seqs: score = pred.predict_property("halflife", col="wt", input_str=seq)['score'] scores.append(score) return torch.tensor(scores) class AffinityWT: def __init__(self, target): self.target = target def __call__(self, input_seqs): scores = [] for seq in input_seqs: score = pred.predict_binding_affinity(col="wt", target_seq=self.target, binder_str=seq)['affinity'] scores.append(score / 10) return torch.tensor(scores) def parse_motifs(motif: str) -> list: parts = motif.split(',') result = [] for part in parts: part = part.strip() if '-' in part: start, end = map(int, part.split('-')) result.extend(range(start, end + 1)) else: result.append(int(part)) # result = [pos-1 for pos in result] return torch.tensor(result) class BindEvaluatorWT(pl.LightningModule): def __init__(self, n_layers, d_model, d_hidden, n_head, d_k, d_v, d_inner, dropout=0.2, learning_rate=0.00001, max_epochs=15, kl_weight=1): super(BindEvaluatorWT, self).__init__() self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") self.esm_model.eval() # freeze all the esm_model parameters for param in self.esm_model.parameters(): param.requires_grad = False self.repeated_module = RepeatedModule3(n_layers, d_model, d_hidden, n_head, d_k, d_v, d_inner, dropout=dropout) self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v, dropout=dropout) self.final_ffn = FFN(d_model, d_inner, dropout=dropout) self.output_projection_prot = nn.Linear(d_model, 1) self.learning_rate = learning_rate self.max_epochs = max_epochs self.kl_weight = kl_weight self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold self.historical_memory = 0.9 self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights def forward(self, binder_tokens, target_tokens): peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state protein_sequence = self.esm_model(**target_tokens).last_hidden_state prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence, protein_sequence) prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc) prot_enc = self.final_ffn(prot_enc) prot_enc = self.output_projection_prot(prot_enc) return prot_enc def get_probs(self, x_t, target_sequence): ''' Inputs: - xt: Shape (bsz, seq_len) - target_sequence: Shape (1, tgt_len) ''' # pdb.set_trace() target_sequence = target_sequence.repeat(x_t.shape[0], 1) binder_attention_mask = torch.ones_like(x_t) target_attention_mask = torch.ones_like(target_sequence) binder_attention_mask[:, 0] = binder_attention_mask[:, -1] = 0 target_attention_mask[:, 0] = target_attention_mask[:, -1] = 0 binder_tokens = {'input_ids': x_t, 'attention_mask': binder_attention_mask.to(x_t.device)} target_tokens = {'input_ids': target_sequence, 'attention_mask': target_attention_mask.to(target_sequence.device)} logits = self.forward(binder_tokens, target_tokens).squeeze(-1) logits[:, 0] = logits[:, -1] = -100 # float('-inf') probs = torch.sigmoid(logits) return probs # shape (bsz, tgt_len) def motif_score(self, x_t, target_sequence, motifs): probs = self.get_probs(x_t, target_sequence) motif_probs = probs[:, motifs] motif_score = motif_probs.sum(dim=-1) / len(motifs) # pdb.set_trace() return motif_score def non_motif_score(self, x_t, target_sequence, motifs): probs = self.get_probs(x_t, target_sequence) non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]] mask = non_motif_probs >= 0.5 count = mask.sum(dim=-1) non_motif_score = torch.where(count > 0, (non_motif_probs * mask).sum(dim=-1) / count, torch.zeros_like(count)) return non_motif_score def scoring(self, x_t, target_sequence, motifs, penalty=False): probs = self.get_probs(x_t, target_sequence) motif_probs = probs[:, motifs] motif_score = motif_probs.sum(dim=-1) / len(motifs) # pdb.set_trace() if penalty: non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]] mask = non_motif_probs >= 0.5 count = mask.sum(dim=-1) # non_motif_score = 1 - torch.where(count > 0, (non_motif_probs * mask).sum(dim=-1) / count, torch.zeros_like(count)) non_motif_score = count / target_sequence.shape[1] return motif_score, 1 - non_motif_score else: return motif_score class MotifModelWT(nn.Module): def __init__(self, bindevaluator, target_sequence, motifs, tokenizer, device, penalty=False): super(MotifModelWT, self).__init__() self.bindevaluator = bindevaluator self.target_sequence = target_sequence self.motifs = motifs self.penalty = penalty self.tokenizer = tokenizer self.device = device def forward(self, input_seqs): x = self.tokenizer(input_seqs, return_tensors='pt')['input_ids'].to(self.device) return self.bindevaluator.scoring(x, self.target_sequence, self.motifs, self.penalty) def load_bindevaluator(checkpoint_path, device): bindevaluator = BindEvaluatorWT.load_from_checkpoint(checkpoint_path, weights_only=False, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64).to(device) bindevaluator.eval() for param in bindevaluator.parameters(): param.requires_grad = False return bindevaluator def load_solver(checkpoint_path, vocab_size, device): lr = 1e-4 epochs = 200 embed_dim = 512 hidden_dim = 256 epsilon = 1e-3 batch_size = 256 warmup_epochs = epochs // 10 device = 'cuda:0' probability_denoiser = CNNModel(alphabet_size=vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim).to(device) probability_denoiser.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=False)) probability_denoiser.eval() for param in probability_denoiser.parameters(): param.requires_grad = False # instantiate a convex path object scheduler = PolynomialConvexScheduler(n=2.0) path = MixtureDiscreteProbPath(scheduler=scheduler) class WrappedModel(ModelWrapper): def forward(self, x: torch.Tensor, t: torch.Tensor, **extras): return torch.softmax(self.model(x, t), dim=-1) wrapped_probability_denoiser = WrappedModel(probability_denoiser) solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=path, vocabulary_size=vocab_size) return solver # def load_uaa_solver(checkpoint_path, vocab_size, device): # checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) # model_args = checkpoint['args'] # probability_denoiser = MDLM( # vocab_size=model_args.vocab_size, # seq_len=model_args.seq_len, # model_dim=model_args.model_dim, # n_heads=model_args.n_heads, # n_layers=model_args.n_layers # ).to(device) # probability_denoiser.load_state_dict(checkpoint['model_state_dict']) # probability_denoiser.eval() # for param in probability_denoiser.parameters(): # param.requires_grad = False # # instantiate a convex path object # scheduler = PolynomialConvexScheduler(n=2.0) # path = MixtureDiscreteProbPath(scheduler=scheduler) # class WrappedModel(ModelWrapper): # def forward(self, x: torch.Tensor, t: torch.Tensor, **extras): # return torch.softmax(self.model(x, t), dim=-1) # wrapped_probability_denoiser = WrappedModel(probability_denoiser) # solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=path, vocabulary_size=vocab_size) # return solver class HemolysisSMILES: def __call__(self, smiles_seqs): scores = [] for seq in smiles_seqs: score = pred.predict_property("hemolysis", col="smiles", input_str=seq)['score'] scores.append(score) return torch.tensor(scores) class NonfoulingSMILES: def __call__(self, smiles_seqs): scores = [] for seq in smiles_seqs: score = pred.predict_property("nf", col="smiles", input_str=seq)['score'] scores.append(score) return torch.tensor(scores) class PermeabilitySMILES: def __call__(self, smiles_seqs): scores = [] for seq in smiles_seqs: score = pred.predict_property("permeability_pampa", col="smiles", input_str=seq)['score'] score = (score + 9) / (-4 + 9) scores.append(score) return torch.tensor(scores) class HalfLifeSMILES: def __call__(self, smiles_seqs): scores = [] for seq in smiles_seqs: score = pred.predict_property("halflife", col="smiles", input_str=seq)['score'] scores.append(score) return torch.tensor(scores) class ToxicitySMILES: def __call__(self, smiles_seqs): scores = [] for seq in smiles_seqs: score = pred.predict_property("toxicity", col="smiles", input_str=seq)['score'] scores.append(score) return torch.tensor(scores) class AffinitySMILES: def __init__(self, target): self.target = target def __call__(self, smiles_seqs): scores = [] for seq in smiles_seqs: score = pred.predict_binding_affinity(col="smiles", target_seq=self.target, binder_str=seq)['affinity'] scores.append(score / 10) return torch.tensor(scores) class BindEvaluatorSMILES(pl.LightningModule): def __init__(self, cfg): super(BindEvaluatorSMILES, self).__init__() self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").eval() for param in self.esm_model.parameters(): param.requires_grad = False self.chemberta_model = AutoModelWithLMHead.from_pretrained("seyonec/ChemBERTa_zinc250k_v2_40k").roberta.eval() for param in self.chemberta_model.parameters(): param.requires_grad = False self.repeated_module = RepeatedModule(cfg.model.n_layers, cfg.model.d_model, cfg.model.d_hidden, cfg.model.n_head, cfg.model.d_k, cfg.model.d_v, cfg.model.d_inner, dropout=cfg.model.dropout) self.final_attention_layer = MultiHeadAttentionSequence(cfg.model.n_head, cfg.model.d_model, cfg.model.d_k, cfg.model.d_v, dropout=cfg.model.dropout) self.final_ffn = FFN(cfg.model.d_model, cfg.model.d_inner, dropout=cfg.model.dropout) self.output_projection_prot = nn.Linear(cfg.model.d_model, 1) def forward(self, binder_tokens, target_tokens): peptide_sequence = self.chemberta_model(**binder_tokens).last_hidden_state protein_sequence = self.esm_model(**target_tokens).last_hidden_state binder_mask = binder_tokens["attention_mask"] # [B, Ls] target_mask = target_tokens["attention_mask"] # [B, Lp] prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ prot_seq_attention_list, seq_prot_attention_list = self.repeated_module( peptide_sequence, protein_sequence, peptide_mask=binder_mask, protein_mask=target_mask, ) # final cross-attention: protein queries attend to binder keys prot_enc, final_prot_seq_attention = self.final_attention_layer( prot_enc, sequence_enc, sequence_enc, key_padding_mask=binder_mask, query_padding_mask=target_mask, ) prot_enc = self.final_ffn(prot_enc, padding_mask=target_mask) prot_enc = self.output_projection_prot(prot_enc) return prot_enc class MotifModelSMILES: def __init__(self, cfg, target, motifs, device, specificity): self.cfg = cfg self.threshold = 0.918 self.device = device self.specificity = specificity self.chemberta_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa_zinc250k_v2_40k") self.esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") self.target = self.esm_tokenizer(target, return_tensors='pt').to(device) self.motifs = parse_motifs(motifs).to(device) self.bindevaluator = BindEvaluatorSMILES.load_from_checkpoint(cfg.inference.ckpt, cfg=cfg, map_location=device) def __call__(self, smiles_seqs): L = self.target['input_ids'].shape[1] motif_scores = [] specificity_scores = [] for seq in smiles_seqs: binder = self.chemberta_tokenizer(seq, return_tensors='pt').to(self.device) prediction = self.bindevaluator(binder, self.target).squeeze(-1) # pdb.set_trace() probs = torch.sigmoid(prediction).squeeze(0) # (1, L) motif_score = probs[self.motifs].mean() motif_scores.append(motif_score) if self.specificity: non_motif_probs = probs[[i for i in range(probs.shape[0]) if i not in self.motifs]] mask = non_motif_probs >= self.threshold count = mask.sum() specificity = 1 - count / (L-2) specificity_scores.append(specificity) if self.specificity: return torch.tensor(motif_scores), torch.tensor(specificity_scores) else: return torch.tensor(motif_scores)