Jcalemcg's picture
Add HF login to fix authentication
c73412a verified
#!/usr/bin/env python3
"""
Fine-tune Zephyr 7B on CyberSecurity Dataset Collection
Runs on Hugging Face Spaces infrastructure
"""
import os
import torch
from datasets import load_dataset, concatenate_datasets
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from huggingface_hub import login
# Configuration
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
OUTPUT_MODEL_NAME = "Jcalemcg/zephyr-7b-cybersecurity-finetuned"
# CyberSecurity datasets from thelordofweb collection
CYBERSECURITY_DATASETS = [
"AlicanKiraz0/All-CVE-Records-Training-Dataset",
"AlicanKiraz0/Cybersecurity-Dataset-v1",
"Bouquets/Cybersecurity-LLM-CVE",
"CyberNative/CyberSecurityEval",
"Mohabahmed03/Alpaca_Dataset_CyberSecurity_Smaller",
"CyberNative/github_cybersecurity_READMEs",
"AlicanKiraz0/Cybersecurity-Dataset-Heimdall-v1.1",
"jcordon5/cybersecurity-rules",
"Bouquets/DeepSeek-V3-Distill-Cybersecurity-en",
"Seerene/cybersecurity_dataset",
"ahmedds10/finetuning_alpaca_Cybersecurity",
"Tiamz/cybersecurity-instruction-dataset",
"OhWayTee/Cybersecurity-News_3",
"Trendyol/All-CVE-Chat-MultiTurn-1999-2025-Dataset",
"Vanessasml/cyber-reports-news-analysis-llama2-3k",
"Vanessasml/cybersecurity_32k_instruction_input_output",
"Vanessasml/enisa_cyber_news_dataset",
"Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset"
]
def format_instruction(example):
"""Format examples into Zephyr chat format"""
if "instruction" in example and "output" in example:
prompt = f"<|user|>\n{example['instruction']}"
if example.get("input", "").strip():
prompt += f"\n{example['input']}"
prompt += f"</s>\n<|assistant|>\n{example['output']}</s>"
return {"text": prompt}
elif "question" in example and "answer" in example:
return {"text": f"<|user|>\n{example['question']}</s>\n<|assistant|>\n{example['answer']}</s>"}
elif "prompt" in example and "completion" in example:
return {"text": f"<|user|>\n{example['prompt']}</s>\n<|assistant|>\n{example['completion']}</s>"}
elif "text" in example:
return {"text": example["text"]}
elif "messages" in example:
formatted_text = ""
for msg in example["messages"]:
role = msg.get("role", "")
content = msg.get("content", "")
if role == "user":
formatted_text += f"<|user|>\n{content}</s>\n"
elif role == "assistant":
formatted_text += f"<|assistant|>\n{content}</s>\n"
return {"text": formatted_text}
return {"text": str(example)}
def load_datasets():
"""Load and prepare cybersecurity datasets"""
print("=" * 70)
print("LOADING CYBERSECURITY DATASETS")
print("=" * 70)
all_datasets = []
for dataset_name in CYBERSECURITY_DATASETS:
try:
print(f"\nLoading: {dataset_name}")
dataset = load_dataset(dataset_name, split="train", trust_remote_code=True)
formatted = dataset.map(
format_instruction,
remove_columns=dataset.column_names,
desc="Formatting"
)
if len(formatted) > 10000:
formatted = formatted.shuffle(seed=42).select(range(10000))
all_datasets.append(formatted)
print(f"βœ“ {len(formatted)} examples loaded")
except Exception as e:
print(f"βœ— Failed: {e}")
combined = concatenate_datasets(all_datasets)
print(f"\n{'='*70}")
print(f"TOTAL DATASET SIZE: {len(combined):,} examples")
print(f"{'='*70}\n")
combined = combined.shuffle(seed=42)
return combined.train_test_split(test_size=0.05, seed=42)
def setup_model():
"""Setup model with QLoRA"""
print("Setting up Zephyr 7B with QLoRA...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model, tokenizer
def main():
print("\n" + "=" * 70)
print("ZEPHYR 7B CYBERSECURITY FINE-TUNING")
print("=" * 70 + "\n")
# Login to Hugging Face
hf_token = os.getenv("HF_TOKEN")
if hf_token:
print("Logging in to Hugging Face...")
login(token=hf_token)
print("βœ“ Logged in successfully\n")
else:
print("Warning: HF_TOKEN not found in environment")
# Load data
datasets = load_datasets()
train_data = datasets["train"]
eval_data = datasets["test"]
# Setup model
model, tokenizer = setup_model()
# Tokenize
print("\nTokenizing datasets...")
def tokenize(examples):
return tokenizer(examples["text"], truncation=True, max_length=2048, padding="max_length")
train_data = train_data.map(tokenize, batched=True, remove_columns=train_data.column_names)
eval_data = eval_data.map(tokenize, batched=True, remove_columns=eval_data.column_names)
# Training config
training_args = TrainingArguments(
output_dir="./output",
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
fp16=True,
save_strategy="steps",
save_steps=500,
eval_strategy="steps",
eval_steps=500,
logging_steps=50,
warmup_steps=100,
lr_scheduler_type="cosine",
optim="paged_adamw_8bit",
save_total_limit=3,
load_best_model_at_end=True,
push_to_hub=True,
hub_model_id=OUTPUT_MODEL_NAME,
hub_strategy="every_save",
report_to="tensorboard",
)
# Train
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=eval_data,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
)
print("\n" + "=" * 70)
print("STARTING TRAINING")
print("=" * 70 + "\n")
trainer.train()
print("\nSaving model...")
trainer.save_model()
model.push_to_hub(OUTPUT_MODEL_NAME)
tokenizer.push_to_hub(OUTPUT_MODEL_NAME)
print("\n" + "=" * 70)
print("βœ“ TRAINING COMPLETE")
print(f"βœ“ Model: {OUTPUT_MODEL_NAME}")
print("=" * 70)
if __name__ == "__main__":
main()