|
|
|
|
|
""" |
|
|
Inference script for Sonar Core 1 - Vietnamese Text Classification. |
|
|
Loads trained models from local files and performs predictions. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import joblib |
|
|
import os |
|
|
import glob |
|
|
|
|
|
|
|
|
def find_local_models(): |
|
|
"""Find all available local model files""" |
|
|
models = { |
|
|
'exported': {}, |
|
|
'runs': {} |
|
|
} |
|
|
|
|
|
|
|
|
for filename in os.listdir('.'): |
|
|
if filename.endswith('.joblib'): |
|
|
if filename.startswith('vntc_classifier_'): |
|
|
models['exported']['vntc'] = filename |
|
|
elif filename.startswith('uts2017_bank_classifier_'): |
|
|
models['exported']['uts2017_bank'] = filename |
|
|
|
|
|
|
|
|
vntc_runs = glob.glob('runs/*/models/VNTC_*.joblib') |
|
|
bank_runs = glob.glob('runs/*/models/UTS2017_Bank_*.joblib') |
|
|
|
|
|
if vntc_runs: |
|
|
models['runs']['vntc'] = sorted(vntc_runs)[-1] |
|
|
if bank_runs: |
|
|
models['runs']['uts2017_bank'] = sorted(bank_runs)[-1] |
|
|
|
|
|
return models |
|
|
|
|
|
|
|
|
def load_model(model_path): |
|
|
"""Load a model from file path""" |
|
|
try: |
|
|
print(f"Loading model from: {model_path}") |
|
|
model = joblib.load(model_path) |
|
|
print(f"Model loaded successfully. Classes: {len(model.classes_)}") |
|
|
return model |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def predict_text(model, text): |
|
|
"""Make prediction on a single text""" |
|
|
try: |
|
|
probabilities = model.predict_proba([text])[0] |
|
|
|
|
|
|
|
|
top_indices = probabilities.argsort()[-3:][::-1] |
|
|
top_predictions = [] |
|
|
for idx in top_indices: |
|
|
category = model.classes_[idx] |
|
|
prob = probabilities[idx] |
|
|
top_predictions.append((category, prob)) |
|
|
|
|
|
|
|
|
prediction = top_predictions[0][0] |
|
|
confidence = top_predictions[0][1] |
|
|
|
|
|
return prediction, confidence, top_predictions |
|
|
except Exception as e: |
|
|
print(f"Error making prediction: {e}") |
|
|
return None, 0, [] |
|
|
|
|
|
|
|
|
def interactive_mode(model, dataset_name): |
|
|
"""Interactive prediction mode""" |
|
|
print(f"\n{'='*60}") |
|
|
print(f"INTERACTIVE MODE - {dataset_name.upper()} CLASSIFICATION") |
|
|
print(f"{'='*60}") |
|
|
print("Enter Vietnamese text to classify (type 'quit' to exit):") |
|
|
|
|
|
while True: |
|
|
try: |
|
|
user_input = input("\nText: ").strip() |
|
|
|
|
|
if user_input.lower() in ['quit', 'exit', 'q']: |
|
|
break |
|
|
|
|
|
if not user_input: |
|
|
continue |
|
|
|
|
|
prediction, confidence, top_predictions = predict_text(model, user_input) |
|
|
|
|
|
if prediction: |
|
|
print(f"Predicted category: {prediction}") |
|
|
print(f"Confidence: {confidence:.3f}") |
|
|
print("Top 3 predictions:") |
|
|
for i, (category, prob) in enumerate(top_predictions, 1): |
|
|
print(f" {i}. {category}: {prob:.3f}") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\nExiting...") |
|
|
break |
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
|
|
|
|
|
|
def test_examples(model, dataset_name): |
|
|
"""Test model with predefined examples""" |
|
|
if dataset_name == 'vntc': |
|
|
examples = [ |
|
|
"Đội tuyển bóng đá Việt Nam giành chiến thắng 2-0", |
|
|
"Chính phủ thông qua nghị định mới về chính sách xã hội", |
|
|
"Các nhà khoa học phát hiện loại vi khuẩn mới", |
|
|
"Thị trường chứng khoán biến động mạnh", |
|
|
"Tiêm vaccine COVID-19 đạt tỷ lệ cao", |
|
|
"Công nghệ trí tuệ nhân tạo phát triển mạnh" |
|
|
] |
|
|
else: |
|
|
examples = [ |
|
|
"Tôi muốn mở tài khoản tiết kiệm mới", |
|
|
"Lãi suất vay mua nhà hiện tại là bao nhiều?", |
|
|
"Làm thế nào để đăng ký internet banking?", |
|
|
"Chi phí chuyển tiền ra nước ngoài", |
|
|
"Ngân hàng ACB có uy tín không?", |
|
|
"Tôi cần hỗ trợ về dịch vụ ngân hàng" |
|
|
] |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"TESTING {dataset_name.upper()} MODEL WITH EXAMPLES") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
for text in examples: |
|
|
prediction, confidence, top_predictions = predict_text(model, text) |
|
|
|
|
|
if prediction: |
|
|
print(f"\nText: {text}") |
|
|
print(f"Prediction: {prediction}") |
|
|
print(f"Confidence: {confidence:.3f}") |
|
|
|
|
|
|
|
|
if confidence < 0.7: |
|
|
print("Alternative predictions:") |
|
|
for i, (category, prob) in enumerate(top_predictions[:3], 1): |
|
|
print(f" {i}. {category}: {prob:.3f}") |
|
|
print("-" * 60) |
|
|
|
|
|
|
|
|
def list_available_models(): |
|
|
"""List all available models""" |
|
|
models = find_local_models() |
|
|
|
|
|
print("Available Models:") |
|
|
print("=" * 50) |
|
|
|
|
|
if models['exported']: |
|
|
print("\nExported Models (Project Root):") |
|
|
for dataset, filename in models['exported'].items(): |
|
|
file_size = os.path.getsize(filename) / (1024 * 1024) |
|
|
print(f" {dataset}: {filename} ({file_size:.1f}MB)") |
|
|
|
|
|
if models['runs']: |
|
|
print("\nRuns Models (Training Directory):") |
|
|
for dataset, filepath in models['runs'].items(): |
|
|
file_size = os.path.getsize(filepath) / (1024 * 1024) |
|
|
print(f" {dataset}: {filepath} ({file_size:.1f}MB)") |
|
|
|
|
|
if not models['exported'] and not models['runs']: |
|
|
print("No local models found!") |
|
|
print("Train a model first using: python train.py --export-model") |
|
|
print("Or download from HuggingFace using: python use_this_model.py") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Inference with local Sonar Core 1 models" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model-path", |
|
|
type=str, |
|
|
help="Path to specific model file" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dataset", |
|
|
type=str, |
|
|
choices=["vntc", "uts2017_bank"], |
|
|
help="Dataset type (auto-detects if not specified)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--text", |
|
|
type=str, |
|
|
help="Text to classify (if not provided, enters interactive mode)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--test-examples", |
|
|
action="store_true", |
|
|
help="Test with predefined examples" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--list-models", |
|
|
action="store_true", |
|
|
help="List all available local models" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--source", |
|
|
type=str, |
|
|
choices=["exported", "runs"], |
|
|
default="exported", |
|
|
help="Model source: exported files or runs directory (default: exported)" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.list_models: |
|
|
list_available_models() |
|
|
return |
|
|
|
|
|
|
|
|
models = find_local_models() |
|
|
|
|
|
|
|
|
model_path = None |
|
|
dataset_name = args.dataset |
|
|
|
|
|
if args.model_path: |
|
|
|
|
|
model_path = args.model_path |
|
|
|
|
|
if not dataset_name: |
|
|
if 'vntc' in args.model_path.lower(): |
|
|
dataset_name = 'vntc' |
|
|
elif 'uts2017' in args.model_path.lower() or 'bank' in args.model_path.lower(): |
|
|
dataset_name = 'uts2017_bank' |
|
|
else: |
|
|
|
|
|
if args.dataset: |
|
|
|
|
|
if args.dataset in models[args.source]: |
|
|
model_path = models[args.source][args.dataset] |
|
|
dataset_name = args.dataset |
|
|
else: |
|
|
print(f"No {args.dataset} model found in {args.source} models") |
|
|
list_available_models() |
|
|
return |
|
|
else: |
|
|
|
|
|
if models[args.source]: |
|
|
dataset_name = list(models[args.source].keys())[0] |
|
|
model_path = models[args.source][dataset_name] |
|
|
print(f"Auto-selected {dataset_name} model") |
|
|
else: |
|
|
print("No models found!") |
|
|
list_available_models() |
|
|
return |
|
|
|
|
|
if not model_path or not os.path.exists(model_path): |
|
|
print(f"Model file not found: {model_path}") |
|
|
list_available_models() |
|
|
return |
|
|
|
|
|
|
|
|
model = load_model(model_path) |
|
|
if not model: |
|
|
return |
|
|
|
|
|
|
|
|
if args.text: |
|
|
|
|
|
prediction, confidence, top_predictions = predict_text(model, args.text) |
|
|
if prediction: |
|
|
print(f"\nText: {args.text}") |
|
|
print(f"Prediction: {prediction}") |
|
|
print(f"Confidence: {confidence:.3f}") |
|
|
print("Top 3 predictions:") |
|
|
for i, (category, prob) in enumerate(top_predictions, 1): |
|
|
print(f" {i}. {category}: {prob:.3f}") |
|
|
|
|
|
elif args.test_examples: |
|
|
|
|
|
test_examples(model, dataset_name) |
|
|
|
|
|
else: |
|
|
|
|
|
print(f"Loaded {dataset_name} model: {os.path.basename(model_path)}") |
|
|
test_examples(model, dataset_name) |
|
|
|
|
|
|
|
|
try: |
|
|
response = input("\nEnter interactive mode? (y/n): ").strip().lower() |
|
|
if response in ['y', 'yes']: |
|
|
interactive_mode(model, dataset_name) |
|
|
except KeyboardInterrupt: |
|
|
print("\nExiting...") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |