Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
8389537
1
Parent(s):
a5bfa12
main.py
CHANGED
|
@@ -110,11 +110,13 @@ def process_frame(
|
|
| 110 |
) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
|
| 111 |
"""Process a single frame through the model."""
|
| 112 |
timing = {}
|
| 113 |
-
|
| 114 |
# Temporal encoding
|
| 115 |
start = time.perf_counter()
|
|
|
|
| 116 |
output_from_rnn, hidden_states = model.temporal_encoder.forward_step(inputs)
|
| 117 |
timing['temporal_encoder'] = time.perf_counter() - start
|
|
|
|
| 118 |
|
| 119 |
# UNet sampling
|
| 120 |
start = time.perf_counter()
|
|
@@ -127,13 +129,16 @@ def process_frame(
|
|
| 127 |
verbose=False
|
| 128 |
)
|
| 129 |
timing['unet'] = time.perf_counter() - start
|
|
|
|
| 130 |
|
| 131 |
# Decoding
|
| 132 |
start = time.perf_counter()
|
| 133 |
sample = sample_latent * DATA_NORMALIZATION['std'] + DATA_NORMALIZATION['mean']
|
| 134 |
sample = model.decode_first_stage(sample)
|
|
|
|
| 135 |
sample = sample.squeeze(0).clamp(-1, 1)
|
| 136 |
timing['decode'] = time.perf_counter() - start
|
|
|
|
| 137 |
|
| 138 |
# Convert to image
|
| 139 |
sample_img = ((sample[:3].transpose(0,1).transpose(1,2).cpu().float().numpy() + 1) * 127.5).astype(np.uint8)
|
|
|
|
| 110 |
) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
|
| 111 |
"""Process a single frame through the model."""
|
| 112 |
timing = {}
|
| 113 |
+
print ('here5')
|
| 114 |
# Temporal encoding
|
| 115 |
start = time.perf_counter()
|
| 116 |
+
print ('here6')
|
| 117 |
output_from_rnn, hidden_states = model.temporal_encoder.forward_step(inputs)
|
| 118 |
timing['temporal_encoder'] = time.perf_counter() - start
|
| 119 |
+
print ('here7')
|
| 120 |
|
| 121 |
# UNet sampling
|
| 122 |
start = time.perf_counter()
|
|
|
|
| 129 |
verbose=False
|
| 130 |
)
|
| 131 |
timing['unet'] = time.perf_counter() - start
|
| 132 |
+
print ('here8')
|
| 133 |
|
| 134 |
# Decoding
|
| 135 |
start = time.perf_counter()
|
| 136 |
sample = sample_latent * DATA_NORMALIZATION['std'] + DATA_NORMALIZATION['mean']
|
| 137 |
sample = model.decode_first_stage(sample)
|
| 138 |
+
print ('here9')
|
| 139 |
sample = sample.squeeze(0).clamp(-1, 1)
|
| 140 |
timing['decode'] = time.perf_counter() - start
|
| 141 |
+
print ('here10')
|
| 142 |
|
| 143 |
# Convert to image
|
| 144 |
sample_img = ((sample[:3].transpose(0,1).transpose(1,2).cpu().float().numpy() + 1) * 127.5).astype(np.uint8)
|