| | |
| | """ |
| | Sync Strawberry Picker Models to HuggingFace Repository |
| | |
| | This script automates the process of syncing trained models from strawberryPicker |
| | to the HuggingFace strawberryPicker repository. |
| | |
| | Usage: |
| | python sync_to_huggingface.py [--dry-run] |
| | """ |
| |
|
| | import os |
| | import sys |
| | import json |
| | import shutil |
| | import argparse |
| | import subprocess |
| | from pathlib import Path |
| | from datetime import datetime |
| | import hashlib |
| |
|
| | def calculate_file_hash(file_path): |
| | """Calculate SHA256 hash of a file.""" |
| | sha256_hash = hashlib.sha256() |
| | with open(file_path, "rb") as f: |
| | for byte_block in iter(lambda: f.read(4096), b""): |
| | sha256_hash.update(byte_block) |
| | return sha256_hash.hexdigest() |
| |
|
| | def find_updated_models(repo_path): |
| | """Find models that have been updated.""" |
| | updated_models = [] |
| |
|
| | |
| | detection_pt = Path(repo_path) / "detection" / "best.pt" |
| | if detection_pt.exists(): |
| | detection_hash = calculate_file_hash(detection_pt) |
| | updated_models.append({ |
| | 'component': 'detection', |
| | 'path': detection_pt, |
| | 'hash': detection_hash |
| | }) |
| |
|
| | |
| | classification_pth = Path(repo_path) / "classification" / "best_enhanced_classifier.pth" |
| | if classification_pth.exists(): |
| | classification_hash = calculate_file_hash(classification_pth) |
| | updated_models.append({ |
| | 'component': 'classification', |
| | 'path': classification_pth, |
| | 'hash': classification_hash |
| | }) |
| |
|
| | return updated_models |
| |
|
| | def export_detection_to_onnx(model_path, output_dir): |
| | """Export detection model to ONNX format.""" |
| | try: |
| | cmd = [ |
| | "yolo", "export", |
| | f"model={model_path}", |
| | f"dir={output_dir}", |
| | "format=onnx", |
| | "opset=12" |
| | ] |
| |
|
| | print(f"Exporting detection model to ONNX...") |
| | result = subprocess.run(cmd, capture_output=True, text=True) |
| |
|
| | if result.returncode == 0: |
| | print("Successfully exported detection model to ONNX") |
| | return True |
| | else: |
| | print(f"Export failed: {result.stderr}") |
| | return False |
| |
|
| | except Exception as e: |
| | print(f"Error during ONNX export: {e}") |
| | return False |
| |
|
| | def update_model_metadata(repo_path, models_info): |
| | """Update metadata files with sync information.""" |
| | metadata = { |
| | "last_sync": datetime.now().isoformat(), |
| | "models": {} |
| | } |
| |
|
| | for model in models_info: |
| | metadata["models"][model['component']] = { |
| | "hash": model['hash'], |
| | "path": str(model['path']), |
| | "last_updated": datetime.now().isoformat() |
| | } |
| |
|
| | metadata_file = Path(repo_path) / "sync_metadata.json" |
| | with open(metadata_file, 'w') as f: |
| | json.dump(metadata, f, indent=2) |
| |
|
| | print(f"Updated metadata file: {metadata_file}") |
| | return True |
| |
|
| | def sync_repository(repo_path, dry_run=False): |
| | """Main sync function for strawberryPicker repository.""" |
| | print(f"Syncing strawberryPicker repository at {repo_path}") |
| |
|
| | if dry_run: |
| | print("DRY RUN MODE - No changes will be made") |
| |
|
| | |
| | updated_models = find_updated_models(repo_path) |
| |
|
| | if not updated_models: |
| | print("No models found to sync") |
| | return True |
| |
|
| | print(f"Found {len(updated_models)} model components:") |
| | for model in updated_models: |
| | print(f" - {model['component']} (hash: {model['hash'][:8]}...)") |
| |
|
| | if dry_run: |
| | return True |
| |
|
| | |
| | detection_model = next((m for m in updated_models if m['component'] == 'detection'), None) |
| | if detection_model: |
| | detection_dir = Path(repo_path) / "detection" |
| | export_detection_to_onnx(detection_model['path'], detection_dir) |
| |
|
| | |
| | update_model_metadata(repo_path, updated_models) |
| |
|
| | print("\nSync completed successfully!") |
| | print("Remember to:") |
| | print("1. Review and commit changes to git") |
| | print("2. Push to HuggingFace: git push origin main") |
| | print("3. Update READMEs with any new performance metrics") |
| |
|
| | return True |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Sync strawberryPicker models to HuggingFace") |
| | parser.add_argument("--repo-path", |
| | default="/home/user/machine-learning/HuggingfaceModels/strawberryPicker", |
| | help="Path to strawberryPicker repository") |
| | parser.add_argument("--dry-run", action="store_true", |
| | help="Show what would be done without making changes") |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | if not os.path.exists(args.repo_path): |
| | print(f"Error: Repository path {args.repo_path} does not exist") |
| | sys.exit(1) |
| |
|
| | |
| | success = sync_repository(args.repo_path, args.dry_run) |
| |
|
| | if not success: |
| | sys.exit(1) |
| |
|
| | if __name__ == "__main__": |
| | main() |