Spaces:
Runtime error
Runtime error
| # Copyright (c) Guangsheng Bao. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os.path | |
| import numpy as np | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| import re | |
| import torch | |
| import tqdm | |
| import argparse | |
| import json | |
| from data_builder import load_data, save_data | |
| from metrics import get_roc_metrics, get_precision_recall_metrics | |
| from model import load_tokenizer, load_model, get_model_fullname, from_pretrained | |
| # define regex to match all <extra_id_*> tokens, where * is an integer | |
| pattern = re.compile(r"<extra_id_\d+>") | |
| def load_mask_model(model_name, device, cache_dir): | |
| model_name = get_model_fullname(model_name) | |
| # mask filling t5 model | |
| print(f'Loading mask filling model {model_name}...') | |
| mask_model = from_pretrained(AutoModelForSeq2SeqLM, model_name, {}, cache_dir) | |
| mask_model = mask_model.to(device) | |
| return mask_model | |
| def load_mask_tokenizer(model_name, max_length, cache_dir): | |
| model_name = get_model_fullname(model_name) | |
| tokenizer = from_pretrained(AutoTokenizer, model_name, {'model_max_length': max_length}, cache_dir) | |
| return tokenizer | |
| def tokenize_and_mask(text, span_length, pct, ceil_pct=False): | |
| buffer_size = 1 | |
| tokens = text.split(' ') | |
| mask_string = '<<<mask>>>' | |
| n_spans = pct * len(tokens) / (span_length + buffer_size * 2) | |
| if ceil_pct: | |
| n_spans = np.ceil(n_spans) | |
| n_spans = int(n_spans) | |
| n_masks = 0 | |
| while n_masks < n_spans: | |
| start = np.random.randint(0, len(tokens) - span_length) | |
| end = start + span_length | |
| search_start = max(0, start - buffer_size) | |
| search_end = min(len(tokens), end + buffer_size) | |
| if mask_string not in tokens[search_start:search_end]: | |
| tokens[start:end] = [mask_string] | |
| n_masks += 1 | |
| # replace each occurrence of mask_string with <extra_id_NUM>, where NUM increments | |
| num_filled = 0 | |
| for idx, token in enumerate(tokens): | |
| if token == mask_string: | |
| tokens[idx] = f'<extra_id_{num_filled}>' | |
| num_filled += 1 | |
| assert num_filled == n_masks, f"num_filled {num_filled} != n_masks {n_masks}" | |
| text = ' '.join(tokens) | |
| return text | |
| def count_masks(texts): | |
| return [len([x for x in text.split() if x.startswith("<extra_id_")]) for text in texts] | |
| # replace each masked span with a sample from T5 mask_model | |
| def replace_masks(args, mask_model, mask_tokenizer, texts): | |
| n_expected = count_masks(texts) | |
| stop_id = mask_tokenizer.encode(f"<extra_id_{max(n_expected)}>")[0] | |
| tokens = mask_tokenizer(texts, return_tensors="pt", padding=True).to(args.device) | |
| outputs = mask_model.generate(**tokens, max_length=150, do_sample=True, top_p=args.mask_top_p, | |
| num_return_sequences=1, eos_token_id=stop_id) | |
| return mask_tokenizer.batch_decode(outputs, skip_special_tokens=False) | |
| def extract_fills(texts): | |
| # remove <pad> from beginning of each text | |
| texts = [x.replace("<pad>", "").replace("</s>", "").strip() for x in texts] | |
| # return the text in between each matched mask token | |
| extracted_fills = [pattern.split(x)[1:-1] for x in texts] | |
| # remove whitespace around each fill | |
| extracted_fills = [[y.strip() for y in x] for x in extracted_fills] | |
| return extracted_fills | |
| def apply_extracted_fills(masked_texts, extracted_fills): | |
| # split masked text into tokens, only splitting on spaces (not newlines) | |
| tokens = [x.split(' ') for x in masked_texts] | |
| n_expected = count_masks(masked_texts) | |
| # replace each mask token with the corresponding fill | |
| for idx, (text, fills, n) in enumerate(zip(tokens, extracted_fills, n_expected)): | |
| if len(fills) < n: | |
| tokens[idx] = [] | |
| else: | |
| for fill_idx in range(n): | |
| text[text.index(f"<extra_id_{fill_idx}>")] = fills[fill_idx] | |
| # join tokens back into text | |
| texts = [" ".join(x) for x in tokens] | |
| return texts | |
| def perturb_texts_(args, mask_model, mask_tokenizer, texts, ceil_pct=False): | |
| span_length = args.span_length | |
| pct = args.pct_words_masked | |
| masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts] | |
| raw_fills = replace_masks(args, mask_model, mask_tokenizer, masked_texts) | |
| extracted_fills = extract_fills(raw_fills) | |
| perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills) | |
| # Handle the fact that sometimes the model doesn't generate the right number of fills and we have to try again | |
| attempts = 1 | |
| while '' in perturbed_texts: | |
| idxs = [idx for idx, x in enumerate(perturbed_texts) if x == ''] | |
| print(f'WARNING: {len(idxs)} texts have no fills. Trying again [attempt {attempts}].') | |
| masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for idx, x in enumerate(texts) if idx in idxs] | |
| raw_fills = replace_masks(args, mask_model, mask_tokenizer, masked_texts) | |
| extracted_fills = extract_fills(raw_fills) | |
| new_perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills) | |
| for idx, x in zip(idxs, new_perturbed_texts): | |
| perturbed_texts[idx] = x | |
| attempts += 1 | |
| return perturbed_texts | |
| def perturb_texts(args, mask_model, mask_tokenizer, texts, ceil_pct=False): | |
| chunk_size = 10 | |
| outputs = [] | |
| for i in range(0, len(texts), chunk_size): | |
| outputs.extend(perturb_texts_(args, mask_model, mask_tokenizer, texts[i:i + chunk_size], ceil_pct=ceil_pct)) | |
| return outputs | |
| # Get the log likelihood of each text under the base_model | |
| def get_ll(args, scoring_model, scoring_tokenizer, text): | |
| with torch.no_grad(): | |
| tokenized = scoring_tokenizer(text, return_tensors="pt", return_token_type_ids=False).to(args.device) | |
| labels = tokenized.input_ids | |
| return -scoring_model(**tokenized, labels=labels).loss.item() | |
| def get_lls(args, scoring_model, scoring_tokenizer, texts): | |
| return [get_ll(args, scoring_model, scoring_tokenizer, text) for text in texts] | |
| def generate_perturbs(args): | |
| n_perturbations = args.n_perturbations | |
| name = f'perturbation_{n_perturbations}' | |
| # load model | |
| mask_model = load_mask_model(args.mask_filling_model_name, args.device, args.cache_dir) | |
| mask_model.eval() | |
| try: | |
| n_positions = mask_model.config.n_positions | |
| except AttributeError: | |
| n_positions = 512 | |
| mask_tokenizer = load_mask_tokenizer(args.mask_filling_model_name, n_positions, args.cache_dir) | |
| # load data | |
| data = load_data(args.dataset_file) | |
| n_samples = len(data["sampled"]) | |
| torch.manual_seed(args.seed) | |
| np.random.seed(args.seed) | |
| # generate perturb samples | |
| perturbs = [] | |
| for idx in tqdm.tqdm(range(n_samples), desc=f"Perturb text"): | |
| original_text = data["original"][idx] | |
| sampled_text = data["sampled"][idx] | |
| # perturb | |
| p_sampled_text = perturb_texts(args, mask_model, mask_tokenizer, [sampled_text for _ in range(n_perturbations)]) | |
| p_original_text = perturb_texts(args, mask_model, mask_tokenizer, [original_text for _ in range(n_perturbations)]) | |
| assert len(p_sampled_text) == n_perturbations, f"Expected {n_perturbations} perturbed samples, got {len(p_sampled_text)}" | |
| assert len(p_original_text) == n_perturbations, f"Expected {n_perturbations} perturbed samples, got {len(p_original_text)}" | |
| # result | |
| perturbs.append({ | |
| "original": original_text, | |
| "sampled": sampled_text, | |
| "perturbed_sampled": p_sampled_text, | |
| "perturbed_original": p_original_text | |
| }) | |
| save_data(f'{args.dataset_file}.{args.mask_filling_model_name}.{name}', args, perturbs) | |
| def experiment(args): | |
| n_perturbations = args.n_perturbations | |
| name = f'perturbation_{n_perturbations}' | |
| perturb_file = f'{args.dataset_file}.{args.mask_filling_model_name}.{name}.raw_data.json' | |
| if os.path.exists(perturb_file): | |
| print(f'Use existing perturbation file: {perturb_file}') | |
| else: | |
| generate_perturbs(args) | |
| # load model | |
| scoring_tokenizer = load_tokenizer(args.scoring_model_name, args.dataset, args.cache_dir) | |
| scoring_model = load_model(args.scoring_model_name, 'cpu', args.cache_dir) | |
| scoring_model.eval() | |
| scoring_model.to(args.device) | |
| # load data | |
| data = load_data(f'{args.dataset_file}.{args.mask_filling_model_name}.{name}') | |
| n_samples = len(data) | |
| torch.manual_seed(args.seed) | |
| np.random.seed(args.seed) | |
| # Evaluate | |
| results = data | |
| for idx in tqdm.tqdm(range(n_samples), desc=f"Computing {name} criterion"): | |
| original_text = results[idx]["original"] | |
| sampled_text = results[idx]["sampled"] | |
| perturbed_original = results[idx]["perturbed_original"] | |
| perturbed_sampled = results[idx]["perturbed_sampled"] | |
| # original text | |
| original_ll = get_ll(args, scoring_model, scoring_tokenizer, original_text) | |
| p_original_ll = get_lls(args, scoring_model, scoring_tokenizer, perturbed_original) | |
| # sampled text | |
| sampled_ll = get_ll(args, scoring_model, scoring_tokenizer, sampled_text) | |
| p_sampled_ll = get_lls(args, scoring_model, scoring_tokenizer, perturbed_sampled) | |
| # result | |
| results[idx]["original_ll"] = original_ll | |
| results[idx]["sampled_ll"] = sampled_ll | |
| results[idx]["all_perturbed_sampled_ll"] = p_sampled_ll | |
| results[idx]["all_perturbed_original_ll"] = p_original_ll | |
| results[idx]["perturbed_sampled_ll"] = np.mean(p_sampled_ll) | |
| results[idx]["perturbed_original_ll"] = np.mean(p_original_ll) | |
| results[idx]["perturbed_sampled_ll_std"] = np.std(p_sampled_ll) if len(p_sampled_ll) > 1 else 1 | |
| results[idx]["perturbed_original_ll_std"] = np.std(p_original_ll) if len(p_original_ll) > 1 else 1 | |
| # compute diffs with perturbed | |
| predictions = {'real': [], 'samples': []} | |
| for res in results: | |
| if res['perturbed_original_ll_std'] == 0: | |
| res['perturbed_original_ll_std'] = 1 | |
| print("WARNING: std of perturbed original is 0, setting to 1") | |
| print(f"Number of unique perturbed original texts: {len(set(res['perturbed_original']))}") | |
| print(f"Original text: {res['original']}") | |
| if res['perturbed_sampled_ll_std'] == 0: | |
| res['perturbed_sampled_ll_std'] = 1 | |
| print("WARNING: std of perturbed sampled is 0, setting to 1") | |
| print(f"Number of unique perturbed sampled texts: {len(set(res['perturbed_sampled']))}") | |
| print(f"Sampled text: {res['sampled']}") | |
| predictions['real'].append((res['original_ll'] - res['perturbed_original_ll']) / res['perturbed_original_ll_std']) | |
| predictions['samples'].append((res['sampled_ll'] - res['perturbed_sampled_ll']) / res['perturbed_sampled_ll_std']) | |
| print(f"Real mean/std: {np.mean(predictions['real']):.2f}/{np.std(predictions['real']):.2f}, Samples mean/std: {np.mean(predictions['samples']):.2f}/{np.std(predictions['samples']):.2f}") | |
| fpr, tpr, roc_auc = get_roc_metrics(predictions['real'], predictions['samples']) | |
| p, r, pr_auc = get_precision_recall_metrics(predictions['real'], predictions['samples']) | |
| print(f"Criterion {name}_threshold ROC AUC: {roc_auc:.4f}, PR AUC: {pr_auc:.4f}") | |
| # results | |
| results_file = f'{args.output_file}.{name}.json' | |
| results = { | |
| 'name': name, | |
| 'info': { | |
| 'pct_words_masked': args.pct_words_masked, | |
| 'span_length': args.span_length, | |
| 'n_perturbations': args.n_perturbations, | |
| 'n_samples': n_samples, | |
| }, | |
| 'predictions': predictions, | |
| 'raw_results': results, | |
| 'metrics': { | |
| 'roc_auc': roc_auc, | |
| 'fpr': fpr, | |
| 'tpr': tpr, | |
| }, | |
| 'pr_metrics': { | |
| 'pr_auc': pr_auc, | |
| 'precision': p, | |
| 'recall': r, | |
| }, | |
| 'loss': 1 - pr_auc, | |
| } | |
| with open(results_file, 'w') as fout: | |
| json.dump(results, fout) | |
| print(f'Results written into {results_file}') | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--output_file', type=str, default="./exp_test/results/xsum_gpt2") | |
| parser.add_argument('--dataset', type=str, default="xsum") | |
| parser.add_argument('--dataset_file', type=str, default="./exp_test/data/xsum_gpt2") | |
| parser.add_argument('--pct_words_masked', type=float, default=0.3) # pct masked is actually pct_words_masked * (span_length / (span_length + 2 * buffer_size)) | |
| parser.add_argument('--mask_top_p', type=float, default=1.0) | |
| parser.add_argument('--span_length', type=int, default=2) | |
| parser.add_argument('--n_perturbations', type=int, default=10) | |
| parser.add_argument('--scoring_model_name', type=str, default="gpt2") | |
| parser.add_argument('--mask_filling_model_name', type=str, default="t5-small") | |
| parser.add_argument('--seed', type=int, default=0) | |
| parser.add_argument('--device', type=str, default="cuda") | |
| parser.add_argument('--cache_dir', type=str, default="../cache") | |
| args = parser.parse_args() | |
| experiment(args) | |