Spaces:
Runtime error
Runtime error
Commit
·
f1ee247
1
Parent(s):
ee927fc
Update main.py
Browse files
main.py
CHANGED
|
@@ -85,7 +85,23 @@ def denormalize_image(image, source_range=(-1, 1)):
|
|
| 85 |
return (image * 255).clip(0, 255).astype(np.uint8)
|
| 86 |
else:
|
| 87 |
raise ValueError(f"Unsupported source range: {source_range}")
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
|
| 90 |
width, height = 256, 256
|
| 91 |
initial_images = load_initial_images(width, height)
|
|
@@ -103,7 +119,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 103 |
action_descriptions = []
|
| 104 |
initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
|
| 105 |
initial_actions = ['0:0'] * 7
|
| 106 |
-
initial_actions = ['N N N N N : N N N N N'] * 7
|
| 107 |
def unnorm_coords(x, y):
|
| 108 |
return int(x), int(y) #int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
|
| 109 |
|
|
@@ -121,7 +137,8 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 121 |
if DEBUG:
|
| 122 |
norm_x = x
|
| 123 |
norm_y = y
|
| 124 |
-
action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}")
|
|
|
|
| 125 |
prev_x = norm_x
|
| 126 |
prev_y = norm_y
|
| 127 |
elif action_type == "left_click":
|
|
@@ -180,7 +197,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 180 |
#positions = positions[1:]
|
| 181 |
mouse_position = position.split('~')
|
| 182 |
mouse_position = [int(item) for item in mouse_position]
|
| 183 |
-
mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
|
| 184 |
|
| 185 |
#previous_actions.append((action_type, mouse_position))
|
| 186 |
previous_actions = [(action_type, mouse_position))]
|
|
|
|
| 85 |
return (image * 255).clip(0, 255).astype(np.uint8)
|
| 86 |
else:
|
| 87 |
raise ValueError(f"Unsupported source range: {source_range}")
|
| 88 |
+
|
| 89 |
+
def format_action(action_str, is_padding=False):
|
| 90 |
+
if is_padding:
|
| 91 |
+
return "N N N N N : N N N N N"
|
| 92 |
+
|
| 93 |
+
# Split the x~y coordinates
|
| 94 |
+
x, y = map(int, action_str.split('~'))
|
| 95 |
+
|
| 96 |
+
# Convert numbers to padded strings and add spaces between digits
|
| 97 |
+
x_str = f"{abs(x):04d}"
|
| 98 |
+
y_str = f"{abs(y):04d}"
|
| 99 |
+
x_spaced = ' '.join(x_str)
|
| 100 |
+
y_spaced = ' '.join(y_str)
|
| 101 |
+
|
| 102 |
+
# Format with sign and proper spacing
|
| 103 |
+
return f"{'+ ' if x >= 0 else '- '}{x_spaced} : {'+ ' if y >= 0 else '- '}{y_spaced}"
|
| 104 |
+
|
| 105 |
def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
|
| 106 |
width, height = 256, 256
|
| 107 |
initial_images = load_initial_images(width, height)
|
|
|
|
| 119 |
action_descriptions = []
|
| 120 |
initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
|
| 121 |
initial_actions = ['0:0'] * 7
|
| 122 |
+
#initial_actions = ['N N N N N : N N N N N'] * 7
|
| 123 |
def unnorm_coords(x, y):
|
| 124 |
return int(x), int(y) #int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
|
| 125 |
|
|
|
|
| 137 |
if DEBUG:
|
| 138 |
norm_x = x
|
| 139 |
norm_y = y
|
| 140 |
+
#action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}")
|
| 141 |
+
action_descriptions.append(format_action(f'{norm_x-prev_x:.0f}~{norm_y-prev_y:.0f}'), pos=='0~0')
|
| 142 |
prev_x = norm_x
|
| 143 |
prev_y = norm_y
|
| 144 |
elif action_type == "left_click":
|
|
|
|
| 197 |
#positions = positions[1:]
|
| 198 |
mouse_position = position.split('~')
|
| 199 |
mouse_position = [int(item) for item in mouse_position]
|
| 200 |
+
#mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
|
| 201 |
|
| 202 |
#previous_actions.append((action_type, mouse_position))
|
| 203 |
previous_actions = [(action_type, mouse_position))]
|