Spaces:
Running
Running
| import torch | |
| from torch.utils.data import DataLoader, IterableDataset | |
| from transformers import AutoTokenizer | |
| from datasets import load_dataset | |
| import random | |
| from typing import Dict, Iterator | |
| import os | |
| def get_tokenizer(model_name: str = "gpt2"): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| return tokenizer | |
| class StreamingDataset(IterableDataset): | |
| def __init__(self, dataset, tokenizer, max_seq_len, mode="pretrain", buffer_size=500, skip_samples=0): | |
| self.dataset = dataset | |
| self.tokenizer = tokenizer | |
| self.max_seq_len = max_seq_len | |
| self.mode = mode | |
| self.buffer_size = buffer_size | |
| self.skip_samples = skip_samples | |
| def _prepare_sft_text(self, example): | |
| if 'messages' in example: | |
| text = "" | |
| for msg in example['messages']: | |
| role = msg.get('role', '') | |
| content = msg.get('content', '') | |
| if role == 'assistant': | |
| text += f"Assistant: {content}{self.tokenizer.eos_token}" | |
| else: | |
| text += f"User: {content}\n" | |
| return text | |
| elif 'text' in example: | |
| return example['text'] | |
| else: | |
| return "" | |
| def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: | |
| iterator = iter(self.dataset) | |
| buffer = [] | |
| # Calculate roughly how many items to skip if they were yielded | |
| # We process skipping in the yield loop | |
| for example in iterator: | |
| text = (example.get('text', '') if self.mode == "pretrain" | |
| else self._prepare_sft_text(example)) | |
| if len(text) < 10: | |
| continue | |
| enc = self.tokenizer(text, truncation=True, max_length=self.max_seq_len, | |
| return_tensors="pt") | |
| input_ids = enc['input_ids'][0] | |
| if len(input_ids) < 2: | |
| continue | |
| if len(input_ids) < self.max_seq_len: | |
| pad_len = self.max_seq_len - len(input_ids) | |
| input_ids = torch.cat([ | |
| input_ids, | |
| torch.full((pad_len,), self.tokenizer.pad_token_id, dtype=torch.long) | |
| ]) | |
| labels = input_ids.clone() | |
| if len(input_ids) < self.max_seq_len: | |
| labels[-pad_len:] = -100 | |
| buffer.append((input_ids, labels)) | |
| if len(buffer) >= self.buffer_size: | |
| random.shuffle(buffer) | |
| for _ in range(self.buffer_size // 2): | |
| item = buffer.pop() | |
| if self.skip_samples > 0: | |
| self.skip_samples -= 1 | |
| continue | |
| yield item | |
| # Yield remaining | |
| random.shuffle(buffer) | |
| while buffer: | |
| item = buffer.pop() | |
| if self.skip_samples > 0: | |
| self.skip_samples -= 1 | |
| continue | |
| yield item | |
| def create_streaming_loader(dataset_name, split, tokenizer, config, batch_size, mode="pretrain", hf_token=None, start_step=0): | |
| raw_dataset = load_dataset(dataset_name, split=split, streaming=True, | |
| trust_remote_code=True, token=hf_token) | |
| # Calculate samples to skip: start_step * batch_size | |
| skip_samples = start_step * batch_size | |
| if skip_samples > 0: | |
| print(f" [Loader] Resuming: Fast-forwarding dataset by {skip_samples} samples...") | |
| stream_ds = StreamingDataset(raw_dataset, tokenizer, config.max_seq_len, mode=mode, skip_samples=skip_samples) | |
| # Increase num_workers for better utilization | |
| return DataLoader(stream_ds, batch_size=batch_size, pin_memory=True, | |
| num_workers=4, prefetch_factor=4) | |