Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
313bb52
1
Parent(s):
94b146f
main.py
CHANGED
|
@@ -203,23 +203,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 203 |
prev_y = 0
|
| 204 |
#print ('here')
|
| 205 |
|
| 206 |
-
|
| 207 |
-
#print ('here2')
|
| 208 |
-
# Use the predefined actions for image_81
|
| 209 |
-
debug_actions = [
|
| 210 |
-
'N + 0 8 5 3 : + 0 4 5 0', 'N + 0 8 7 1 : + 0 4 6 3',
|
| 211 |
-
'N + 0 8 9 0 : + 0 4 7 5', 'N + 0 9 0 8 : + 0 4 8 8',
|
| 212 |
-
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
| 213 |
-
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
| 214 |
-
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
| 215 |
-
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
| 216 |
-
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
| 217 |
-
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1'
|
| 218 |
-
]
|
| 219 |
-
previous_actions = []
|
| 220 |
-
for action in debug_actions[-8:]:
|
| 221 |
-
x, y, action_type = parse_action_string(action)
|
| 222 |
-
previous_actions.append((action_type, (x, y)))
|
| 223 |
|
| 224 |
for action_type, pos in previous_actions: #[-8:]:
|
| 225 |
print ('here3', action_type, pos)
|
|
@@ -302,6 +286,31 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 302 |
positions = ['815~335']
|
| 303 |
#positions = ['787~342']
|
| 304 |
positions = ['300~800']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
#positions = positions[:4]
|
| 306 |
try:
|
| 307 |
while True:
|
|
@@ -325,9 +334,13 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 325 |
#mouse_position = position.split('~')
|
| 326 |
#mouse_position = [int(item) for item in mouse_position]
|
| 327 |
#mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
|
| 328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
previous_actions.append((action_type, mouse_position))
|
| 330 |
-
previous_actions = [(action_type, mouse_position)]
|
| 331 |
|
| 332 |
# Log the start time
|
| 333 |
start_time = time.time()
|
|
@@ -336,7 +349,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 336 |
next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
|
| 337 |
# Load and append the corresponding ground truth image instead of model output
|
| 338 |
#img = Image.open(f"image_{len(previous_frames)%7}.png")
|
| 339 |
-
|
| 340 |
|
| 341 |
# Convert the numpy array to a base64 encoded image
|
| 342 |
img = Image.fromarray(next_frame)
|
|
|
|
| 203 |
prev_y = 0
|
| 204 |
#print ('here')
|
| 205 |
|
| 206 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
for action_type, pos in previous_actions: #[-8:]:
|
| 209 |
print ('here3', action_type, pos)
|
|
|
|
| 286 |
positions = ['815~335']
|
| 287 |
#positions = ['787~342']
|
| 288 |
positions = ['300~800']
|
| 289 |
+
|
| 290 |
+
if DEBUG_TEACHER_FORCING:
|
| 291 |
+
#print ('here2')
|
| 292 |
+
# Use the predefined actions for image_81
|
| 293 |
+
debug_actions = [
|
| 294 |
+
'N + 0 8 5 3 : + 0 4 5 0', 'N + 0 8 7 1 : + 0 4 6 3',
|
| 295 |
+
'N + 0 8 9 0 : + 0 4 7 5', 'N + 0 9 0 8 : + 0 4 8 8',
|
| 296 |
+
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
| 297 |
+
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
| 298 |
+
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
| 299 |
+
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
| 300 |
+
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
| 301 |
+
'N + 0 9 2 7 : + 0 5 0 1', #'N + 0 9 2 7 : + 0 5 0 1'
|
| 302 |
+
]
|
| 303 |
+
previous_actions = []
|
| 304 |
+
for action in debug_actions[-8:]:
|
| 305 |
+
x, y, action_type = parse_action_string(action)
|
| 306 |
+
previous_actions.append((action_type, (x, y)))
|
| 307 |
+
positions = [
|
| 308 |
+
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 1 8 : + 0 4 9 2',
|
| 309 |
+
'N + 0 9 0 8 : + 0 4 8 3', 'N + 0 8 9 8 : + 0 4 7 4',
|
| 310 |
+
'N + 0 8 8 9 : + 0 4 6 5', 'N + 0 8 8 0 : + 0 4 5 6',
|
| 311 |
+
'N + 0 8 7 0 : + 0 4 4 7', 'N + 0 8 6 0 : + 0 4 3 8',
|
| 312 |
+
'N + 0 8 5 1 : + 0 4 2 9', 'N + 0 8 4 2 : + 0 4 2 0',
|
| 313 |
+
'N + 0 8 3 2 : + 0 4 1 1', 'N + 0 8 3 2 : + 0 4 1 1']
|
| 314 |
#positions = positions[:4]
|
| 315 |
try:
|
| 316 |
while True:
|
|
|
|
| 334 |
#mouse_position = position.split('~')
|
| 335 |
#mouse_position = [int(item) for item in mouse_position]
|
| 336 |
#mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
|
| 337 |
+
if DEBUG_TEACHER_FORCING:
|
| 338 |
+
position = positions[0]
|
| 339 |
+
positions = positions[1:]
|
| 340 |
+
x, y, action_type = parse_action_string(position)
|
| 341 |
+
mouse_position = (x, y)
|
| 342 |
previous_actions.append((action_type, mouse_position))
|
| 343 |
+
#previous_actions = [(action_type, mouse_position)]
|
| 344 |
|
| 345 |
# Log the start time
|
| 346 |
start_time = time.time()
|
|
|
|
| 349 |
next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
|
| 350 |
# Load and append the corresponding ground truth image instead of model output
|
| 351 |
#img = Image.open(f"image_{len(previous_frames)%7}.png")
|
| 352 |
+
previous_frames.append(next_frame_append)
|
| 353 |
|
| 354 |
# Convert the numpy array to a base64 encoded image
|
| 355 |
img = Image.fromarray(next_frame)
|