Spaces:
Sleeping
Sleeping
| """ | |
| React Agent for Cyber Knowledge Base | |
| This script creates a ReAct agent using LangGraph that can use the CyberKnowledgeBase | |
| search method as a tool to retrieve MITRE ATT&CK techniques. | |
| """ | |
| import os | |
| import sys | |
| import json | |
| from typing import List, Dict, Any, Union, Optional | |
| from pathlib import Path | |
| # Add parent directory to path for imports | |
| sys.path.append(str(Path(__file__).parent.parent)) | |
| from langchain_core.tools import tool | |
| from langchain_core.messages import HumanMessage, AIMessage, ToolMessage | |
| from langgraph.prebuilt import create_react_agent | |
| from langchain.chat_models import init_chat_model | |
| from langchain_core.language_models.chat_models import BaseChatModel | |
| # Import local modules | |
| from src.knowledge_base.cyber_knowledge_base import CyberKnowledgeBase | |
| # Initialize the knowledge base | |
| def init_knowledge_base( | |
| persist_dir: str = "./cyber_knowledge_base", | |
| ) -> CyberKnowledgeBase: | |
| """Initialize and load the cyber knowledge base""" | |
| kb = CyberKnowledgeBase() | |
| # Try to load existing knowledge base | |
| if kb.load_knowledge_base(persist_dir): | |
| print("[SUCCESS] Loaded existing knowledge base") | |
| return kb | |
| else: | |
| print("[WARNING] Could not load knowledge base, please build it first") | |
| print("Run: python src/scripts/build_cyber_database.py") | |
| sys.exit(1) | |
| def _format_results_as_json(results) -> List[Dict[str, Any]]: | |
| """Format search results as structured JSON""" | |
| output = [] | |
| for doc in results: | |
| technique_info = { | |
| "attack_id": doc.metadata.get("attack_id", "Unknown"), | |
| "name": doc.metadata.get("name", "Unknown"), | |
| "tactics": [ | |
| t.strip() | |
| for t in doc.metadata.get("tactics", "").split(",") | |
| if t.strip() | |
| ], | |
| "platforms": [ | |
| p.strip() | |
| for p in doc.metadata.get("platforms", "").split(",") | |
| if p.strip() | |
| ], | |
| "description": ( | |
| doc.page_content.split("Description: ")[-1] | |
| if "Description: " in doc.page_content | |
| else doc.page_content | |
| ), | |
| "relevance_score": doc.metadata.get( | |
| "relevance_score", None | |
| ), # From reranking | |
| } | |
| output.append(technique_info) | |
| return output | |
| def create_agent(llm_client: BaseChatModel, kb: CyberKnowledgeBase): | |
| """Create a ReAct agent with LangGraph""" | |
| # Define the tools bound to the provided knowledge base | |
| def search_techniques( | |
| queries: Union[str, List[str]], | |
| top_k: int = 5, | |
| rerank_query: Optional[str] = None, | |
| ) -> str: | |
| """ | |
| Search for MITRE ATT&CK techniques using the knowledge base. | |
| This tool searches a vector database containing MITRE ATT&CK technique descriptions, | |
| including their tactics, platforms, and detailed behavioral information. Each technique | |
| in the database has its full description embedded for semantic similarity search. | |
| Args: | |
| queries: Single search query string OR list of query strings. | |
| rerank_query: Optional tag echoed in the output for transparency. | |
| top_k: Number of results to return per query (default: 10) | |
| Returns: | |
| JSON string with results grouped per query. Each group contains: | |
| - query: The original query string | |
| - techniques: List of technique objects (attack_id, name, tactics, platforms, description, relevance_score) | |
| - total_results: Number of techniques in this group | |
| """ | |
| try: | |
| # Convert single query to list for uniform processing | |
| if isinstance(queries, str): | |
| queries = [queries] | |
| # Run a normal search once per query and keep results associated with that query | |
| results_by_query: List[Dict[str, Any]] = [] | |
| for i, q in enumerate(queries, 1): | |
| print(f"[INFO] Query {i}/{len(queries)}: '{q}'") | |
| per_query_results = kb.search(q, top_k=top_k) | |
| techniques = _format_results_as_json(per_query_results) | |
| results_by_query.append( | |
| { | |
| "query": q, | |
| "techniques": techniques, | |
| "total_results": len(techniques), | |
| } | |
| ) | |
| # If all queries returned no results | |
| if all(len(group["techniques"]) == 0 for group in results_by_query): | |
| return json.dumps( | |
| { | |
| "results_by_query": results_by_query, | |
| "message": "No techniques found matching the provided queries.", | |
| }, | |
| indent=2, | |
| ) | |
| return json.dumps( | |
| { | |
| "results_by_query": results_by_query, | |
| "queries_used": queries, | |
| "rerank_query": rerank_query, | |
| }, | |
| indent=2, | |
| ) | |
| except Exception as e: | |
| return json.dumps( | |
| { | |
| "error": str(e), | |
| "techniques": [], | |
| "message": "Error occurred during search", | |
| }, | |
| indent=2, | |
| ) | |
| tools = [search_techniques] | |
| # Define the system prompt for the agent | |
| system_prompt = """ | |
| You are a cybersecurity analyst assistant that helps answer questions about MITRE ATT&CK techniques. | |
| You have access to a knowledge base of MITRE ATT&CK techniques that you can search. | |
| Use the search_techniques tool to find relevant techniques based on the user's query. | |
| """ | |
| # Get the LLM from the client | |
| llm = llm_client | |
| # Create the React agent | |
| agent_runnable = create_react_agent(llm, tools, prompt=system_prompt) | |
| return agent_runnable | |
| def run_test_queries(agent): | |
| """Run the agent with some test queries""" | |
| # Test queries | |
| test_queries = [ | |
| "What techniques are used for credential dumping?", | |
| "How do attackers use process injection for defense evasion?", | |
| "What are common persistence techniques on Windows systems?", | |
| ] | |
| # Run the agent with test queries | |
| for i, query in enumerate(test_queries, 1): | |
| print(f"\n\n===== Test Query {i}: '{query}' =====\n") | |
| # Create the input state | |
| state = {"messages": [HumanMessage(content=query)]} | |
| # Run the agent | |
| result = agent.invoke(state) | |
| # Print all intermediate messages | |
| print("[TRACE] Conversation messages:") | |
| for message in result["messages"]: | |
| if isinstance(message, HumanMessage): | |
| print(f"- [Human] {message.content}") | |
| elif isinstance(message, AIMessage): | |
| agent_name = getattr(message, "name", None) or "agent" | |
| print(f"- [Agent:{agent_name}] {message.content}") | |
| if "function_call" in message.additional_kwargs: | |
| fc = message.additional_kwargs["function_call"] | |
| print(f" [ToolCall] {fc.get('name')}: {fc.get('arguments')}") | |
| elif isinstance(message, ToolMessage): | |
| tool_name = getattr(message, "name", None) or "tool" | |
| print(f"- [Tool:{tool_name}] {message.content}") | |
| def interactive_mode(agent): | |
| """Run the agent in interactive mode""" | |
| print("\n\n===== Interactive Mode =====") | |
| print("Type 'exit' or 'quit' to end the session\n") | |
| # Keep track of conversation history | |
| messages = [] | |
| while True: | |
| # Get user input | |
| user_input = input("\nYou: ") | |
| # Check if user wants to exit | |
| if user_input.lower() in ["exit", "quit"]: | |
| print("Exiting interactive mode...") | |
| break | |
| # Add user message to history | |
| messages.append(HumanMessage(content=user_input)) | |
| # Create the input state | |
| state = {"messages": messages.copy()} | |
| # Run the agent | |
| try: | |
| result = agent.invoke(state) | |
| # Update conversation history with agent's response | |
| messages = result["messages"] | |
| # Print the agent's response | |
| for message in messages: | |
| if isinstance(message, AIMessage): | |
| print("\n" + "=" * 50) | |
| print(f"\nAgent: {message.content}") | |
| if "function_call" in message.additional_kwargs: | |
| print( | |
| "Function call:", | |
| message.additional_kwargs["function_call"]["name"], | |
| ) | |
| print( | |
| "Arguments:", | |
| message.additional_kwargs["function_call"]["arguments"], | |
| ) | |
| print("-" * 50) | |
| if isinstance(message, ToolMessage): | |
| print("Tool output:", message.content) | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| def main(): | |
| """Main function to run the agent""" | |
| global kb | |
| # Initialize the knowledge base | |
| kb_path = os.path.join( | |
| os.path.dirname(os.path.dirname(os.path.dirname(__file__))), | |
| "cyber_knowledge_base", | |
| ) | |
| kb = init_knowledge_base(kb_path) | |
| # Print KB stats | |
| stats = kb.get_stats() | |
| print( | |
| f"Knowledge base loaded with {stats.get('total_techniques', 'unknown')} techniques" | |
| ) | |
| # Initialize the LLM client (using environment variables) | |
| llm_client = init_chat_model("google_genai:gemini-2.0-flash", temperature=0.2) | |
| # Create the agent | |
| agent = create_agent(llm_client, kb) | |
| # Parse command line arguments | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Run the Cyber KB React Agent") | |
| parser.add_argument( | |
| "--interactive", "-i", action="store_true", help="Run in interactive mode" | |
| ) | |
| parser.add_argument("--test", "-t", action="store_true", help="Run test queries") | |
| args = parser.parse_args() | |
| # Run in the appropriate mode | |
| if args.interactive: | |
| interactive_mode(agent) | |
| elif args.test: | |
| run_test_queries(agent) | |
| else: | |
| # Default: run interactive mode | |
| interactive_mode(agent) | |
| if __name__ == "__main__": | |
| main() | |