BirdScopeAI / langgraph_agent /test_agent.py
facemelter's picture
Initial commit to hf space for hackathon
ff0e97f verified
"""
Test suite for bird classifier agents.
"""
import asyncio
import sys
from pathlib import Path
# Add parent directory to path so imports work from any location
parent_dir = Path(__file__).parent.parent
if str(parent_dir) not in sys.path:
sys.path.insert(0, str(parent_dir))
from langgraph_agent import AgentFactory
async def test_classifier_agent():
"""Test basic classifier agent with multiple images."""
print("\n" + "="*70)
print("Test Suite: Basic Classifier Agent")
print("="*70 + "\n")
# Create agent
agent = await AgentFactory.create_classifier_agent()
test_urls = [
"https://images.unsplash.com/photo-1555169062-013468b47731?w=400",
"https://images.unsplash.com/photo-1445820200644-69f87d946277?w=400",
]
for i, url in enumerate(test_urls, 1):
print(f"\n[TEST {i}/{len(test_urls)}]")
print("="*70)
result = await agent.ainvoke({
"messages": [{
"role": "user",
"content": f"Classify the bird in this image: {url}"
}]
})
print(f"\n[RESULT]: {result['messages'][-1].content}\n")
print("\n[ALL TESTS COMPLETE!]\n")
async def test_multi_server_agent():
"""Test multi-server agent with classifier + eBird."""
print("\n" + "="*70)
print("Test Suite: Multi-Server Agent")
print("="*70 + "\n")
# Create agent with memory
agent = await AgentFactory.create_multi_server_agent(with_memory=True)
config = {"configurable": {"thread_id": "test_session"}}
# Test 1: Classify bird
print("\n[TEST 1]: Classify bird from URL")
print("="*70)
result1 = await agent.ainvoke({
"messages": [{
"role": "user",
"content": "What bird is this? https://images.unsplash.com/photo-1555169062-013468b47731?w=400"
}]
}, config)
print(f"\n[RESULT]: {result1['messages'][-1].content}\n")
# Test 2: Ask follow-up (tests memory)
print("\n[TEST 2]: Follow-up question (tests memory)")
print("="*70)
result2 = await agent.ainvoke({
"messages": [{
"role": "user",
"content": "Where can I see this bird near Boston (42.36, -71.06)?"
}]
}, config)
print(f"\n[RESULT]: {result2['messages'][-1].content}\n")
print("\n[ALL TESTS COMPLETE!]\n")
if __name__ == "__main__":
import sys
if len(sys.argv) > 1 and sys.argv[1] == "multi":
asyncio.run(test_multi_server_agent())
else:
asyncio.run(test_classifier_agent())