File size: 3,889 Bytes
1df0e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
from torch.utils.data import DataLoader, IterableDataset
from transformers import AutoTokenizer
from datasets import load_dataset
import random
from typing import Dict, Iterator
import os

def get_tokenizer(model_name: str = "gpt2"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer

class StreamingDataset(IterableDataset):
    def __init__(self, dataset, tokenizer, max_seq_len, mode="pretrain", buffer_size=500, skip_samples=0):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.mode = mode
        self.buffer_size = buffer_size
        self.skip_samples = skip_samples

    def _prepare_sft_text(self, example):
        if 'messages' in example:
            text = ""
            for msg in example['messages']:
                role = msg.get('role', '')
                content = msg.get('content', '')
                if role == 'assistant':
                    text += f"Assistant: {content}{self.tokenizer.eos_token}"
                else:
                    text += f"User: {content}\n"
            return text
        elif 'text' in example:
            return example['text']
        else:
            return ""

    def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
        iterator = iter(self.dataset)
        buffer = []
        
        # Calculate roughly how many items to skip if they were yielded
        # We process skipping in the yield loop
        
        for example in iterator:
            text = (example.get('text', '') if self.mode == "pretrain"
                   else self._prepare_sft_text(example))

            if len(text) < 10:
                continue

            enc = self.tokenizer(text, truncation=True, max_length=self.max_seq_len,
                               return_tensors="pt")
            input_ids = enc['input_ids'][0]

            if len(input_ids) < 2:
                continue

            if len(input_ids) < self.max_seq_len:
                pad_len = self.max_seq_len - len(input_ids)
                input_ids = torch.cat([
                    input_ids,
                    torch.full((pad_len,), self.tokenizer.pad_token_id, dtype=torch.long)
                ])

            labels = input_ids.clone()
            if len(input_ids) < self.max_seq_len:
                labels[-pad_len:] = -100

            buffer.append((input_ids, labels))

            if len(buffer) >= self.buffer_size:
                random.shuffle(buffer)
                for _ in range(self.buffer_size // 2):
                    item = buffer.pop()
                    if self.skip_samples > 0:
                        self.skip_samples -= 1
                        continue
                    yield item

        # Yield remaining
        random.shuffle(buffer)
        while buffer:
            item = buffer.pop()
            if self.skip_samples > 0:
                self.skip_samples -= 1
                continue
            yield item

def create_streaming_loader(dataset_name, split, tokenizer, config, batch_size, mode="pretrain", hf_token=None, start_step=0):
    raw_dataset = load_dataset(dataset_name, split=split, streaming=True,
                              trust_remote_code=True, token=hf_token)
    
    # Calculate samples to skip: start_step * batch_size
    skip_samples = start_step * batch_size
    if skip_samples > 0:
        print(f"  [Loader] Resuming: Fast-forwarding dataset by {skip_samples} samples...")
        
    stream_ds = StreamingDataset(raw_dataset, tokenizer, config.max_seq_len, mode=mode, skip_samples=skip_samples)
    
    # Increase num_workers for better utilization
    return DataLoader(stream_ds, batch_size=batch_size, pin_memory=True,
                     num_workers=4, prefetch_factor=4)