Spaces:
Runtime error
Runtime error
| # Copyright (c) Guangsheng Bao. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import time | |
| import numpy as np | |
| import datasets | |
| import torch | |
| import random | |
| import argparse | |
| import os | |
| import json | |
| import custom_datasets | |
| from model import load_tokenizer, load_model | |
| def stats_str(data): | |
| if type(data) == dict: | |
| mean_orig = np.mean([len(v.split()) for v in data['original']]) | |
| mean_samp = np.mean([len(v.split()) for v in data['sampled']]) | |
| return f'{mean_orig:.0f} words (original), {mean_samp:.0f} words (sampled).' | |
| else: | |
| mean_orig = np.mean([len(v['original'].split()) for v in data]) | |
| mean_samp = np.mean([len(v['sampled'].split()) for v in data]) | |
| mean_perturb_orig = np.mean([np.mean([len(p.split()) for p in v['perturbed_original']]) for v in data]) | |
| mean_perturb_samp = np.mean([np.mean([len(p.split()) for p in v['perturbed_sampled']]) for v in data]) | |
| return f'{mean_orig:.0f} words (original), {mean_samp:.0f} words (sampled), {mean_perturb_orig:.0f} words (perturb original), {mean_perturb_samp:.0f} words (perturb sampled).' | |
| def save_data(output_file, args, data): | |
| # write args to file | |
| args_file = f"{output_file}.args.json" | |
| with open(args_file, "w") as fout: | |
| json.dump(args, fout, indent=4) | |
| print(f"Args written into {args_file}") | |
| # write the data to a json file in the save folder | |
| data_file = f"{output_file}.raw_data.json" | |
| with open(data_file, "w") as fout: | |
| json.dump(data, fout, indent=4) | |
| print(f"Raw data written into {data_file}: {stats_str(data)}") | |
| def load_data(input_file): | |
| # load args from file | |
| args_file = f"{input_file}.args.json" | |
| with open(args_file, "r") as fin: | |
| args = json.load(fin) | |
| print(f"Args loaded from {args_file}") | |
| # load the data from file | |
| data_file = f"{input_file}.raw_data.json" | |
| with open(data_file, "r") as fin: | |
| data = json.load(fin) | |
| print(f"Raw data loaded from {data_file}: {stats_str(data)}") | |
| return args, data | |
| def convert_data(input_file, output_file, max_words): | |
| def _reduce(text): | |
| lines = [] | |
| nwords = 0 | |
| for line in text.split('\n'): | |
| if nwords >= max_words: | |
| break | |
| words = line.split() | |
| words = words[:max_words - nwords] | |
| lines.append(' '.join(words)) | |
| nwords += len(words) | |
| return '\n'.join(lines) | |
| args, data = load_data(input_file) | |
| if type(data) == dict: | |
| data['original'] = [_reduce(x) for x in data['original']] | |
| data['sampled'] = [_reduce(x) for x in data['sampled']] | |
| else: | |
| for item in data: | |
| item['original'] = _reduce(item['original']) | |
| item['sampled'] = _reduce(item['sampled']) | |
| item['perturbed_original'] = [_reduce(x) for x in item['perturbed_original']] | |
| item['perturbed_sampled'] = [_reduce(x) for x in item['perturbed_sampled']] | |
| save_data(output_file, args, data) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--input_path', type=str, default="./exp_gpt3to4/data/") | |
| parser.add_argument('--output_path', type=str, default="./exp_maxlen150/data/") | |
| parser.add_argument('--max_words', type=int, default=150) | |
| args = parser.parse_args() | |
| import glob | |
| import os.path as path | |
| for file_name in glob.glob(f'{args.input_path}/*.raw_data.json'): | |
| print(file_name) | |
| file_name = path.basename(file_name).replace('.raw_data.json', '') | |
| convert_data(path.join(args.input_path, file_name), path.join(args.output_path, file_name), args.max_words) | |