CASS-2.0 / cass2.0_builder.py
DSDUDEd's picture
Update cass2.0_builder.py
0d95666 verified
# cass2.0_builder.py
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from datasets import Dataset
from huggingface_hub import login
# 1️⃣ Hugging Face login
# Replace 'YOUR_HF_TOKEN' with your Hugging Face token
login(token="DUDE")
# 2️⃣ Load base model and tokenizer
base_model = "PerceptronAI/Isaac-0.1"
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(base_model, trust_remote_code=True, torch_dtype="auto")
# 3️⃣ Custom fine-tuning dataset (expandable)
data = [
{"input": "Hello, who are you?", "output": "I am Cass2.0, your AI assistant."},
{"input": "Tell me a joke.", "output": "Why did the robot cross the road? To recharge itself!"},
{"input": "What's your purpose?", "output": "I help you with answers, coding, and ideas as Cass2.0."},
]
dataset = Dataset.from_list(data)
# 4️⃣ Tokenize dataset
def tokenize(batch):
inputs = tokenizer(batch["input"], truncation=True, padding="max_length", max_length=128)
outputs = tokenizer(batch["output"], truncation=True, padding="max_length", max_length=128)
inputs["labels"] = outputs["input_ids"]
return inputs
tokenized_dataset = dataset.map(tokenize, batched=True)
# 5️⃣ Training setup
training_args = TrainingArguments(
output_dir="./cass2.0",
num_train_epochs=3,
per_device_train_batch_size=2,
save_steps=50,
save_total_limit=2,
logging_steps=10,
learning_rate=5e-5,
fp16=True, # Use GPU if available
push_to_hub=True, # Enable push to HF hub
hub_model_id="cass2.0" # Name of your model on Hugging Face
)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer,
data_collator=data_collator
)
# 6️⃣ Train the model
print("🚀 Training Cass2.0...")
trainer.train()
# 7️⃣ Save locally
model.save_pretrained("./cass2.0")
tokenizer.save_pretrained("./cass2.0")
print("✅ Model saved locally in './cass2.0'")
# 8️⃣ Push to Hugging Face Hub
trainer.push_to_hub()
print("🌐 Model pushed to Hugging Face Hub as 'cass2.0'")