File size: 4,981 Bytes
efb1801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python3
"""
Sync Strawberry Picker Models to HuggingFace Repository

This script automates the process of syncing trained models from strawberryPicker
to the HuggingFace strawberryPicker repository.

Usage:
    python sync_to_huggingface.py [--dry-run]
"""

import os
import sys
import json
import shutil
import argparse
import subprocess
from pathlib import Path
from datetime import datetime
import hashlib

def calculate_file_hash(file_path):
    """Calculate SHA256 hash of a file."""
    sha256_hash = hashlib.sha256()
    with open(file_path, "rb") as f:
        for byte_block in iter(lambda: f.read(4096), b""):
            sha256_hash.update(byte_block)
    return sha256_hash.hexdigest()

def find_updated_models(repo_path):
    """Find models that have been updated."""
    updated_models = []

    # Check detection model
    detection_pt = Path(repo_path) / "detection" / "best.pt"
    if detection_pt.exists():
        detection_hash = calculate_file_hash(detection_pt)
        updated_models.append({
            'component': 'detection',
            'path': detection_pt,
            'hash': detection_hash
        })

    # Check classification model
    classification_pth = Path(repo_path) / "classification" / "best_enhanced_classifier.pth"
    if classification_pth.exists():
        classification_hash = calculate_file_hash(classification_pth)
        updated_models.append({
            'component': 'classification',
            'path': classification_pth,
            'hash': classification_hash
        })

    return updated_models

def export_detection_to_onnx(model_path, output_dir):
    """Export detection model to ONNX format."""
    try:
        cmd = [
            "yolo", "export",
            f"model={model_path}",
            f"dir={output_dir}",
            "format=onnx",
            "opset=12"
        ]

        print(f"Exporting detection model to ONNX...")
        result = subprocess.run(cmd, capture_output=True, text=True)

        if result.returncode == 0:
            print("Successfully exported detection model to ONNX")
            return True
        else:
            print(f"Export failed: {result.stderr}")
            return False

    except Exception as e:
        print(f"Error during ONNX export: {e}")
        return False

def update_model_metadata(repo_path, models_info):
    """Update metadata files with sync information."""
    metadata = {
        "last_sync": datetime.now().isoformat(),
        "models": {}
    }

    for model in models_info:
        metadata["models"][model['component']] = {
            "hash": model['hash'],
            "path": str(model['path']),
            "last_updated": datetime.now().isoformat()
        }

    metadata_file = Path(repo_path) / "sync_metadata.json"
    with open(metadata_file, 'w') as f:
        json.dump(metadata, f, indent=2)

    print(f"Updated metadata file: {metadata_file}")
    return True

def sync_repository(repo_path, dry_run=False):
    """Main sync function for strawberryPicker repository."""
    print(f"Syncing strawberryPicker repository at {repo_path}")

    if dry_run:
        print("DRY RUN MODE - No changes will be made")

    # Find updated models
    updated_models = find_updated_models(repo_path)

    if not updated_models:
        print("No models found to sync")
        return True

    print(f"Found {len(updated_models)} model components:")
    for model in updated_models:
        print(f"  - {model['component']} (hash: {model['hash'][:8]}...)")

    if dry_run:
        return True

    # Export detection model to ONNX if needed
    detection_model = next((m for m in updated_models if m['component'] == 'detection'), None)
    if detection_model:
        detection_dir = Path(repo_path) / "detection"
        export_detection_to_onnx(detection_model['path'], detection_dir)

    # Update metadata
    update_model_metadata(repo_path, updated_models)

    print("\nSync completed successfully!")
    print("Remember to:")
    print("1. Review and commit changes to git")
    print("2. Push to HuggingFace: git push origin main")
    print("3. Update READMEs with any new performance metrics")

    return True

def main():
    parser = argparse.ArgumentParser(description="Sync strawberryPicker models to HuggingFace")
    parser.add_argument("--repo-path",
                        default="/home/user/machine-learning/HuggingfaceModels/strawberryPicker",
                        help="Path to strawberryPicker repository")
    parser.add_argument("--dry-run", action="store_true",
                        help="Show what would be done without making changes")

    args = parser.parse_args()

    # Validate path
    if not os.path.exists(args.repo_path):
        print(f"Error: Repository path {args.repo_path} does not exist")
        sys.exit(1)

    # Run sync
    success = sync_repository(args.repo_path, args.dry_run)

    if not success:
        sys.exit(1)

if __name__ == "__main__":
    main()