Spaces:
Build error
Build error
| import torch | |
| import transformers | |
| from transformers import PreTrainedTokenizerFast | |
| import tranception | |
| import datasets | |
| from tranception import config, model_pytorch | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import gradio as gr | |
| tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tranception/utils/tokenizers/Basic_tokenizer", | |
| unk_token="[UNK]", | |
| sep_token="[SEP]", | |
| pad_token="[PAD]", | |
| cls_token="[CLS]", | |
| mask_token="[MASK]" | |
| ) | |
| ####################################################################################################################################### | |
| ############################################### HELPER FUNCTIONS #################################################################### | |
| ####################################################################################################################################### | |
| AA_vocab = "ACDEFGHIKLMNPQRSTVWY" | |
| def create_all_single_mutants(sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None): | |
| all_single_mutants={} | |
| sequence_list=list(sequence) | |
| if mutation_range_start is None: mutation_range_start=1 | |
| if mutation_range_end is None: mutation_range_end=len(sequence) | |
| for position,current_AA in enumerate(sequence[mutation_range_start-1:mutation_range_end]): | |
| for mutated_AA in AA_vocab: | |
| if current_AA!=mutated_AA: | |
| mutated_sequence = sequence_list.copy() | |
| mutated_sequence[position] = mutated_AA | |
| all_single_mutants[current_AA+str(position+1)+mutated_AA]="".join(mutated_sequence) | |
| all_single_mutants = pd.DataFrame.from_dict(all_single_mutants,columns=['mutated_sequence'],orient='index') | |
| all_single_mutants.reset_index(inplace=True) | |
| all_single_mutants.columns = ['mutant','mutated_sequence'] | |
| return all_single_mutants | |
| def create_scoring_matrix_visual(scores,sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None): | |
| piv=scores.pivot(index='position',columns='target_AA',values='avg_score').transpose().round(4) | |
| fig, ax = plt.subplots(figsize=(len(sequence)*1.2,20)) | |
| scores_dict = {} | |
| valid_mutant_set=set(scores.mutant) | |
| if mutation_range_start is None: mutation_range_start=1 | |
| if mutation_range_end is None: mutation_range_start=len(sequence) | |
| for target_AA in list(AA_vocab): | |
| for position in range(mutation_range_start,mutation_range_end+1): | |
| mutant = sequence[position-1]+str(position)+target_AA | |
| if mutant in valid_mutant_set: | |
| scores_dict[mutant]= float(scores.loc[scores.mutant==mutant,'avg_score']) | |
| else: | |
| scores_dict[mutant]=0.0 | |
| labels = (np.asarray(["{} \n {:.4f}".format(symb,value) for symb, value in scores_dict.items() ])).reshape(len(AA_vocab),mutation_range_end-mutation_range_start+1) | |
| heat = sns.heatmap(piv,annot=labels,fmt="",cmap='RdYlGn',linewidths=0.30,vmin=np.percentile(scores.avg_score,2),vmax=np.percentile(scores.avg_score,98),\ | |
| cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'}) | |
| heat.figure.axes[-1].yaxis.label.set_size(20) | |
| #heat.set_title("Fitness scores for all single amino acid substitutions",fontsize=30) | |
| heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=30, pad=40) | |
| heat.set_xlabel("Sequence position", fontsize = 20) | |
| heat.set_ylabel("Amino Acid mutation", fontsize = 20) | |
| plt.savefig('fitness_scoring_substitution_matrix.png') | |
| return plt | |
| def suggest_mutations(scores): | |
| intro_message = "The following mutations may be sensible options to improve fitness: \n\n" | |
| #Best mutants | |
| top_mutants=list(scores.sort_values(by=['avg_score'],ascending=False).head(5).mutant) | |
| mutant_recos = "The 5 single mutants with highest predicted fitness are:\n {} \n\n".format(", ".join(top_mutants)) | |
| #Best positions | |
| positive_scores = scores[scores.avg_score > 0] | |
| positive_scores_position_avg = positive_scores.groupby(['position']).mean() | |
| top_positions=list(positive_scores_position_avg.sort_values(by=['avg_score'],ascending=False).head(5).index.astype(str)) | |
| print(top_positions) | |
| position_recos = "The 5 positions with the highest average fitness increase are:\n {}".format(", ".join(top_positions)) | |
| return intro_message+mutant_recos+position_recos | |
| def get_mutated_protein(sequence,mutant): | |
| mutated_sequence = list(sequence) | |
| mutated_sequence[int(mutant[1:-1])-1]=mutant[-1] | |
| return ''.join(mutated_sequence) | |
| def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutation_range_end=None,model_type="Small",scoring_mirror=False,batch_size_inference=20,num_workers=0,AA_vocab=AA_vocab): | |
| if model_type=="Small": | |
| model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Small",use_auth_token=True) | |
| elif model_type=="Medium": | |
| model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Medium",use_auth_token=True) | |
| elif model_type=="Large": | |
| model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Large",use_auth_token=True) | |
| model.config.tokenizer = tokenizer | |
| all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end) | |
| scores = model.score_mutants(DMS_data=all_single_mutants, | |
| target_seq=sequence, | |
| scoring_mirror=scoring_mirror, | |
| batch_size_inference=batch_size_inference, | |
| num_workers=num_workers, | |
| indel_mode=False | |
| ) | |
| scores = pd.merge(scores,all_single_mutants,on="mutated_sequence",how="left") | |
| scores["position"]=scores["mutant"].map(lambda x: int(x[1:-1])) | |
| scores["target_AA"] = scores["mutant"].map(lambda x: x[-1]) | |
| score_heatmap = create_scoring_matrix_visual(scores,sequence,AA_vocab,mutation_range_start,mutation_range_end) | |
| return score_heatmap,suggest_mutations(scores) | |
| ####################################################################################################################################### | |
| ############################################### GRADIO INTERFACE #################################################################### | |
| ####################################################################################################################################### | |
| title = "Interactive in silico directed evolution with Tranception" | |
| description = "Perform in silico directed evolution with Tranception to iteratively improve the fitness of a starting protein sequence one mutation at a time. At each step, the Tranception model computes the log likelihood ratios of all possible single amino acid substitution Vs the starting sequence, and outputs a fitness heatmap and recommandations to guide the selection of the mutation to apply. Note: The current version does not currently leverage homologs retrieval at inference time to boost fitness prediction performance." | |
| article = "<p style='text-align: center'><a href='https://proceedings.mlr.press/v162/notin22a.html' target='_blank'>Tranception: Protein Fitness Prediction with Autoregressive Transformers and Inference-time Retrieval</a></p>" | |
| examples=[ | |
| ['A4_HUMAN: MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMNVQNGKWDSDPSGTKTCIDTKEGILQYCQEVYPELQITNVVEANQPVTIQNWCKRGRKQCKTHPHFVIPYRCLVGEFVSDALLVPDKCKFLHQERMDVCETHLHWHTVAKETCSEKSTNLHDYGMLLPCGIDKFRGVEFVCCPLAEESDNVDSADAEEDDSDVWWGGADTDYADGSEDKVVEVAEEEEVAEVEEEEADDDEDDEDGDEVEEEAEEPYEEATERTTSIATTTTTTTESVEEVVREVCSEQAETGPCRAMISRWYFDVTEGKCAPFFYGGCGGNRNNFDTEEYCMAVCGSAMSQSLLKTTQEPLARDPVKLPTTAASTPDAVDKYLETPGDENEHAHFQKAKERLEAKHRERMSQVMREWEEAERQAKNLPKADKKAVIQHFQEKVESLEQEAANERQQLVETHMARVEAMLNDRRRLALENYITALQAVPPRPRHVFNMLKKYVRAEQKDRQHTLKHFEHVRMVDPKKAAQIRSQVMTHLRVIYERMNQSLSLLYNVPAVAEEIQDEVDELLQKEQNYSDDVLANMISEPRISYGNDALMPSLTETKTTVELLPVNGEFSLDDLQPWHSFGADSVPANTENEVEPVDARPAADRGLTTRPGSGLTNIKTEEISEVKMDAEFRHDSGYEVHHQKLVFFAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHHGVVEVDAAVTPEERHLSKMQQNGYENPTYKFFEQMQN'], | |
| ['ADRB2_HUMAN: MGQPGNGSAFLLAPNGSHAPDHDVTQERDEVWVVGMGIVMSLIVLAIVFGNVLVITAIAKFERLQTVTNYFITSLACADLVMGLAVVPFGAAHILMKMWTFGNFWCEFWTSIDVLCVTASIETLCVIAVDRYFAITSPFKYQSLLTKNKARVIILMVWIVSGLTSFLPIQMHWYRATHQEAINCYANETCCDFFTNQAYAIASSIVSFYVPLVIMVFVYSRVFQEAKRQLQKIDKSEGRFHVQNLSQVEQDGRTGHGLRRSSKFCLKEHKALKTLGIIMGTFTLCWLPFFIVNIVHVIQDNLIRKEVYILLNWIGYVNSGFNPLIYCRSPDFRIAFQELLCLRRSSLKAYGNGYSSNGNTGEQSGYHVEQEKENKLLCEDLPGTEDFVGHQGTVPSDNIDSQGRNCSTNDSLL'], | |
| ['AMIE_PSEAE: MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETEIFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNYPEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQLSLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEGLEKEA'], | |
| ['P53_HUMAN: MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPRVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD'] | |
| ] | |
| model_size_selection = gr.Radio(label="Tranception model size", choices=["Small","Medium","Large"], value="Small") | |
| protein_sequence_input = gr.Textbox(lines=1, label="Input protein sequence (see below for examples; default = RL40A_YEAST)",value="MQIFVKTLTGKTITLEVESSDTIDNVKSKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGGIIEPSLKALASKYNCDKSVCRKCYARLPPRATNCRKRKCGHTNQLRPKKKLK") | |
| mutation_range_start = gr.Number(label="Start of mutation range (min value = 1)",value=1,precision=0) | |
| mutation_range_end = gr.Number(label="End of mutation range (leave empty for full lenth)",value=10,precision=0) | |
| scoring_mirror = gr.Checkbox(label="Score protein from both directions (leads to more robust fitness predictions, but doubles inference time)") | |
| #output ==> find a way to make scroallable | |
| output_plot = gr.Plot(label="Fitness scores for all single amino acid substitutions in mutation range") | |
| output_recommendations = gr.Textbox(label="Mutation recommendations") | |
| gr.Interface( | |
| fn=score_and_create_matrix_all_singles, | |
| inputs=[protein_sequence_input,mutation_range_start,mutation_range_end,model_size_selection,scoring_mirror], | |
| outputs=["plot","text"], | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=examples, | |
| enable_queue=True, | |
| allow_flagging="never" | |
| ).launch(debug=True) | |