Spaces:
Running
on
Zero
Running
on
Zero
Fix cached error
Browse files
app.py
CHANGED
|
@@ -50,10 +50,12 @@ happy_file_path = "assets/happy.jpg"
|
|
| 50 |
def generate_activations(image):
|
| 51 |
prompt = "<image>"
|
| 52 |
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
|
| 53 |
-
global
|
|
|
|
|
|
|
| 54 |
|
| 55 |
def hook(module: torch.nn.Module, _, outputs):
|
| 56 |
-
global
|
| 57 |
# Maybe unpack tuple outputs
|
| 58 |
if isinstance(outputs, tuple):
|
| 59 |
unpack_outputs = list(outputs)
|
|
@@ -72,7 +74,7 @@ def generate_activations(image):
|
|
| 72 |
result = torch.zeros_like(latents)
|
| 73 |
# results (bs, seq, num_latents)
|
| 74 |
result.scatter_(-1, topk.indices, topk.values)
|
| 75 |
-
|
| 76 |
topk_indices = (
|
| 77 |
latents.squeeze(0).mean(dim=0).topk(k=100).indices.detach().cpu()
|
| 78 |
)
|
|
@@ -91,10 +93,9 @@ def generate_activations(image):
|
|
| 91 |
handle.remove()
|
| 92 |
|
| 93 |
torch.cuda.empty_cache()
|
| 94 |
-
return topk_indices
|
| 95 |
|
| 96 |
-
def visualize_activations(image, feature_num):
|
| 97 |
-
global cached_tensor
|
| 98 |
base_img_tokens = 576
|
| 99 |
patch_size = 24
|
| 100 |
# Using Cached tensor
|
|
@@ -191,6 +192,7 @@ def generate_with_clamp(feature_idx, feature_strength, text, image, chat_history
|
|
| 191 |
|
| 192 |
|
| 193 |
with gr.Blocks() as demo:
|
|
|
|
| 194 |
gr.Markdown(
|
| 195 |
"""
|
| 196 |
# Large Multi-modal Models Can Interpret Features in Large Multi-modal Models
|
|
@@ -210,12 +212,12 @@ with gr.Blocks() as demo:
|
|
| 210 |
with gr.Row():
|
| 211 |
clear_btn = gr.ClearButton([image, topk_features], value="Clear")
|
| 212 |
submit_btn = gr.Button("Submit", variant="primary")
|
| 213 |
-
submit_btn.click(generate_activations, inputs=[image], outputs=[topk_features])
|
| 214 |
with gr.Column():
|
| 215 |
output = gr.Image(label="Activation Visualization")
|
| 216 |
feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True)
|
| 217 |
visualize_btn = gr.Button("Visualize", variant="primary")
|
| 218 |
-
visualize_btn.click(visualize_activations, inputs=[image, feature_num], outputs=[output])
|
| 219 |
|
| 220 |
dummy_text = gr.Textbox(visible=False, label="Explanation")
|
| 221 |
gr.Examples(
|
|
@@ -261,7 +263,6 @@ with gr.Blocks() as demo:
|
|
| 261 |
|
| 262 |
|
| 263 |
if __name__ == "__main__":
|
| 264 |
-
cached_tensor = None
|
| 265 |
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
|
| 266 |
sae = load_single_sae("lmms-lab/llama3-llava-next-8b-hf-sae-131k", "model.layers.24")
|
| 267 |
model, processor = maybe_load_llava_model(
|
|
|
|
| 50 |
def generate_activations(image):
|
| 51 |
prompt = "<image>"
|
| 52 |
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
|
| 53 |
+
global topk_indices
|
| 54 |
+
|
| 55 |
+
cached_list = []
|
| 56 |
|
| 57 |
def hook(module: torch.nn.Module, _, outputs):
|
| 58 |
+
global topk_indices
|
| 59 |
# Maybe unpack tuple outputs
|
| 60 |
if isinstance(outputs, tuple):
|
| 61 |
unpack_outputs = list(outputs)
|
|
|
|
| 74 |
result = torch.zeros_like(latents)
|
| 75 |
# results (bs, seq, num_latents)
|
| 76 |
result.scatter_(-1, topk.indices, topk.values)
|
| 77 |
+
cached_list.append(result.detach().cpu())
|
| 78 |
topk_indices = (
|
| 79 |
latents.squeeze(0).mean(dim=0).topk(k=100).indices.detach().cpu()
|
| 80 |
)
|
|
|
|
| 93 |
handle.remove()
|
| 94 |
|
| 95 |
torch.cuda.empty_cache()
|
| 96 |
+
return topk_indices, cached_list[0]
|
| 97 |
|
| 98 |
+
def visualize_activations(image, feature_num, cached_tensor):
|
|
|
|
| 99 |
base_img_tokens = 576
|
| 100 |
patch_size = 24
|
| 101 |
# Using Cached tensor
|
|
|
|
| 192 |
|
| 193 |
|
| 194 |
with gr.Blocks() as demo:
|
| 195 |
+
cached_tensor = gr.State()
|
| 196 |
gr.Markdown(
|
| 197 |
"""
|
| 198 |
# Large Multi-modal Models Can Interpret Features in Large Multi-modal Models
|
|
|
|
| 212 |
with gr.Row():
|
| 213 |
clear_btn = gr.ClearButton([image, topk_features], value="Clear")
|
| 214 |
submit_btn = gr.Button("Submit", variant="primary")
|
| 215 |
+
submit_btn.click(generate_activations, inputs=[image], outputs=[topk_features, cached_tensor])
|
| 216 |
with gr.Column():
|
| 217 |
output = gr.Image(label="Activation Visualization")
|
| 218 |
feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True)
|
| 219 |
visualize_btn = gr.Button("Visualize", variant="primary")
|
| 220 |
+
visualize_btn.click(visualize_activations, inputs=[image, feature_num, cached_tensor], outputs=[output])
|
| 221 |
|
| 222 |
dummy_text = gr.Textbox(visible=False, label="Explanation")
|
| 223 |
gr.Examples(
|
|
|
|
| 263 |
|
| 264 |
|
| 265 |
if __name__ == "__main__":
|
|
|
|
| 266 |
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
|
| 267 |
sae = load_single_sae("lmms-lab/llama3-llava-next-8b-hf-sae-131k", "model.layers.24")
|
| 268 |
model, processor = maybe_load_llava_model(
|