|
|
|
|
|
""" |
|
|
Training script for Vietnamese text classification. |
|
|
Supports both VNTC (news) and UTS2017_Bank (banking) datasets. |
|
|
This script trains a TF-IDF + Logistic Regression model on Vietnamese text classification datasets. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import time |
|
|
from datetime import datetime |
|
|
|
|
|
import numpy as np |
|
|
from datasets import load_dataset |
|
|
import requests |
|
|
import zipfile |
|
|
from io import BytesIO |
|
|
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer |
|
|
from sklearn.linear_model import LogisticRegression |
|
|
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix |
|
|
from sklearn.model_selection import train_test_split |
|
|
from sklearn.pipeline import Pipeline |
|
|
from sklearn.svm import SVC |
|
|
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, AdaBoostClassifier |
|
|
from sklearn.naive_bayes import MultinomialNB |
|
|
from sklearn.neural_network import MLPClassifier |
|
|
from sklearn.tree import DecisionTreeClassifier |
|
|
import joblib |
|
|
|
|
|
|
|
|
def setup_logging(run_name): |
|
|
"""Setup logging to save all information to runs folder""" |
|
|
runs_dir = "runs" |
|
|
os.makedirs(runs_dir, exist_ok=True) |
|
|
|
|
|
run_dir = os.path.join(runs_dir, run_name) |
|
|
os.makedirs(run_dir, exist_ok=True) |
|
|
|
|
|
log_file = os.path.join(run_dir, "training.log") |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s - %(levelname)s - %(message)s", |
|
|
handlers=[logging.FileHandler(log_file), logging.StreamHandler()], |
|
|
) |
|
|
|
|
|
return run_dir |
|
|
|
|
|
|
|
|
def load_vntc_data(split_ratio=0.2, random_state=42, n_samples=None): |
|
|
"""Load and prepare VNTC dataset |
|
|
|
|
|
Args: |
|
|
split_ratio: Not used for VNTC (has predefined train/test split) |
|
|
random_state: Not used for VNTC (has predefined train/test split) |
|
|
n_samples: Optional limit on number of samples |
|
|
|
|
|
Returns: |
|
|
Tuple of (X_train, y_train), (X_test, y_test) |
|
|
""" |
|
|
print("Loading VNTC dataset...") |
|
|
|
|
|
|
|
|
dataset_folder = os.path.expanduser("~/.underthesea/VNTC") |
|
|
os.makedirs(dataset_folder, exist_ok=True) |
|
|
|
|
|
train_file = os.path.join(dataset_folder, "train.txt") |
|
|
test_file = os.path.join(dataset_folder, "test.txt") |
|
|
|
|
|
|
|
|
if not os.path.exists(train_file) or not os.path.exists(test_file): |
|
|
print("Downloading VNTC dataset...") |
|
|
url = "https://github.com/undertheseanlp/underthesea/releases/download/resources/VNTC.zip" |
|
|
|
|
|
response = requests.get(url) |
|
|
with zipfile.ZipFile(BytesIO(response.content)) as zip_file: |
|
|
zip_file.extractall(dataset_folder) |
|
|
print("Dataset downloaded and extracted.") |
|
|
|
|
|
|
|
|
X_train = [] |
|
|
y_train = [] |
|
|
with open(train_file, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
if line.strip(): |
|
|
parts = line.strip().split(' ', 1) |
|
|
if len(parts) == 2: |
|
|
label = parts[0].replace('__label__', '') |
|
|
text = parts[1] |
|
|
y_train.append(label) |
|
|
X_train.append(text) |
|
|
|
|
|
|
|
|
X_test = [] |
|
|
y_test = [] |
|
|
with open(test_file, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
if line.strip(): |
|
|
parts = line.strip().split(' ', 1) |
|
|
if len(parts) == 2: |
|
|
label = parts[0].replace('__label__', '') |
|
|
text = parts[1] |
|
|
y_test.append(label) |
|
|
X_test.append(text) |
|
|
|
|
|
|
|
|
if n_samples: |
|
|
if n_samples < len(X_train): |
|
|
|
|
|
X_train_array = np.array(X_train) |
|
|
y_train_array = np.array(y_train) |
|
|
indices = np.arange(len(X_train)) |
|
|
|
|
|
|
|
|
np.random.seed(42) |
|
|
shuffled_indices = np.random.permutation(indices) |
|
|
|
|
|
|
|
|
sample_indices = shuffled_indices[:n_samples] |
|
|
X_train = X_train_array[sample_indices].tolist() |
|
|
y_train = y_train_array[sample_indices].tolist() |
|
|
|
|
|
if n_samples < len(X_test): |
|
|
|
|
|
X_test_array = np.array(X_test) |
|
|
y_test_array = np.array(y_test) |
|
|
indices = np.arange(len(X_test)) |
|
|
|
|
|
|
|
|
np.random.seed(42) |
|
|
shuffled_indices = np.random.permutation(indices) |
|
|
|
|
|
|
|
|
sample_indices = shuffled_indices[:n_samples] |
|
|
X_test = X_test_array[sample_indices].tolist() |
|
|
y_test = y_test_array[sample_indices].tolist() |
|
|
|
|
|
|
|
|
X_train = np.array(X_train) |
|
|
y_train = np.array(y_train) |
|
|
X_test = np.array(X_test) |
|
|
y_test = np.array(y_test) |
|
|
|
|
|
print(f"Dataset loaded: {len(X_train)} train samples, {len(X_test)} test samples") |
|
|
print(f"Number of unique labels: {len(set(y_train))}") |
|
|
|
|
|
return (X_train, y_train), (X_test, y_test) |
|
|
|
|
|
|
|
|
def load_uts2017_data(split_ratio=0.2, random_state=42, n_samples=None): |
|
|
"""Load and prepare UTS2017_Bank classification dataset |
|
|
|
|
|
Args: |
|
|
split_ratio: Ratio for train/test split |
|
|
random_state: Random seed for reproducibility |
|
|
n_samples: Optional limit on number of samples |
|
|
|
|
|
Returns: |
|
|
Tuple of (X_train, y_train), (X_test, y_test) |
|
|
""" |
|
|
print("Loading UTS2017_Bank dataset from Hugging Face...") |
|
|
|
|
|
|
|
|
dataset = load_dataset("undertheseanlp/UTS2017_Bank", "classification") |
|
|
|
|
|
|
|
|
train_data = dataset["train"] |
|
|
|
|
|
|
|
|
texts = train_data["text"] |
|
|
labels = train_data["label"] |
|
|
|
|
|
|
|
|
if n_samples and n_samples < len(texts): |
|
|
texts = texts[:n_samples] |
|
|
labels = labels[:n_samples] |
|
|
|
|
|
|
|
|
X = np.array(texts) |
|
|
y = np.array(labels) |
|
|
|
|
|
|
|
|
|
|
|
min_samples_per_class = 2 |
|
|
unique_classes, class_counts = np.unique(y, return_counts=True) |
|
|
can_stratify = all(count >= min_samples_per_class for count in class_counts) |
|
|
|
|
|
if can_stratify: |
|
|
X_train, X_test, y_train, y_test = train_test_split( |
|
|
X, y, test_size=split_ratio, random_state=random_state, stratify=y |
|
|
) |
|
|
else: |
|
|
print( |
|
|
f"Warning: Some classes have fewer than {min_samples_per_class} samples. Disabling stratification." |
|
|
) |
|
|
X_train, X_test, y_train, y_test = train_test_split( |
|
|
X, y, test_size=split_ratio, random_state=random_state |
|
|
) |
|
|
|
|
|
print(f"Dataset loaded: {len(X_train)} train samples, {len(X_test)} test samples") |
|
|
print(f"Number of unique labels: {len(set(y))}") |
|
|
|
|
|
return (X_train, y_train), (X_test, y_test) |
|
|
|
|
|
|
|
|
def get_available_models(): |
|
|
"""Get available classifier options""" |
|
|
return { |
|
|
|
|
|
"logistic": LogisticRegression(max_iter=1000, random_state=42), |
|
|
"svc_linear": SVC(kernel="linear", random_state=42, probability=True), |
|
|
"svc_rbf": SVC(kernel="rbf", random_state=42, probability=True, gamma='scale'), |
|
|
"naive_bayes": MultinomialNB(), |
|
|
|
|
|
|
|
|
"decision_tree": DecisionTreeClassifier(random_state=42, max_depth=10), |
|
|
"random_forest": RandomForestClassifier(n_estimators=100, random_state=42, max_depth=10, n_jobs=-1), |
|
|
|
|
|
|
|
|
"gradient_boost": GradientBoostingClassifier(n_estimators=100, random_state=42, max_depth=5), |
|
|
"ada_boost": AdaBoostClassifier(n_estimators=100, random_state=42), |
|
|
|
|
|
|
|
|
"mlp": MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=500, random_state=42, early_stopping=True), |
|
|
} |
|
|
|
|
|
|
|
|
def train_model( |
|
|
dataset="uts2017", |
|
|
model_name="logistic", |
|
|
max_features=20000, |
|
|
ngram_range=(1, 2), |
|
|
split_ratio=0.2, |
|
|
n_samples=None, |
|
|
export_model=False, |
|
|
): |
|
|
"""Train a single model with specified parameters |
|
|
|
|
|
Args: |
|
|
dataset: Dataset to use ('vntc' or 'uts2017') |
|
|
model_name: Name of the model to train ('logistic' or 'svc') |
|
|
max_features: Maximum number of features for TF-IDF vectorizer |
|
|
ngram_range: N-gram range for feature extraction |
|
|
split_ratio: Train/test split ratio |
|
|
n_samples: Optional limit on number of samples |
|
|
|
|
|
Returns: |
|
|
Dictionary containing training results |
|
|
""" |
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
run_dir = setup_logging(timestamp) |
|
|
|
|
|
logging.info(f"Starting training run: {timestamp}") |
|
|
logging.info(f"Model: {model_name}") |
|
|
logging.info(f"Max features: {max_features}") |
|
|
logging.info(f"N-gram range: {ngram_range}") |
|
|
if n_samples: |
|
|
logging.info(f"Sample limit: {n_samples}") |
|
|
|
|
|
|
|
|
output_folder = os.path.join(run_dir, "models") |
|
|
os.makedirs(output_folder, exist_ok=True) |
|
|
|
|
|
|
|
|
if dataset == "vntc": |
|
|
logging.info("Loading VNTC dataset...") |
|
|
(X_train, y_train), (X_test, y_test) = load_vntc_data( |
|
|
split_ratio=split_ratio, n_samples=n_samples |
|
|
) |
|
|
dataset_name = "VNTC" |
|
|
else: |
|
|
logging.info("Loading UTS2017_Bank dataset...") |
|
|
(X_train, y_train), (X_test, y_test) = load_uts2017_data( |
|
|
split_ratio=split_ratio, n_samples=n_samples |
|
|
) |
|
|
dataset_name = "UTS2017_Bank" |
|
|
|
|
|
|
|
|
unique_labels = sorted(set(y_train)) |
|
|
label_counts_train = {label: np.sum(y_train == label) for label in unique_labels} |
|
|
label_counts_test = {label: np.sum(y_test == label) for label in unique_labels} |
|
|
|
|
|
logging.info(f"Train samples: {len(X_train)}") |
|
|
logging.info(f"Test samples: {len(X_test)}") |
|
|
logging.info(f"Unique labels: {len(unique_labels)}") |
|
|
logging.info(f"Label distribution (train): {label_counts_train}") |
|
|
logging.info(f"Label distribution (test): {label_counts_test}") |
|
|
|
|
|
|
|
|
available_models = get_available_models() |
|
|
if model_name not in available_models: |
|
|
raise ValueError( |
|
|
f"Model '{model_name}' not available. Choose from: {list(available_models.keys())}" |
|
|
) |
|
|
|
|
|
classifier = available_models[model_name] |
|
|
clf_name = classifier.__class__.__name__ |
|
|
logging.info(f"Selected classifier: {clf_name}") |
|
|
|
|
|
|
|
|
config_name = f"{dataset_name}_{clf_name}_feat{max_features // 1000}k_ngram{ngram_range[0]}-{ngram_range[1]}" |
|
|
|
|
|
logging.info("=" * 60) |
|
|
logging.info(f"Training: {config_name}") |
|
|
logging.info("=" * 60) |
|
|
|
|
|
|
|
|
logging.info( |
|
|
f"Creating pipeline with max_features={max_features}, ngram_range={ngram_range}" |
|
|
) |
|
|
|
|
|
text_clf = Pipeline( |
|
|
[ |
|
|
( |
|
|
"vect", |
|
|
CountVectorizer(max_features=max_features, ngram_range=ngram_range), |
|
|
), |
|
|
("tfidf", TfidfTransformer(use_idf=True)), |
|
|
("clf", classifier), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
logging.info("Training model...") |
|
|
start_time = time.time() |
|
|
text_clf.fit(X_train, y_train) |
|
|
train_time = time.time() - start_time |
|
|
logging.info(f"Training completed in {train_time:.2f} seconds") |
|
|
|
|
|
|
|
|
logging.info("Evaluating on training set...") |
|
|
train_predictions = text_clf.predict(X_train) |
|
|
train_accuracy = accuracy_score(y_train, train_predictions) |
|
|
logging.info(f"Training accuracy: {train_accuracy:.4f}") |
|
|
|
|
|
|
|
|
logging.info("Evaluating on test set...") |
|
|
start_time = time.time() |
|
|
test_predictions = text_clf.predict(X_test) |
|
|
test_accuracy = accuracy_score(y_test, test_predictions) |
|
|
prediction_time = time.time() - start_time |
|
|
logging.info(f"Test accuracy: {test_accuracy:.4f}") |
|
|
logging.info(f"Prediction time: {prediction_time:.2f} seconds") |
|
|
|
|
|
|
|
|
logging.info("Classification Report:") |
|
|
report = classification_report(y_test, test_predictions, zero_division=0) |
|
|
logging.info(report) |
|
|
print("\nClassification Report:") |
|
|
print(report) |
|
|
|
|
|
|
|
|
report_dict = classification_report( |
|
|
y_test, test_predictions, zero_division=0, output_dict=True |
|
|
) |
|
|
|
|
|
|
|
|
cm = confusion_matrix(y_test, test_predictions, labels=unique_labels) |
|
|
logging.info(f"Confusion Matrix shape: {cm.shape}") |
|
|
|
|
|
|
|
|
model_path = os.path.join(output_folder, "model.joblib") |
|
|
joblib.dump(text_clf, model_path) |
|
|
logging.info(f"Model saved to {model_path}") |
|
|
print(f"Model saved to {model_path}") |
|
|
|
|
|
|
|
|
config_model_path = os.path.join(output_folder, f"{config_name}.joblib") |
|
|
joblib.dump(text_clf, config_model_path) |
|
|
logging.info(f"Model also saved as {config_model_path}") |
|
|
|
|
|
|
|
|
if export_model: |
|
|
|
|
|
run_id = os.path.basename(run_dir) |
|
|
export_filename = f"{dataset_name.lower()}_classifier_{run_id}.joblib" |
|
|
export_path = os.path.join(".", export_filename) |
|
|
joblib.dump(text_clf, export_path) |
|
|
logging.info(f"Model exported as {export_path}") |
|
|
print(f"Model exported for distribution: {export_filename}") |
|
|
|
|
|
|
|
|
label_mapping_path = os.path.join(output_folder, "labels.txt") |
|
|
with open(label_mapping_path, "w", encoding="utf-8") as f: |
|
|
for label in unique_labels: |
|
|
f.write(f"{label}\n") |
|
|
logging.info(f"Label mapping saved to {label_mapping_path}") |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"timestamp": timestamp, |
|
|
"config_name": config_name, |
|
|
"model_name": model_name, |
|
|
"classifier": clf_name, |
|
|
"max_features": max_features, |
|
|
"ngram_range": list(ngram_range), |
|
|
"split_ratio": split_ratio, |
|
|
"n_samples": n_samples, |
|
|
"train_samples": len(X_train), |
|
|
"test_samples": len(X_test), |
|
|
"unique_labels": len(unique_labels), |
|
|
"labels": unique_labels, |
|
|
"train_accuracy": float(train_accuracy), |
|
|
"test_accuracy": float(test_accuracy), |
|
|
"train_time": train_time, |
|
|
"prediction_time": prediction_time, |
|
|
"classification_report": report_dict, |
|
|
"confusion_matrix": cm.tolist(), |
|
|
} |
|
|
|
|
|
metadata_path = os.path.join(run_dir, "metadata.json") |
|
|
with open(metadata_path, "w", encoding="utf-8") as f: |
|
|
json.dump(metadata, f, indent=2, ensure_ascii=False) |
|
|
logging.info(f"Metadata saved to {metadata_path}") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Training Summary") |
|
|
print("=" * 60) |
|
|
print(f"Model: {clf_name}") |
|
|
print(f"Training samples: {len(X_train)}") |
|
|
print(f"Test samples: {len(X_test)}") |
|
|
print(f"Number of classes: {len(unique_labels)}") |
|
|
print(f"Training accuracy: {train_accuracy:.4f}") |
|
|
print(f"Test accuracy: {test_accuracy:.4f}") |
|
|
print(f"Training time: {train_time:.2f} seconds") |
|
|
print(f"Model saved to: {model_path}") |
|
|
print("=" * 60) |
|
|
|
|
|
return metadata |
|
|
|
|
|
|
|
|
def train_all_configurations(dataset="vntc", models=None, num_rows=None): |
|
|
"""Train multiple model configurations and compare results""" |
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
run_dir = setup_logging(timestamp) |
|
|
|
|
|
logging.info(f"Starting comparison run: {timestamp}") |
|
|
logging.info(f"Dataset: {dataset}") |
|
|
if num_rows: |
|
|
logging.info(f"Sample limit: {num_rows}") |
|
|
|
|
|
if models is None: |
|
|
|
|
|
available_models = get_available_models() |
|
|
models = list(available_models.keys()) |
|
|
|
|
|
logging.info(f"Models to compare: {models}") |
|
|
|
|
|
|
|
|
configurations = [] |
|
|
for model_name in models: |
|
|
if model_name in ["svc_rbf", "gradient_boost", "ada_boost", "mlp"]: |
|
|
|
|
|
configurations.append({ |
|
|
"dataset": dataset, |
|
|
"model_name": model_name, |
|
|
"max_features": 10000, |
|
|
"ngram_range": (1, 2), |
|
|
"n_samples": num_rows |
|
|
}) |
|
|
else: |
|
|
|
|
|
configurations.append({ |
|
|
"dataset": dataset, |
|
|
"model_name": model_name, |
|
|
"max_features": 20000, |
|
|
"ngram_range": (1, 2), |
|
|
"n_samples": num_rows |
|
|
}) |
|
|
|
|
|
results = [] |
|
|
|
|
|
for config in configurations: |
|
|
print(f"\nTraining configuration: {config}") |
|
|
try: |
|
|
result = train_model(**config) |
|
|
results.append(result) |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to train with config {config}: {e}") |
|
|
print(f"Error training configuration: {e}") |
|
|
|
|
|
|
|
|
comparison_path = os.path.join(run_dir, "comparison_results.json") |
|
|
with open(comparison_path, "w", encoding="utf-8") as f: |
|
|
json.dump(results, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Model Comparison Results") |
|
|
print("=" * 80) |
|
|
print( |
|
|
f"{'Model':<10} {'Features':<10} {'N-gram':<10} {'Train Acc':<12} {'Test Acc':<12}" |
|
|
) |
|
|
print("-" * 80) |
|
|
|
|
|
for result in sorted(results, key=lambda x: x["test_accuracy"], reverse=True): |
|
|
model = result["classifier"][:8] |
|
|
features = f"{result['max_features'] // 1000}k" |
|
|
ngram = f"{result['ngram_range'][0]}-{result['ngram_range'][1]}" |
|
|
train_acc = result["train_accuracy"] |
|
|
test_acc = result["test_accuracy"] |
|
|
print( |
|
|
f"{model:<10} {features:<10} {ngram:<10} {train_acc:<12.4f} {test_acc:<12.4f}" |
|
|
) |
|
|
|
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
best_model = max(results, key=lambda x: x["test_accuracy"]) |
|
|
print(f"\nBest model: {best_model['config_name']}") |
|
|
print(f"Test accuracy: {best_model['test_accuracy']:.4f}") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def train_notebook(dataset="uts2017", model_name="logistic", max_features=20000, ngram_min=1, ngram_max=2, |
|
|
split_ratio=0.2, n_samples=None, compare=False, export_model=False): |
|
|
""" |
|
|
Convenience function for training in Jupyter/Colab notebooks without argparse. |
|
|
|
|
|
Example usage: |
|
|
from train import train_notebook |
|
|
train_notebook(dataset="vntc", model_name="logistic", max_features=20000, export_model=True) |
|
|
""" |
|
|
if compare: |
|
|
print("Training and comparing multiple configurations...") |
|
|
return train_all_configurations() |
|
|
else: |
|
|
dataset_name = "VNTC" if dataset == "vntc" else "UTS2017_Bank" |
|
|
print(f"Training {model_name} model on {dataset_name} dataset...") |
|
|
print(f"Configuration: max_features={max_features}, ngram=({ngram_min}, {ngram_max})") |
|
|
|
|
|
return train_model( |
|
|
dataset=dataset, |
|
|
model_name=model_name, |
|
|
max_features=max_features, |
|
|
ngram_range=(ngram_min, ngram_max), |
|
|
split_ratio=split_ratio, |
|
|
n_samples=n_samples, |
|
|
export_model=export_model, |
|
|
) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function with argument parsing""" |
|
|
|
|
|
import sys |
|
|
in_notebook = hasattr(sys, 'ps1') or 'ipykernel' in sys.modules or 'google.colab' in sys.modules |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
|
description="Train Vietnamese text classification model on VNTC or UTS2017_Bank dataset" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dataset", |
|
|
type=str, |
|
|
choices=["vntc", "uts2017"], |
|
|
default="uts2017", |
|
|
help="Dataset to use for training (default: uts2017)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model", |
|
|
type=str, |
|
|
choices=["logistic", "svc_linear", "svc_rbf", "naive_bayes", "decision_tree", "random_forest", "gradient_boost", "ada_boost", "mlp"], |
|
|
default="logistic", |
|
|
help="Model type to train (default: logistic)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-features", |
|
|
type=int, |
|
|
default=20000, |
|
|
help="Maximum number of features for TF-IDF (default: 20000)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ngram-min", type=int, default=1, help="Minimum n-gram range (default: 1)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ngram-max", type=int, default=2, help="Maximum n-gram range (default: 2)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--split-ratio", type=float, default=0.2, help="Test split ratio (default: 0.2)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num-rows", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Limit number of rows/samples for quick testing (default: None - use all data)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--compare", |
|
|
action="store_true", |
|
|
help="Train and compare multiple configurations", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--compare-models", |
|
|
nargs="+", |
|
|
help="List of specific models to compare (e.g., --compare-models logistic random_forest svc_rbf)", |
|
|
choices=["logistic", "svc_linear", "svc_rbf", "naive_bayes", "decision_tree", "random_forest", "gradient_boost", "ada_boost", "mlp"] |
|
|
) |
|
|
parser.add_argument( |
|
|
"--compare-dataset", |
|
|
type=str, |
|
|
choices=["vntc", "uts2017"], |
|
|
default="vntc", |
|
|
help="Dataset to use for model comparison (default: vntc)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--export-model", |
|
|
action="store_true", |
|
|
help="Export a copy of the trained model to project root for distribution/publishing" |
|
|
) |
|
|
|
|
|
|
|
|
args, unknown = parser.parse_known_args() |
|
|
|
|
|
|
|
|
if in_notebook and unknown: |
|
|
print(f"Note: Running in Jupyter/Colab environment. Ignoring kernel arguments: {unknown}") |
|
|
|
|
|
if args.compare or args.compare_models: |
|
|
if args.compare_models: |
|
|
print(f"Training and comparing selected models: {args.compare_models}") |
|
|
print(f"Dataset: {args.compare_dataset}") |
|
|
if args.num_rows: |
|
|
print(f"Using {args.num_rows} rows per dataset") |
|
|
train_all_configurations(dataset=args.compare_dataset, models=args.compare_models, num_rows=args.num_rows) |
|
|
else: |
|
|
print("Training and comparing all available models...") |
|
|
print(f"Dataset: {args.compare_dataset}") |
|
|
if args.num_rows: |
|
|
print(f"Using {args.num_rows} rows per dataset") |
|
|
train_all_configurations(dataset=args.compare_dataset, num_rows=args.num_rows) |
|
|
else: |
|
|
dataset_name = "VNTC" if args.dataset == "vntc" else "UTS2017_Bank" |
|
|
print(f"Training {args.model} model on {dataset_name} dataset...") |
|
|
print( |
|
|
f"Configuration: max_features={args.max_features}, ngram=({args.ngram_min}, {args.ngram_max})" |
|
|
) |
|
|
|
|
|
train_model( |
|
|
dataset=args.dataset, |
|
|
model_name=args.model, |
|
|
max_features=args.max_features, |
|
|
ngram_range=(args.ngram_min, args.ngram_max), |
|
|
split_ratio=args.split_ratio, |
|
|
n_samples=args.num_rows, |
|
|
export_model=args.export_model, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|