Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
1a5d2e7
1
Parent(s):
8f6c968
main.py
CHANGED
|
@@ -13,7 +13,7 @@ import os
|
|
| 13 |
import time
|
| 14 |
|
| 15 |
DEBUG = False
|
| 16 |
-
DEBUG_TEACHER_FORCING =
|
| 17 |
app = FastAPI()
|
| 18 |
|
| 19 |
# Mount the static directory to serve HTML, JavaScript, and CSS files
|
|
@@ -156,7 +156,7 @@ def load_initial_images(width, height):
|
|
| 156 |
initial_images.append(np.array(img))
|
| 157 |
else:
|
| 158 |
#assert False
|
| 159 |
-
for i in range(
|
| 160 |
initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
|
| 161 |
return initial_images
|
| 162 |
|
|
@@ -202,10 +202,10 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 202 |
print ('length of previous_frames', len(previous_frames))
|
| 203 |
|
| 204 |
# Prepare the image sequence for the model
|
| 205 |
-
assert len(initial_images) ==
|
| 206 |
-
image_sequence = previous_frames[-
|
| 207 |
i = 1
|
| 208 |
-
while len(image_sequence) <
|
| 209 |
image_sequence.insert(0, initial_images[-i])
|
| 210 |
i += 1
|
| 211 |
#image_sequence.append(initial_images[len(image_sequence)])
|
|
@@ -213,18 +213,23 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 213 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
| 214 |
image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
|
| 215 |
image_sequence_tensor = image_sequence_tensor.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
# Prepare the prompt based on the previous actions
|
| 218 |
action_descriptions = []
|
| 219 |
#initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
|
| 220 |
-
initial_actions = ['0:0'] *
|
| 221 |
#initial_actions = ['N N N N N : N N N N N'] * 7
|
| 222 |
def unnorm_coords(x, y):
|
| 223 |
return int(x), int(y) #int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
|
| 224 |
|
| 225 |
# Process initial actions if there are not enough previous actions
|
| 226 |
-
while len(previous_actions) <
|
| 227 |
-
assert False
|
| 228 |
x, y = map(int, initial_actions.pop(0).split(':'))
|
| 229 |
previous_actions.insert(0, ("N", unnorm_coords(x, y)))
|
| 230 |
prev_x = 0
|
|
@@ -242,7 +247,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 242 |
previous_actions = [('move', (16, 328)), ('move', (304, 96)), ('move', (240, 192)), ('move', (152, 56)), ('left_click', (288, 176)), ('left_click', (56, 376)), ('move', (136, 360)), ('move', (112, 48))]
|
| 243 |
prompt = 'L + 0 0 5 6 : + 0 1 2 8 N + 0 4 0 0 : + 0 0 6 4 N + 0 5 0 4 : + 0 1 2 8 N + 0 4 2 4 : + 0 1 2 0 N + 0 3 2 0 : + 0 1 0 4 N + 0 2 8 0 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4'
|
| 244 |
previous_actions = [('left_click', (56, 128)), ('left_click', (400, 64)), ('move', (504, 128)), ('move', (424, 120)), ('left_click', (320, 104)), ('left_click', (280, 104)), ('move', (272, 104)), ('move', (272, 104))]
|
| 245 |
-
for action_type, pos in previous_actions[-
|
| 246 |
#print ('here3', action_type, pos)
|
| 247 |
if action_type == 'move':
|
| 248 |
action_type = 'N'
|
|
@@ -287,14 +292,14 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 287 |
else:
|
| 288 |
assert False
|
| 289 |
|
| 290 |
-
prompt = " ".join(action_descriptions[-
|
| 291 |
print(prompt)
|
| 292 |
#prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5"
|
| 293 |
#x, y, action_type = parse_action_string(action_descriptions[-1])
|
| 294 |
#pos_map, leftclick_map, x_scaled, y_scaled = create_position_and_click_map((x, y), action_type)
|
| 295 |
leftclick_maps = []
|
| 296 |
pos_maps = []
|
| 297 |
-
for j in range(1,
|
| 298 |
print ('fsfs', action_descriptions[-j])
|
| 299 |
x, y, action_type = parse_action_string(action_descriptions[-j])
|
| 300 |
pos_map_j, leftclick_map_j, x_scaled_j, y_scaled_j = create_position_and_click_map((x, y), action_type)
|
|
@@ -318,6 +323,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 318 |
# Convert the generated frame to the correct format
|
| 319 |
new_frame = new_frame.transpose(1, 2, 0)
|
| 320 |
print (new_frame.max(), new_frame.min())
|
|
|
|
| 321 |
new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
|
| 322 |
|
| 323 |
# Draw the trace of previous actions
|
|
@@ -429,6 +435,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 429 |
#mouse_position = (x, y)
|
| 430 |
|
| 431 |
#previous_actions.append((action_type, mouse_position))
|
|
|
|
| 432 |
if not DEBUG_TEACHER_FORCING:
|
| 433 |
previous_actions = []
|
| 434 |
|
|
@@ -448,6 +455,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 448 |
# Format action string
|
| 449 |
previous_actions.append((action_type, (x*8, y*8)))
|
| 450 |
try:
|
|
|
|
|
|
|
| 451 |
while True:
|
| 452 |
try:
|
| 453 |
# Receive user input with a timeout
|
|
@@ -483,36 +492,37 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 483 |
x, y, action_type = parse_action_string(position)
|
| 484 |
mouse_position = (x, y)
|
| 485 |
previous_actions.append((action_type, mouse_position))
|
| 486 |
-
if
|
| 487 |
previous_actions.append((action_type, mouse_position))
|
| 488 |
#previous_actions = [(action_type, mouse_position)]
|
| 489 |
-
if not DEBUG_TEACHER_FORCING:
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
# Log the start time
|
| 499 |
start_time = time.time()
|
| 500 |
|
| 501 |
# Predict the next frame based on the previous frames and actions
|
| 502 |
-
if DEBUG_TEACHER_FORCING:
|
| 503 |
-
|
| 504 |
-
|
| 505 |
next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
|
| 506 |
# Load and append the corresponding ground truth image instead of model output
|
| 507 |
-
print ('here4', len(previous_frames))
|
| 508 |
-
if DEBUG_TEACHER_FORCING:
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
else:
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
|
|
|
| 516 |
|
| 517 |
# Convert the numpy array to a base64 encoded image
|
| 518 |
img = Image.fromarray(next_frame)
|
|
|
|
| 13 |
import time
|
| 14 |
|
| 15 |
DEBUG = False
|
| 16 |
+
DEBUG_TEACHER_FORCING = False
|
| 17 |
app = FastAPI()
|
| 18 |
|
| 19 |
# Mount the static directory to serve HTML, JavaScript, and CSS files
|
|
|
|
| 156 |
initial_images.append(np.array(img))
|
| 157 |
else:
|
| 158 |
#assert False
|
| 159 |
+
for i in range(32):
|
| 160 |
initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
|
| 161 |
return initial_images
|
| 162 |
|
|
|
|
| 202 |
print ('length of previous_frames', len(previous_frames))
|
| 203 |
|
| 204 |
# Prepare the image sequence for the model
|
| 205 |
+
assert len(initial_images) == 32
|
| 206 |
+
image_sequence = previous_frames[-32:] # Take the last 7 frames
|
| 207 |
i = 1
|
| 208 |
+
while len(image_sequence) < 32:
|
| 209 |
image_sequence.insert(0, initial_images[-i])
|
| 210 |
i += 1
|
| 211 |
#image_sequence.append(initial_images[len(image_sequence)])
|
|
|
|
| 213 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
| 214 |
image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
|
| 215 |
image_sequence_tensor = image_sequence_tensor.to(device)
|
| 216 |
+
data_mean = -0.54
|
| 217 |
+
data_std = 6.78
|
| 218 |
+
data_min = -27.681446075439453
|
| 219 |
+
data_max = 30.854148864746094
|
| 220 |
+
image_sequence_tensor = (image_sequence_tensor - data_mean) / data_std
|
| 221 |
|
| 222 |
# Prepare the prompt based on the previous actions
|
| 223 |
action_descriptions = []
|
| 224 |
#initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
|
| 225 |
+
initial_actions = ['0:0'] * 32
|
| 226 |
#initial_actions = ['N N N N N : N N N N N'] * 7
|
| 227 |
def unnorm_coords(x, y):
|
| 228 |
return int(x), int(y) #int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
|
| 229 |
|
| 230 |
# Process initial actions if there are not enough previous actions
|
| 231 |
+
while len(previous_actions) < 33:
|
| 232 |
+
#assert False
|
| 233 |
x, y = map(int, initial_actions.pop(0).split(':'))
|
| 234 |
previous_actions.insert(0, ("N", unnorm_coords(x, y)))
|
| 235 |
prev_x = 0
|
|
|
|
| 247 |
previous_actions = [('move', (16, 328)), ('move', (304, 96)), ('move', (240, 192)), ('move', (152, 56)), ('left_click', (288, 176)), ('left_click', (56, 376)), ('move', (136, 360)), ('move', (112, 48))]
|
| 248 |
prompt = 'L + 0 0 5 6 : + 0 1 2 8 N + 0 4 0 0 : + 0 0 6 4 N + 0 5 0 4 : + 0 1 2 8 N + 0 4 2 4 : + 0 1 2 0 N + 0 3 2 0 : + 0 1 0 4 N + 0 2 8 0 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4'
|
| 249 |
previous_actions = [('left_click', (56, 128)), ('left_click', (400, 64)), ('move', (504, 128)), ('move', (424, 120)), ('left_click', (320, 104)), ('left_click', (280, 104)), ('move', (272, 104)), ('move', (272, 104))]
|
| 250 |
+
for action_type, pos in previous_actions[-33:]:
|
| 251 |
#print ('here3', action_type, pos)
|
| 252 |
if action_type == 'move':
|
| 253 |
action_type = 'N'
|
|
|
|
| 292 |
else:
|
| 293 |
assert False
|
| 294 |
|
| 295 |
+
prompt = " ".join(action_descriptions[-33:])
|
| 296 |
print(prompt)
|
| 297 |
#prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5"
|
| 298 |
#x, y, action_type = parse_action_string(action_descriptions[-1])
|
| 299 |
#pos_map, leftclick_map, x_scaled, y_scaled = create_position_and_click_map((x, y), action_type)
|
| 300 |
leftclick_maps = []
|
| 301 |
pos_maps = []
|
| 302 |
+
for j in range(1, 34):
|
| 303 |
print ('fsfs', action_descriptions[-j])
|
| 304 |
x, y, action_type = parse_action_string(action_descriptions[-j])
|
| 305 |
pos_map_j, leftclick_map_j, x_scaled_j, y_scaled_j = create_position_and_click_map((x, y), action_type)
|
|
|
|
| 323 |
# Convert the generated frame to the correct format
|
| 324 |
new_frame = new_frame.transpose(1, 2, 0)
|
| 325 |
print (new_frame.max(), new_frame.min())
|
| 326 |
+
new_frame = new_frame * data_std + data_mean
|
| 327 |
new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
|
| 328 |
|
| 329 |
# Draw the trace of previous actions
|
|
|
|
| 435 |
#mouse_position = (x, y)
|
| 436 |
|
| 437 |
#previous_actions.append((action_type, mouse_position))
|
| 438 |
+
|
| 439 |
if not DEBUG_TEACHER_FORCING:
|
| 440 |
previous_actions = []
|
| 441 |
|
|
|
|
| 455 |
# Format action string
|
| 456 |
previous_actions.append((action_type, (x*8, y*8)))
|
| 457 |
try:
|
| 458 |
+
previous_actions = []
|
| 459 |
+
previous_frames = []
|
| 460 |
while True:
|
| 461 |
try:
|
| 462 |
# Receive user input with a timeout
|
|
|
|
| 492 |
x, y, action_type = parse_action_string(position)
|
| 493 |
mouse_position = (x, y)
|
| 494 |
previous_actions.append((action_type, mouse_position))
|
| 495 |
+
if True:
|
| 496 |
previous_actions.append((action_type, mouse_position))
|
| 497 |
#previous_actions = [(action_type, mouse_position)]
|
| 498 |
+
#if not DEBUG_TEACHER_FORCING:
|
| 499 |
+
# x, y = mouse_position
|
| 500 |
+
# x = x//8 * 8
|
| 501 |
+
# y = y // 8 * 8
|
| 502 |
+
# assert x % 8 == 0
|
| 503 |
+
# assert y % 8 == 0
|
| 504 |
+
# mouse_position = (x, y)
|
| 505 |
+
# #mouse_position = (x//8, y//8)
|
| 506 |
+
# previous_actions.append((action_type, mouse_position))
|
| 507 |
# Log the start time
|
| 508 |
start_time = time.time()
|
| 509 |
|
| 510 |
# Predict the next frame based on the previous frames and actions
|
| 511 |
+
#if DEBUG_TEACHER_FORCING:
|
| 512 |
+
# print ('predicting', f"record_10003/image_{117+len(previous_frames)}.png")
|
| 513 |
+
print ('previous_actions', previous_actions)
|
| 514 |
next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
|
| 515 |
# Load and append the corresponding ground truth image instead of model output
|
| 516 |
+
#print ('here4', len(previous_frames))
|
| 517 |
+
#if DEBUG_TEACHER_FORCING:
|
| 518 |
+
# img = Image.open(f"record_10003/image_{117+len(previous_frames)}.png")
|
| 519 |
+
# previous_frames.append(np.array(img))
|
| 520 |
+
#else:
|
| 521 |
+
# assert False
|
| 522 |
+
# previous_frames.append(next_frame_append)
|
| 523 |
+
# pass
|
| 524 |
+
previous_frames = []
|
| 525 |
+
previous_actions = []
|
| 526 |
|
| 527 |
# Convert the numpy array to a base64 encoded image
|
| 528 |
img = Image.fromarray(next_frame)
|