longjava2024 commited on
Commit
cd0fd85
·
verified ·
1 Parent(s): d78e827

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -0
app.py CHANGED
@@ -6,6 +6,7 @@ import re
6
  from io import BytesIO
7
  import types
8
  import sys
 
9
 
10
  # Force CPU-only & disable bitsandbytes CUDA checks in this environment
11
  os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
@@ -18,6 +19,7 @@ from PIL import Image
18
  from torchvision.transforms.functional import InterpolationMode
19
  import gradio as gr
20
 
 
21
  # Stub bitsandbytes and flash_attn to avoid GPU driver checks in CPU-only environments
22
  fake_bnb = types.ModuleType("bitsandbytes")
23
  def _bnb_unavailable(*args, **kwargs):
@@ -27,6 +29,8 @@ fake_bnb._bnb_unavailable = _bnb_unavailable
27
  sys.modules["bitsandbytes"] = fake_bnb
28
 
29
  fake_flash = types.ModuleType("flash_attn")
 
 
30
  sys.modules["flash_attn"] = fake_flash
31
 
32
  from transformers import AutoModel, AutoTokenizer
 
6
  from io import BytesIO
7
  import types
8
  import sys
9
+ import importlib.machinery
10
 
11
  # Force CPU-only & disable bitsandbytes CUDA checks in this environment
12
  os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
 
19
  from torchvision.transforms.functional import InterpolationMode
20
  import gradio as gr
21
 
22
+ # Stub bitsandbytes and flash_attn to avoid GPU driver checks in CPU-only environments
23
  # Stub bitsandbytes and flash_attn to avoid GPU driver checks in CPU-only environments
24
  fake_bnb = types.ModuleType("bitsandbytes")
25
  def _bnb_unavailable(*args, **kwargs):
 
29
  sys.modules["bitsandbytes"] = fake_bnb
30
 
31
  fake_flash = types.ModuleType("flash_attn")
32
+ # set a valid __spec__ so importlib.util.find_spec('flash_attn') does not fail
33
+ fake_flash.__spec__ = importlib.machinery.ModuleSpec("flash_attn", loader=None)
34
  sys.modules["flash_attn"] = fake_flash
35
 
36
  from transformers import AutoModel, AutoTokenizer