|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq |
|
|
from datasets import Dataset |
|
|
from huggingface_hub import login |
|
|
|
|
|
|
|
|
|
|
|
login(token="DUDE") |
|
|
|
|
|
|
|
|
base_model = "PerceptronAI/Isaac-0.1" |
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained(base_model, trust_remote_code=True, torch_dtype="auto") |
|
|
|
|
|
|
|
|
data = [ |
|
|
{"input": "Hello, who are you?", "output": "I am Cass2.0, your AI assistant."}, |
|
|
{"input": "Tell me a joke.", "output": "Why did the robot cross the road? To recharge itself!"}, |
|
|
{"input": "What's your purpose?", "output": "I help you with answers, coding, and ideas as Cass2.0."}, |
|
|
] |
|
|
|
|
|
dataset = Dataset.from_list(data) |
|
|
|
|
|
|
|
|
def tokenize(batch): |
|
|
inputs = tokenizer(batch["input"], truncation=True, padding="max_length", max_length=128) |
|
|
outputs = tokenizer(batch["output"], truncation=True, padding="max_length", max_length=128) |
|
|
inputs["labels"] = outputs["input_ids"] |
|
|
return inputs |
|
|
|
|
|
tokenized_dataset = dataset.map(tokenize, batched=True) |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir="./cass2.0", |
|
|
num_train_epochs=3, |
|
|
per_device_train_batch_size=2, |
|
|
save_steps=50, |
|
|
save_total_limit=2, |
|
|
logging_steps=10, |
|
|
learning_rate=5e-5, |
|
|
fp16=True, |
|
|
push_to_hub=True, |
|
|
hub_model_id="cass2.0" |
|
|
) |
|
|
|
|
|
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) |
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_dataset, |
|
|
tokenizer=tokenizer, |
|
|
data_collator=data_collator |
|
|
) |
|
|
|
|
|
|
|
|
print("🚀 Training Cass2.0...") |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
model.save_pretrained("./cass2.0") |
|
|
tokenizer.save_pretrained("./cass2.0") |
|
|
print("✅ Model saved locally in './cass2.0'") |
|
|
|
|
|
|
|
|
trainer.push_to_hub() |
|
|
print("🌐 Model pushed to Hugging Face Hub as 'cass2.0'") |
|
|
|