Pomilon
Deploy Aetheris to HF Space
1df0e33
import torch
from torch.utils.data import DataLoader, IterableDataset
from transformers import AutoTokenizer
from datasets import load_dataset
import random
from typing import Dict, Iterator
import os
def get_tokenizer(model_name: str = "gpt2"):
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
class StreamingDataset(IterableDataset):
def __init__(self, dataset, tokenizer, max_seq_len, mode="pretrain", buffer_size=500, skip_samples=0):
self.dataset = dataset
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.mode = mode
self.buffer_size = buffer_size
self.skip_samples = skip_samples
def _prepare_sft_text(self, example):
if 'messages' in example:
text = ""
for msg in example['messages']:
role = msg.get('role', '')
content = msg.get('content', '')
if role == 'assistant':
text += f"Assistant: {content}{self.tokenizer.eos_token}"
else:
text += f"User: {content}\n"
return text
elif 'text' in example:
return example['text']
else:
return ""
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
iterator = iter(self.dataset)
buffer = []
# Calculate roughly how many items to skip if they were yielded
# We process skipping in the yield loop
for example in iterator:
text = (example.get('text', '') if self.mode == "pretrain"
else self._prepare_sft_text(example))
if len(text) < 10:
continue
enc = self.tokenizer(text, truncation=True, max_length=self.max_seq_len,
return_tensors="pt")
input_ids = enc['input_ids'][0]
if len(input_ids) < 2:
continue
if len(input_ids) < self.max_seq_len:
pad_len = self.max_seq_len - len(input_ids)
input_ids = torch.cat([
input_ids,
torch.full((pad_len,), self.tokenizer.pad_token_id, dtype=torch.long)
])
labels = input_ids.clone()
if len(input_ids) < self.max_seq_len:
labels[-pad_len:] = -100
buffer.append((input_ids, labels))
if len(buffer) >= self.buffer_size:
random.shuffle(buffer)
for _ in range(self.buffer_size // 2):
item = buffer.pop()
if self.skip_samples > 0:
self.skip_samples -= 1
continue
yield item
# Yield remaining
random.shuffle(buffer)
while buffer:
item = buffer.pop()
if self.skip_samples > 0:
self.skip_samples -= 1
continue
yield item
def create_streaming_loader(dataset_name, split, tokenizer, config, batch_size, mode="pretrain", hf_token=None, start_step=0):
raw_dataset = load_dataset(dataset_name, split=split, streaming=True,
trust_remote_code=True, token=hf_token)
# Calculate samples to skip: start_step * batch_size
skip_samples = start_step * batch_size
if skip_samples > 0:
print(f" [Loader] Resuming: Fast-forwarding dataset by {skip_samples} samples...")
stream_ds = StreamingDataset(raw_dataset, tokenizer, config.max_seq_len, mode=mode, skip_samples=skip_samples)
# Increase num_workers for better utilization
return DataLoader(stream_ds, batch_size=batch_size, pin_memory=True,
num_workers=4, prefetch_factor=4)