robinhad commited on
Commit
ed76f0e
·
verified ·
1 Parent(s): cba7458

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import subprocess
3
  import tempfile
4
 
5
- subprocess.run('pip install flash-attn==2.8.0.post2 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
6
 
7
  import threading
8
 
@@ -56,7 +56,7 @@ def load_model():
56
  MODEL_ID,
57
  dtype=torch.bfloat16, # if device == "cuda" else torch.float32,
58
  device_map="auto", # if device == "cuda" else None,
59
- attn_implementation="flash_attention_2",# "kernels-community/vllm-flash-attn3", # #
60
  ) # .cuda()
61
  print(f"Selected device:", device)
62
  return model, tokenizer, processor, device
 
2
  import subprocess
3
  import tempfile
4
 
5
+ #subprocess.run('pip install flash-attn==2.8.0.post2 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
6
 
7
  import threading
8
 
 
56
  MODEL_ID,
57
  dtype=torch.bfloat16, # if device == "cuda" else torch.float32,
58
  device_map="auto", # if device == "cuda" else None,
59
+ attn_implementation="kernels-community/flash-attn" # "flash_attention_2",# "kernels-community/vllm-flash-attn3", # #
60
  ) # .cuda()
61
  print(f"Selected device:", device)
62
  return model, tokenizer, processor, device