Yacine Jernite commited on
Commit
10d9a6a
·
1 Parent(s): 06103c4

missing file

Browse files
Files changed (1) hide show
  1. utils/dataset.py +193 -0
utils/dataset.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset utilities for saving and loading test results."""
2
+
3
+ from datetime import datetime
4
+
5
+ from datasets import Dataset, load_dataset
6
+ from huggingface_hub import HfApi
7
+
8
+ from utils.model_interface import extract_model_id, get_model_info
9
+
10
+
11
+ def get_username_from_token(token: str | None) -> str:
12
+ """
13
+ Get username from Hugging Face token using whoami.
14
+
15
+ Args:
16
+ token: HF token string or None
17
+
18
+ Returns:
19
+ Username string, or "yjernite" as fallback if token is None or whoami fails
20
+ """
21
+ if token is None:
22
+ return "yjernite"
23
+
24
+ try:
25
+ api = HfApi()
26
+ user_info = api.whoami(token=token)
27
+ return user_info.get("name", "yjernite")
28
+ except Exception:
29
+ return "yjernite"
30
+
31
+
32
+ def get_dataset_repo_id(token: str | None) -> str:
33
+ """
34
+ Get dataset repository ID for the current user.
35
+
36
+ Args:
37
+ token: HF token string or None
38
+
39
+ Returns:
40
+ Dataset repo ID in format "{username}/moderation-test-results"
41
+ """
42
+ username = get_username_from_token(token)
43
+ return f"{username}/moderation-test-results"
44
+
45
+
46
+ def load_dataset_from_hub(token: str | None) -> tuple[list[dict], Exception | None]:
47
+ """
48
+ Load dataset from Hub and return list of examples.
49
+
50
+ Args:
51
+ token: HF token string or None
52
+
53
+ Returns:
54
+ Tuple of (list of example dicts, error Exception or None if successful)
55
+ """
56
+ repo_id = get_dataset_repo_id(token)
57
+
58
+ try:
59
+ # Use load_dataset - more standard way to load from Hub
60
+ dataset_dict = load_dataset(repo_id, token=token)
61
+ # Get the default split (usually 'train' or first split)
62
+ dataset = dataset_dict[list(dataset_dict.keys())[0]]
63
+
64
+ # Convert to list of dicts
65
+ examples = dataset.to_list()
66
+ return examples, None
67
+ except FileNotFoundError:
68
+ # Dataset doesn't exist yet
69
+ return [], None
70
+ except Exception as e:
71
+ # Other errors (network, auth, etc.) - return error
72
+ return [], e
73
+
74
+
75
+ def format_categories_and_reasoning(parsed: dict) -> str:
76
+ """
77
+ Format categories and reasoning from parsed JSON response.
78
+
79
+ Args:
80
+ parsed: Parsed JSON dict with 'categories' key
81
+
82
+ Returns:
83
+ Formatted markdown string
84
+ """
85
+ categories = parsed.get("categories", [])
86
+
87
+ if categories and len(categories) > 0:
88
+ cat_text = "### Categories:\n\n"
89
+ for cat in categories:
90
+ category_name = cat.get('category', 'Unknown')
91
+ reasoning_text = cat.get('reasoning', 'No reasoning provided')
92
+ policy_source = cat.get('policy_source', '')
93
+
94
+ cat_text += f"- **Category:** {category_name}\n"
95
+ cat_text += f" - **Explanation:** {reasoning_text}\n"
96
+ if policy_source:
97
+ cat_text += f" - **Policy Source:** {policy_source}\n"
98
+ cat_text += "\n\n"
99
+ return cat_text
100
+ else:
101
+ return "*No categories found in response*\n\nThis output expects a valid JSON response, as specified for example in the default prompt.\n\nThe raw response can be seen in the Model Response section below."
102
+
103
+
104
+ def save_to_dataset(token: str | None, data: dict) -> tuple[bool, str]:
105
+ """
106
+ Save test result to Hugging Face dataset.
107
+
108
+ Args:
109
+ token: HF token string or None
110
+ data: Dict with all test result fields
111
+
112
+ Returns:
113
+ Tuple of (success: bool, message: str)
114
+ """
115
+ try:
116
+ repo_id = get_dataset_repo_id(token)
117
+
118
+ # Load existing dataset and examples using shared function
119
+ examples, load_error = load_dataset_from_hub(token)
120
+
121
+ # If there was an error loading (other than FileNotFoundError), raise it
122
+ if load_error is not None:
123
+ raise load_error
124
+
125
+ # Append new example
126
+ examples.append(data)
127
+
128
+ # Create new dataset with all examples
129
+ dataset = Dataset.from_list(examples)
130
+
131
+ # Push to hub
132
+ dataset.push_to_hub(repo_id, token=token)
133
+ return True, f"Saved to {repo_id}"
134
+ except FileNotFoundError:
135
+ # Dataset doesn't exist yet, create new one
136
+ try:
137
+ repo_id = get_dataset_repo_id(token)
138
+ dataset = Dataset.from_list([data])
139
+ dataset.push_to_hub(repo_id, token=token)
140
+ return True, f"Saved to {repo_id}"
141
+ except Exception as e:
142
+ return False, f"Failed to create new dataset: {str(e)}"
143
+ except Exception as e:
144
+ return False, f"Failed to save: {str(e)}"
145
+
146
+
147
+ def load_dataset_examples(token: str | None) -> tuple[list[dict], list[str]]:
148
+ """
149
+ Load examples from Hugging Face dataset.
150
+
151
+ Args:
152
+ token: HF token string or None
153
+
154
+ Returns:
155
+ Tuple of (list of example dicts, list of formatted dropdown labels)
156
+ """
157
+ # Use shared loading function
158
+ examples, load_error = load_dataset_from_hub(token)
159
+
160
+ # If there was an error loading, return empty lists
161
+ if load_error is not None:
162
+ return [], []
163
+
164
+ if not examples:
165
+ return [], []
166
+
167
+ # Format dropdown labels
168
+ labels = []
169
+ for idx, example in enumerate(examples):
170
+ input_text = example.get("input", "")
171
+ model_selection = example.get("model_selection", "")
172
+ policy_violation = example.get("policy_violation", -1)
173
+
174
+ # Get label emoji
175
+ if policy_violation == 1:
176
+ label_emoji = "❌"
177
+ elif policy_violation == 0:
178
+ label_emoji = "✅"
179
+ else:
180
+ label_emoji = "⚠️"
181
+
182
+ # Extract model name
183
+ model_id = extract_model_id(model_selection)
184
+ model_info = get_model_info(model_id) if model_id else None
185
+ model_name = model_info.get("name", model_id) if model_info else model_id or "Unknown"
186
+
187
+ # Truncate input for label
188
+ input_preview = input_text[:40] + "..." if len(input_text) > 40 else input_text
189
+ label = f"{input_preview} - {label_emoji} - {model_name} - #{idx}"
190
+ labels.append(label)
191
+
192
+ return examples, labels
193
+