SHAP_DEMO / download_imagenet_labels.py
xxnithicxx's picture
Fix requirements.txt
326e833
"""
Download ImageNet class labels
This ensures proper class names are displayed in the demo
"""
import os
import json
import urllib.request
def download_imagenet_labels():
"""Download ImageNet class index JSON file"""
json_path = "imagenet_class_index.json"
# Check if already exists
if os.path.exists(json_path):
print(f"✓ {json_path} already exists")
# Verify it's valid
try:
with open(json_path) as f:
class_idx = json.load(f)
class_names = [class_idx[str(i)][1] for i in range(1000)]
print(f"✓ Verified: {len(class_names)} class names loaded")
print(f"\nExample classes:")
for i in [0, 100, 200, 300, 400]:
print(f" - class_{i}: {class_names[i]}")
return True
except Exception as e:
print(f"⚠ File exists but is invalid: {e}")
print("Downloading fresh copy...")
# Download the file
print(f"Downloading ImageNet class labels...")
urls = [
"https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json",
"https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json",
]
for url in urls:
try:
print(f"Trying: {url}")
urllib.request.urlretrieve(url, json_path)
# Verify download
with open(json_path) as f:
class_idx = json.load(f)
class_names = [class_idx[str(i)][1] for i in range(1000)]
print(f"✓ Successfully downloaded {len(class_names)} class names")
print(f"\nExample classes:")
for i in [0, 100, 200, 300, 400]:
print(f" - class_{i}: {class_names[i]}")
return True
except Exception as e:
print(f"✗ Failed: {e}")
continue
print("\n✗ Could not download from any source")
print("The demo will still work but will show 'class_0', 'class_1', etc.")
return False
if __name__ == "__main__":
print("=" * 60)
print("ImageNet Class Labels Downloader")
print("=" * 60)
print()
success = download_imagenet_labels()
print()
print("=" * 60)
if success:
print("SUCCESS! You can now run the demo with proper class names:")
print(" python gradio_shap_demo.py")
else:
print("FAILED! But the demo will still work with placeholder names.")
print("=" * 60)