Spaces:
Running
Running
Commit
·
ae20fe2
1
Parent(s):
089bc3b
添加对PyTorch编译的禁用支持,以解决Gradio Spaces中的兼容性问题,并在多个文件中统一配置日志记录。
Browse files
src/podcast_transcribe/llm/llm_base.py
CHANGED
|
@@ -1,9 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import time
|
| 2 |
import uuid
|
| 3 |
import torch
|
| 4 |
from typing import List, Dict, Optional, Union, Literal
|
| 5 |
from abc import ABC, abstractmethod
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
class BaseChatCompletion(ABC):
|
| 9 |
"""Gemma 聊天完成的基类,包含公共功能"""
|
|
@@ -308,6 +328,16 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
|
|
| 308 |
except ImportError:
|
| 309 |
raise ImportError("请先安装 transformers 库: pip install transformers")
|
| 310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
print(f"正在加载模型: {self.model_name}")
|
| 312 |
print(f"目标设备: {self.device}")
|
| 313 |
print(f"设备映射: {self.device_map}")
|
|
@@ -372,6 +402,13 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
|
|
| 372 |
) -> str:
|
| 373 |
"""使用 transformers 生成响应"""
|
| 374 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
# 对提示进行编码
|
| 376 |
inputs = self.tokenizer.encode(prompt_str, return_tensors="pt")
|
| 377 |
|
|
@@ -488,6 +525,29 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
|
|
| 488 |
print(f"生成完成,输出长度: {len(generated_tokens)} tokens")
|
| 489 |
return generated_text
|
| 490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
except RuntimeError as e:
|
| 492 |
if "CUDA error" in str(e):
|
| 493 |
print(f"CUDA 错误,尝试使用 CPU 进行推理: {e}")
|
|
@@ -517,10 +577,42 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
|
|
| 517 |
else:
|
| 518 |
raise e
|
| 519 |
except Exception as e:
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
|
| 525 |
def get_model_info(self) -> Dict[str, Union[str, bool, int]]:
|
| 526 |
"""获取模型信息"""
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM基础类定义
|
| 3 |
+
提供聊天完成功能的抽象基类和Transformers实现
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
import time
|
| 8 |
import uuid
|
| 9 |
import torch
|
| 10 |
from typing import List, Dict, Optional, Union, Literal
|
| 11 |
from abc import ABC, abstractmethod
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
# 禁用 PyTorch 编译以避免在 Gradio Spaces 中的兼容性问题
|
| 15 |
+
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
|
| 16 |
|
| 17 |
+
# 如果 torch._dynamo 可用,禁用它
|
| 18 |
+
try:
|
| 19 |
+
import torch._dynamo
|
| 20 |
+
torch._dynamo.config.disable = True
|
| 21 |
+
torch._dynamo.config.suppress_errors = True
|
| 22 |
+
except ImportError:
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
# 配置日志
|
| 26 |
+
logger = logging.getLogger("llm")
|
| 27 |
|
| 28 |
class BaseChatCompletion(ABC):
|
| 29 |
"""Gemma 聊天完成的基类,包含公共功能"""
|
|
|
|
| 328 |
except ImportError:
|
| 329 |
raise ImportError("请先安装 transformers 库: pip install transformers")
|
| 330 |
|
| 331 |
+
# 确保编译功能被禁用
|
| 332 |
+
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
|
| 333 |
+
os.environ["TORCH_COMPILE_DISABLE"] = "1"
|
| 334 |
+
try:
|
| 335 |
+
import torch._dynamo
|
| 336 |
+
torch._dynamo.config.disable = True
|
| 337 |
+
torch._dynamo.config.suppress_errors = True
|
| 338 |
+
except (ImportError, AttributeError):
|
| 339 |
+
pass
|
| 340 |
+
|
| 341 |
print(f"正在加载模型: {self.model_name}")
|
| 342 |
print(f"目标设备: {self.device}")
|
| 343 |
print(f"设备映射: {self.device_map}")
|
|
|
|
| 402 |
) -> str:
|
| 403 |
"""使用 transformers 生成响应"""
|
| 404 |
|
| 405 |
+
# 额外的编译禁用措施,确保在 Gradio Spaces 中正常工作
|
| 406 |
+
try:
|
| 407 |
+
import torch._dynamo
|
| 408 |
+
torch._dynamo.config.disable = True
|
| 409 |
+
except (ImportError, AttributeError):
|
| 410 |
+
pass
|
| 411 |
+
|
| 412 |
# 对提示进行编码
|
| 413 |
inputs = self.tokenizer.encode(prompt_str, return_tensors="pt")
|
| 414 |
|
|
|
|
| 525 |
print(f"生成完成,输出长度: {len(generated_tokens)} tokens")
|
| 526 |
return generated_text
|
| 527 |
|
| 528 |
+
except torch._dynamo.exc.BackendCompilerFailed as e:
|
| 529 |
+
print(f"PyTorch 编译器错误,尝试禁用编译后重试: {e}")
|
| 530 |
+
# 强制禁用编译并重试
|
| 531 |
+
try:
|
| 532 |
+
torch._dynamo.reset()
|
| 533 |
+
torch._dynamo.config.disable = True
|
| 534 |
+
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
|
| 535 |
+
|
| 536 |
+
with torch.no_grad():
|
| 537 |
+
outputs = self.model.generate(
|
| 538 |
+
inputs,
|
| 539 |
+
**generation_config
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
generated_tokens = outputs[0][len(inputs[0]):]
|
| 543 |
+
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 544 |
+
|
| 545 |
+
print(f"禁用编译后生成完成,输出长度: {len(generated_tokens)} tokens")
|
| 546 |
+
return generated_text
|
| 547 |
+
|
| 548 |
+
except Exception as retry_e:
|
| 549 |
+
print(f"禁用编译后仍然失败: {retry_e}")
|
| 550 |
+
raise e
|
| 551 |
except RuntimeError as e:
|
| 552 |
if "CUDA error" in str(e):
|
| 553 |
print(f"CUDA 错误,尝试使用 CPU 进行推理: {e}")
|
|
|
|
| 577 |
else:
|
| 578 |
raise e
|
| 579 |
except Exception as e:
|
| 580 |
+
# 处理其他编译器相关错误
|
| 581 |
+
if "BackendCompilerFailed" in str(e) or "dynamo" in str(e).lower() or "inductor" in str(e).lower():
|
| 582 |
+
print(f"检测到编译器相关错误,尝试完全禁用编译: {e}")
|
| 583 |
+
try:
|
| 584 |
+
# 强制禁用所有编译功能
|
| 585 |
+
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
|
| 586 |
+
os.environ["TORCH_COMPILE_DISABLE"] = "1"
|
| 587 |
+
|
| 588 |
+
# 如果可能,重置编译状态
|
| 589 |
+
try:
|
| 590 |
+
torch._dynamo.reset()
|
| 591 |
+
torch._dynamo.config.disable = True
|
| 592 |
+
torch._dynamo.config.suppress_errors = True
|
| 593 |
+
except:
|
| 594 |
+
pass
|
| 595 |
+
|
| 596 |
+
with torch.no_grad():
|
| 597 |
+
outputs = self.model.generate(
|
| 598 |
+
inputs,
|
| 599 |
+
**generation_config
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
generated_tokens = outputs[0][len(inputs[0]):]
|
| 603 |
+
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 604 |
+
|
| 605 |
+
print(f"完全禁用编译后生成完成,输出长度: {len(generated_tokens)} tokens")
|
| 606 |
+
return generated_text
|
| 607 |
+
|
| 608 |
+
except Exception as final_e:
|
| 609 |
+
print(f"所有重试都失败: {final_e}")
|
| 610 |
+
raise e
|
| 611 |
+
else:
|
| 612 |
+
print(f"生成响应时出错: {e}")
|
| 613 |
+
import traceback
|
| 614 |
+
traceback.print_exc()
|
| 615 |
+
raise
|
| 616 |
|
| 617 |
def get_model_info(self) -> Dict[str, Union[str, bool, int]]:
|
| 618 |
"""获取模型信息"""
|
src/podcast_transcribe/llm/llm_gemma_transfomers.py
CHANGED
|
@@ -1,6 +1,20 @@
|
|
| 1 |
import torch
|
| 2 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 3 |
from typing import List, Dict, Optional, Union, Literal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from .llm_base import TransformersBaseChatCompletion
|
| 5 |
|
| 6 |
|
|
|
|
| 1 |
import torch
|
| 2 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 3 |
from typing import List, Dict, Optional, Union, Literal
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# 禁用 PyTorch 编译以避免在 Gradio Spaces 中的兼容性问题
|
| 7 |
+
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
|
| 8 |
+
os.environ["TORCH_COMPILE_DISABLE"] = "1"
|
| 9 |
+
|
| 10 |
+
# 如果 torch._dynamo 可用,禁用它
|
| 11 |
+
try:
|
| 12 |
+
import torch._dynamo
|
| 13 |
+
torch._dynamo.config.disable = True
|
| 14 |
+
torch._dynamo.config.suppress_errors = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
from .llm_base import TransformersBaseChatCompletion
|
| 19 |
|
| 20 |
|
src/podcast_transcribe/llm/llm_router.py
CHANGED
|
@@ -6,6 +6,19 @@ LLM模型调用路由器
|
|
| 6 |
import logging
|
| 7 |
import torch
|
| 8 |
from typing import Dict, Any, Optional, List, Union
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
import spaces
|
| 11 |
from .llm_base import BaseChatCompletion
|
|
|
|
| 6 |
import logging
|
| 7 |
import torch
|
| 8 |
from typing import Dict, Any, Optional, List, Union
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
# 禁用 PyTorch 编译以避免在 Gradio Spaces 中的兼容性问题
|
| 12 |
+
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
|
| 13 |
+
os.environ["TORCH_COMPILE_DISABLE"] = "1"
|
| 14 |
+
|
| 15 |
+
# 如果 torch._dynamo 可用,禁用它
|
| 16 |
+
try:
|
| 17 |
+
import torch._dynamo
|
| 18 |
+
torch._dynamo.config.disable = True
|
| 19 |
+
torch._dynamo.config.suppress_errors = True
|
| 20 |
+
except ImportError:
|
| 21 |
+
pass
|
| 22 |
|
| 23 |
import spaces
|
| 24 |
from .llm_base import BaseChatCompletion
|