""" Initial Solution Generator AZR 기반 TestTime RLVR을 위한 초기 솔루션 생성기 기존 Test-Time-RLVR의 generate_initial_solution 함수를 클래스화하여 확장 """ import re import torch from typing import Dict, Any, Optional, Tuple, List from transformers import AutoTokenizer, AutoModelForCausalLM from .config import TestTimeConfig from .logger import TestTimeLogger from .prompts import get_prompt, get_temperature, get_diversity_instruction # AZR에서 사용하는 코드 추출 함수 직접 임포트 from ..rewards.custom_evaluate import extract_code # VLLM 최적화 지원 try: from vllm import LLM, SamplingParams VLLM_AVAILABLE = True except ImportError: VLLM_AVAILABLE = False class InitialSolutionGenerator: """벤치마크 문제에 대한 초기 솔루션 생성""" def __init__(self, model, tokenizer, config: TestTimeConfig, logger: Optional[TestTimeLogger] = None, use_vllm: bool = True): self.model = model self.tokenizer = tokenizer self.config = config self.logger = logger or TestTimeLogger() self.use_vllm = use_vllm and VLLM_AVAILABLE # VLLM 사용 가능 여부 확인 및 로깅 if use_vllm and not VLLM_AVAILABLE: self.logger.log_info("⚠️ VLLM requested but not available, falling back to HuggingFace") elif self.use_vllm: self.logger.log_info("🚀 Using VLLM for optimized inference") else: self.logger.log_info("🔧 Using HuggingFace Transformers for inference") def generate(self, problem: Dict[str, Any]) -> str: """문제에 대한 초기 솔루션 생성 (AZR 코드 평가 프롬프트 사용)""" problem_prompt = problem['prompt'] problem_id = problem.get('task_id', 'unknown') # AZR 코드 평가에서 사용하는 프롬프트 포맷 적용 # prompt = f"Please provide a self-contained Python script that solves the following problem in a markdown code block:\n\n{problem_prompt}" # 중앙 프롬프트 시스템 사용 if 'HumanEval' in problem_id: # entry_point 함수명 찾기 entry_point = problem.get('entry_point', 'unknown') # 프롬프트에서 함수가 여러 개 있는지 확인 import re function_count = len(re.findall(r'^\s*def\s+\w+', problem_prompt, re.MULTILINE)) if function_count > 1: # 다중 함수 프롬프트 사용 prompt = get_prompt("solution_humaneval_multi", problem_prompt=problem_prompt, entry_point=entry_point) else: # 단일 함수 프롬프트 사용 prompt = get_prompt("solution_humaneval_basic", problem_prompt=problem_prompt) else: # MBPP 프롬프트 사용 prompt = get_prompt("solution_mbpp_basic", problem_prompt=problem_prompt) self.logger.log_info(f"🔍 Generating initial solution for {problem_id}") self.logger.log_info(f"📋 Full prompt: {prompt}") # VLLM 또는 HuggingFace 백엔드 선택 if self.use_vllm and isinstance(self.model, LLM): solution = self._generate_with_vllm(prompt) else: solution = self._generate_with_huggingface(prompt) # 마크다운 코드 블록에서 Python 코드 추출 (개선된 방식) extracted_solution = self._extract_python_code(solution) # 코드 추출 결과 로깅 if extracted_solution and extracted_solution != solution: self.logger.log_info(f"🔍 Extracted Python code from markdown block") solution = extracted_solution elif not extracted_solution: self.logger.log_info(f"🔍 No markdown code block found, using original text") # HumanEval의 경우 프롬프트에서 import 추출하여 추가 (EvalPlus 방식) if 'HumanEval' in problem_id: solution = self._add_imports_from_prompt(solution, problem_prompt) # 함수 정의 복구 (AZR 로직 그대로) solution = self._fix_function_definition(solution, prompt, problem_id) self.logger.log_info(f"✅ Generated solution ({len(solution)} chars)") self.logger.log_info(f"🔍 Solution preview: {solution[:200]}...") # 디버깅: 실제 솔루션 내용 로깅 self.logger.log_info(f"🔍 Full solution for debugging:") self.logger.log_info(f"--- START SOLUTION ---") self.logger.log_info(solution) self.logger.log_info(f"--- END SOLUTION ---") return solution def generate_diverse(self, problem: Dict[str, Any], temperature: float = 0.7, variation_id: int = 0) -> str: """다양한 솔루션 생성 (높은 temperature 사용)""" problem_prompt = problem['prompt'] problem_id = problem.get('task_id', 'unknown') # 중앙 관리 다양성 프롬프트 시스템 사용 diversity_instruction = get_diversity_instruction(variation_id) # HumanEval에 대해서는 함수 완성 요청 (다양성 버전) if 'HumanEval' in problem_id: entry_point = problem.get('entry_point', 'unknown') import re function_count = len(re.findall(r'^\s*def\s+\w+', problem_prompt, re.MULTILINE)) if function_count > 1: prompt = get_prompt("diverse_humaneval_multi", diversity_instruction=diversity_instruction, problem_prompt=problem_prompt, entry_point=entry_point) else: prompt = get_prompt("diverse_humaneval_basic", diversity_instruction=diversity_instruction, problem_prompt=problem_prompt) else: # MBPP 다양성 프롬프트 사용 prompt = get_prompt("diverse_mbpp_basic", diversity_instruction=diversity_instruction, problem_prompt=problem_prompt) self.logger.log_info(f"🎨 Generating diverse solution #{variation_id+1} for {problem_id}") # 다양성 생성 메서드 사용 try: from vllm import LLM if isinstance(self.model, LLM): solution = self._generate_with_vllm_diverse(prompt, temperature) else: solution = self._generate_with_huggingface_diverse(prompt, temperature) except ImportError: solution = self._generate_with_huggingface_diverse(prompt, temperature) # 코드 추출 및 후처리 (기존과 동일) extracted_solution = self._extract_python_code(solution) if extracted_solution and extracted_solution != solution: self.logger.log_info(f"🔍 Extracted Python code from markdown block") solution = extracted_solution if 'HumanEval' in problem_id: solution = self._add_imports_from_prompt(solution, problem_prompt) solution = self._fix_function_definition(solution, prompt, problem_id) self.logger.log_info(f"✅ Generated diverse solution #{variation_id+1} ({len(solution)} chars)") return solution def _generate_with_vllm(self, prompt: str) -> str: """VLLM 백엔드로 생성 (AZR 방식)""" # AZR evaluation과 동일한 SamplingParams 설정 sampling_params = SamplingParams( temperature=0.05, max_tokens=2048, # AZR 평가 설정 top_p=1.0, # greedy mode stop=["\n```\n"], # 코드 블록 종료 시 정지 ) # VLLM 생성 outputs = self.model.generate([prompt], sampling_params, use_tqdm=False) solution = outputs[0].outputs[0].text.replace("\t", " ") # AZR 방식 탭 처리 return solution.strip() def _generate_with_vllm_diverse(self, prompt: str, temperature: float = 0.7) -> str: """다양한 솔루션 생성용 VLLM 백엔드 (높은 temperature)""" # 다양성을 위한 SamplingParams 설정 sampling_params = SamplingParams( temperature=temperature, # 높은 temperature로 다양성 확보 max_tokens=2048, top_p=0.95, # 다양성을 위해 top_p 사용 stop=["\n```\n"], # 코드 블록 종료 시 정지 ) # VLLM 생성 outputs = self.model.generate([prompt], sampling_params, use_tqdm=False) solution = outputs[0].outputs[0].text.replace("\t", " ") return solution.strip() def generate_batch(self, prompts: List[str], temperature: float = 0.7) -> List[str]: """배치로 여러 프롬프트 동시 처리""" # 실제 모델 타입 확인 (VLLM 로딩 실패 시 HuggingFace 모델이 로드됨) if self.use_vllm and isinstance(self.model, LLM): raw_solutions = self._generate_batch_with_vllm(prompts, temperature) else: # HuggingFace는 순차 처리 (fallback) raw_solutions = [self._generate_with_huggingface(prompt) for prompt in prompts] # 각 솔루션에 대해 후처리 수행 processed_solutions = [] for i, (prompt, solution) in enumerate(zip(prompts, raw_solutions)): # 1. 마크다운에서 Python 코드 추출 extracted = self._extract_python_code(solution) if extracted and extracted != solution: self.logger.log_info(f"🔍 Extracted Python code from markdown block for batch item {i+1}") solution = extracted # 2. HumanEval 문제인 경우 import 추가 # 프롬프트에서 problem ID 추출 (프롬프트에 포함되어 있다고 가정) if 'HumanEval' in prompt: # 프롬프트에서 원본 problem prompt 추출 시도 # 프롬프트 구조에 따라 조정 필요 solution = self._add_imports_from_prompt(solution, prompt) # 3. 함수 정의 수정 (필요한 경우) # generate_diverse와 동일한 처리 solution = self._fix_function_definition(solution, prompt) processed_solutions.append(solution) return processed_solutions def _generate_batch_with_vllm(self, prompts: List[str], temperature: float = 0.7) -> List[str]: """VLLM으로 배치 처리""" # VLLM 샘플링 파라미터 # seed를 제거하여 매번 다른 응답 생성 sampling_params = SamplingParams( temperature=temperature, top_p=0.85, max_tokens=1024, stop=[] # stop 토큰 명시적으로 비움 ) # VLLM 배치 생성 outputs = self.model.generate(prompts, sampling_params, use_tqdm=False) # 결과 추출 solutions = [] for i, output in enumerate(outputs): solution = output.outputs[0].text.replace("\t", " ") # 디버깅: finish_reason 확인 finish_reason = output.outputs[0].finish_reason if finish_reason != "stop" and i < 3: # 처음 3개만 로깅 self.logger.log_warning(f"Output {i} finish_reason: {finish_reason}, length: {len(solution)}") solutions.append(solution.strip()) return solutions def _generate_with_huggingface(self, prompt: str) -> str: """HuggingFace 백엔드로 생성 (attention mask 수정)""" # 토크나이저 처리 (attention mask 경고 수정) inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True, max_length=4096) # attention mask 명시적으로 설정 if 'attention_mask' not in inputs: inputs['attention_mask'] = torch.ones_like(inputs['input_ids']) # 디바이스 이동 (AZR 방식 그대로) device = getattr(self.model, 'device', 'cuda' if torch.cuda.is_available() else 'cpu') if isinstance(device, str): inputs = {k: v.to(device) for k, v in inputs.items()} else: # 모델이 이미 특정 디바이스에 있는 경우 inputs = {k: v.to(next(self.model.parameters()).device) for k, v in inputs.items()} with torch.no_grad(): # 메모리 정리 (AZR 방식 그대로) if torch.cuda.is_available(): torch.cuda.empty_cache() # AZR evaluation과 동일한 greedy 설정 outputs = self.model.generate( inputs['input_ids'], attention_mask=inputs['attention_mask'], # attention mask 명시적으로 전달 max_new_tokens=2048, # 원래 AZR 평가 설정 do_sample=False, # greedy mode (--greedy와 동일) pad_token_id=self.tokenizer.eos_token_id ) # 솔루션 추출 (AZR 방식 그대로) solution = self.tokenizer.decode(outputs[0], skip_special_tokens=True) solution = solution[len(prompt):].strip() return solution def _generate_with_huggingface_diverse(self, prompt: str, temperature: float = 0.7) -> str: """다양한 솔루션 생성용 HuggingFace 백엔드 (높은 temperature)""" # 토크나이저 처리 inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True, max_length=4096) # attention mask 명시적으로 설정 if 'attention_mask' not in inputs: inputs['attention_mask'] = torch.ones_like(inputs['input_ids']) # 디바이스 이동 device = getattr(self.model, 'device', 'cuda' if torch.cuda.is_available() else 'cpu') if isinstance(device, str): inputs = {k: v.to(device) for k, v in inputs.items()} else: # 모델이 이미 특정 디바이스에 있는 경우 inputs = {k: v.to(next(self.model.parameters()).device) for k, v in inputs.items()} with torch.no_grad(): # 메모리 정리 if torch.cuda.is_available(): torch.cuda.empty_cache() # 다양성을 위한 sampling 설정 outputs = self.model.generate( inputs['input_ids'], attention_mask=inputs['attention_mask'], max_new_tokens=2048, do_sample=True, # sampling 활성화 temperature=temperature, # 높은 temperature top_p=0.95, # 다양성을 위해 top_p 사용 pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id ) # 솔루션 추출 solution = self.tokenizer.decode(outputs[0], skip_special_tokens=True) solution = solution[len(prompt):].strip() return solution def _extract_python_code(self, solution: str) -> str: """개선된 Python 코드 추출 (AZR 방식 + 추가 패턴)""" # 1. AZR의 extract_code 함수 먼저 시도 try: extracted = extract_code(solution, language="python") if extracted: return extracted except: pass # 2. 다양한 마크다운 패턴 시도 patterns = [ r'```python\n(.*?)```', # ```python ... ``` r'```\n(.*?)```', # ``` ... ``` r'```py\n(.*?)```', # ```py ... ``` r'```Python\n(.*?)```', # ```Python ... ``` r'Here is.*?:\n\n```python\n(.*?)```', # 설명 텍스트 포함 r'Here is.*?:\n\n```\n(.*?)```', # 설명 텍스트 포함 ] for pattern in patterns: matches = re.findall(pattern, solution, re.DOTALL | re.IGNORECASE) if matches: return matches[-1].strip() # 3. def로 시작하는 함수 찾기 lines = solution.split('\n') code_lines = [] in_function = False for line in lines: if line.strip().startswith('def '): in_function = True code_lines.append(line) elif in_function and (line.startswith(' ') or line.strip() == ''): code_lines.append(line) elif in_function and line.strip() and not line.startswith(' '): # 함수 정의 끝 break if code_lines: return '\n'.join(code_lines) # 4. 원본 반환 return solution def _add_imports_from_prompt(self, solution: str, prompt: str) -> str: """HumanEval 프롬프트에서 import 문을 추출하여 솔루션에 추가 (EvalPlus 방식)""" # 이미 import가 있으면 그대로 반환 if 'from typing import' in solution or 'import typing' in solution: return solution # 프롬프트에서 import 문 추출 import_lines = [] prompt_lines = prompt.split('\n') for line in prompt_lines: stripped = line.strip() # import 문 찾기 if (stripped.startswith('from ') and 'import' in stripped) or stripped.startswith('import '): import_lines.append(line) # 함수 정의가 시작되면 중단 elif stripped.startswith('def '): break # import가 없으면 원본 반환 if not import_lines: return solution # import 추가 self.logger.log_info(f"🔧 Adding imports from prompt: {import_lines}") # 솔루션이 이미 import로 시작하는지 확인 solution_lines = solution.split('\n') first_non_empty_line = None for i, line in enumerate(solution_lines): if line.strip(): first_non_empty_line = i break # import를 맨 앞에 추가 if first_non_empty_line is not None: # 기존 import 뒤에 추가하거나 맨 앞에 추가 imports_text = '\n'.join(import_lines) + '\n\n' # 첫 번째 비어있지 않은 줄이 import인 경우 if solution_lines[first_non_empty_line].strip().startswith(('import ', 'from ')): # 마지막 import 찾기 last_import_idx = first_non_empty_line for i in range(first_non_empty_line, len(solution_lines)): if solution_lines[i].strip() and not solution_lines[i].strip().startswith(('import ', 'from ')): break if solution_lines[i].strip().startswith(('import ', 'from ')): last_import_idx = i # 마지막 import 다음에 추가 solution_lines.insert(last_import_idx + 1, '') solution_lines.insert(last_import_idx + 1, '\n'.join(import_lines)) return '\n'.join(solution_lines) else: # 맨 앞에 추가 return imports_text + solution return imports_text + solution def _fix_function_definition(self, solution: str, prompt: str, problem_id: str = "") -> str: """함수 정의가 누락된 경우 복구 + lpw 스타일 중복 처리""" # lpw 스타일: 프롬프트에서 함수 이름 추출 func_def_match = re.search(r'def\s+(\w+)\([^)]*\)(?:\s*->\s*[^:]+)?:', prompt) if not func_def_match: return solution entry_point = func_def_match.group(1) func_def_line = func_def_match.group(0) # HumanEval의 경우 전체 코드를 반환하므로 중복 처리 불필요 if 'HumanEval' in problem_id: # 이미 전체 코드가 있으므로 그대로 반환 return solution # MBPP의 경우 기존 로직 유지 # Case 1: LLM이 전체 함수를 생성한 경우 (lpw 스타일 체크) if (prompt in solution) or (f'def {entry_point}(' in solution): # 함수가 이미 포함되어 있음 self.logger.log_info(f"✅ Function definition already present for {entry_point}") return solution # Case 2: 함수 본문만 생성한 경우 - 함수 정의 추가 if solution and not solution.startswith('def '): # 함수 정의와 함수 내용을 결합 lines = solution.split('\n') fixed_lines = [func_def_line] for line in lines: if line.strip(): # 빈 줄이 아닌 경우 # if __name__ == "__main__": 부분은 함수 밖에 있어야 함 if line.strip().startswith('if __name__'): # 함수 정의 끝내고 메인 부분 시작 fixed_lines.append('') # 빈 줄 추가 fixed_lines.append(line.strip()) else: # 함수 내용은 4칸 인덴테이션 if not line.startswith(' ') and line.strip(): line = ' ' + line.lstrip() fixed_lines.append(line) else: fixed_lines.append(line) solution = '\n'.join(fixed_lines) self.logger.log_info(f"🔧 Fixed function definition for {entry_point}") return solution def generate_fallback_solution(self, problem: Dict[str, Any]) -> str: """문제 생성 실패 시 대체 솔루션 생성""" entry_point = problem.get('entry_point', 'solution') problem_description = problem.get('prompt', '') # 문제 유형별 기본 템플릿 (기존 방식) if 'similar_elements' in problem_description: # similar_elements 문제 (Mbpp/2) solution = f"""def {entry_point}(test_tup1, test_tup2): return tuple(set(test_tup1) & set(test_tup2))""" elif 'kth_element' in problem_description: # kth_element 문제 solution = f"""def {entry_point}(arr, k): return sorted(arr)[k-1]""" else: # 일반 템플릿 solution = f"""def {entry_point}(*args): # TODO: Implement this function return None""" self.logger.log_info(f"🔄 Generated fallback solution for {entry_point}") return solution def validate_syntax(self, solution: str) -> Tuple[bool, Optional[str]]: """솔루션 구문 검증""" try: compile(solution, '', 'exec') return True, None except SyntaxError as e: return False, str(e) except Exception as e: return False, str(e) def extract_function_signature(self, prompt: str) -> Optional[Dict[str, str]]: """프롬프트에서 함수 시그니처 추출""" # def function_name(args) -> return_type: 패턴 매칭 pattern = r'def\s+(\w+)\(([^)]*)\)(?:\s*->\s*([^:]+))?:' match = re.search(pattern, prompt) if match: func_name = match.group(1) args = match.group(2) return_type = match.group(3) return { 'name': func_name, 'args': args.strip(), 'return_type': return_type.strip() if return_type else None, 'full_signature': match.group(0) } return None def format_solution(self, raw_solution: str, problem: Dict[str, Any]) -> str: """솔루션 형식 정리""" # 기본 정리 solution = raw_solution.strip() # 함수 정의 확인 및 수정 if not solution.startswith('def '): signature = self.extract_function_signature(problem.get('prompt', '')) if signature: # 함수 정의 추가 lines = solution.split('\n') indented_lines = [' ' + line if line.strip() else line for line in lines] solution = signature['full_signature'] + '\n' + '\n'.join(indented_lines) # 불필요한 설명 텍스트 제거 lines = solution.split('\n') code_lines = [] in_function = False for line in lines: if line.strip().startswith('def '): in_function = True code_lines.append(line) elif in_function: code_lines.append(line) elif line.strip() and not any(keyword in line.lower() for keyword in ['explanation', 'here', 'this function', 'the solution']): code_lines.append(line) return '\n'.join(code_lines).strip() @staticmethod def extract_docstring_from_function(code: str) -> str: """함수 코드에서 docstring을 추출""" import re # 함수 정의 다음에 오는 docstring 패턴 매칭 # """...""" 또는 '''...''' 형태 docstring_patterns = [ r'def\s+\w+\([^)]*\):\s*\n\s*"""(.*?)"""', # """...""" r'def\s+\w+\([^)]*\):\s*\n\s*\'\'\'(.*?)\'\'\'', # '''...''' ] for pattern in docstring_patterns: match = re.search(pattern, code, re.DOTALL) if match: docstring = match.group(1).strip() # 여러 줄인 경우 깔끔하게 정리 lines = docstring.split('\n') cleaned_lines = [] for line in lines: cleaned_line = line.strip() if cleaned_line: cleaned_lines.append(cleaned_line) return ' '.join(cleaned_lines) # docstring이 없는 경우 기본 메시지 반환 return "Find the function that produces these outputs from these inputs." def _extract_function_code(self, code: str) -> str: """코드에서 함수 정의와 필요한 import 추출""" import re lines = code.strip().split('\n') import_lines = [] func_lines = [] in_function = False indent_level = 0 # 1. import 문 수집 for line in lines: stripped = line.strip() if (stripped.startswith('import ') or stripped.startswith('from ')) and not stripped.startswith('#'): import_lines.append(line) # 2. 함수 정의 찾기 for line in lines: if line.strip().startswith('def '): in_function = True func_lines = [line] # 첫 줄의 들여쓰기 레벨 저장 indent_level = len(line) - len(line.lstrip()) elif in_function: # 빈 줄이거나 같은/더 깊은 들여쓰기면 함수의 일부 if not line.strip() or (line.strip() and len(line) - len(line.lstrip()) > indent_level): func_lines.append(line) else: # 함수 끝 break # 3. import + function 결합 if func_lines: result_lines = import_lines + [''] + func_lines if import_lines else func_lines return '\n'.join(result_lines) else: return code def evaluate_solution(self, problem: Dict[str, Any], solution: str) -> Dict[str, Any]: """LLM 솔루션을 벤치마크 테스트로 평가 (EvalPlus 필수)""" try: # EvalPlus 함수들 임포트 (pip으로 설치된 버전 사용) self.logger.log_info("🔄 Attempting to import EvalPlus...") from evalplus.evaluate import check_correctness from evalplus.gen.util import trusted_exec from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS from evalplus.eval import PASS self.logger.log_info("✅ Using EvalPlus for evaluation") except ImportError as e: # EvalPlus가 없으면 오류로 처리 (fallback 제거) self.logger.log_error(f"❌ EvalPlus is required but not available: {e}") import traceback self.logger.log_error(f"📋 Import traceback: {traceback.format_exc()}") return { 'correct': False, 'passed_tests': 0, 'total_tests': 0, 'error': f"EvalPlus import failed: {e}. Please install EvalPlus properly.", 'execution_results': [], 'base_passed': 0, 'plus_passed': 0, 'base_total': 0, 'plus_total': 0 } except Exception as e: self.logger.log_error(f"❌ EvalPlus import failed with unexpected error: {e}") return { 'correct': False, 'passed_tests': 0, 'total_tests': 0, 'error': f"EvalPlus import error: {e}", 'execution_results': [], 'base_passed': 0, 'plus_passed': 0, 'base_total': 0, 'plus_total': 0 } result = { 'correct': False, 'passed_tests': 0, 'total_tests': 0, 'error': None, 'execution_results': [], 'base_passed': 0, 'plus_passed': 0, 'base_total': 0, 'plus_total': 0 } try: # 1. 함수 정의 추출 extracted_code = self._extract_function_code(solution) if not extracted_code: result['error'] = "No function definition found" return result # 2. 데이터셋 타입 결정 task_id = problem.get('task_id', '') if task_id.startswith('Mbpp'): dataset = 'mbpp' elif task_id.startswith('HumanEval'): dataset = 'humaneval' else: # 기본값 dataset = 'mbpp' # 3. expected outputs 생성 (canonical solution 사용) entry_point = problem.get('entry_point', '') canonical_solution = problem.get('canonical_solution', '') if not canonical_solution: result['error'] = "No canonical_solution found" return result # Expected outputs 계산 expected_output = {} # Base tests base_inputs = problem.get('base_input', []) if base_inputs: expected_output['base'], expected_output['base_time'] = trusted_exec( problem.get('prompt', '') + canonical_solution, base_inputs, entry_point, record_time=True, output_not_none=entry_point in MBPP_OUTPUT_NOT_NONE_TASKS ) # Plus tests plus_inputs = problem.get('plus_input', []) if plus_inputs: expected_output['plus'], expected_output['plus_time'] = trusted_exec( problem.get('prompt', '') + canonical_solution, plus_inputs, entry_point, record_time=True, output_not_none=entry_point in MBPP_OUTPUT_NOT_NONE_TASKS ) # 4. EvalPlus check_correctness 호출 evalplus_result = check_correctness( dataset=dataset, completion_id=0, problem=problem, solution=extracted_code, expected_output=expected_output, base_only=False, # Plus tests도 실행 fast_check=False, # 모든 테스트 실행 identifier=task_id ) # 5. 결과 파싱 if 'base' in evalplus_result: base_stat, base_details = evalplus_result['base'] result['base_total'] = len(base_inputs) if base_stat == PASS: result['base_passed'] = result['base_total'] else: result['base_passed'] = sum(1 for d in base_details if d) if base_details else 0 result['passed_tests'] += result['base_passed'] result['total_tests'] += result['base_total'] if 'plus' in evalplus_result: plus_stat, plus_details = evalplus_result['plus'] result['plus_total'] = len(plus_inputs) if plus_stat == PASS: result['plus_passed'] = result['plus_total'] else: result['plus_passed'] = sum(1 for d in plus_details if d) if plus_details else 0 result['passed_tests'] += result['plus_passed'] result['total_tests'] += result['plus_total'] # EvalPlus 기준: 모든 테스트 통과해야 correct result['correct'] = (result['passed_tests'] == result['total_tests']) and result['total_tests'] > 0 # 에러 메시지 설정 if not result['correct']: if base_stat != PASS: result['error'] = f"Base tests failed: {base_stat}" elif 'plus' in evalplus_result and plus_stat != PASS: result['error'] = f"Plus tests failed: {plus_stat}" # 로깅 self.logger.log_info(f"EvalPlus evaluation for {task_id}:") self.logger.log_info(f" Base: {result['base_passed']}/{result['base_total']}") self.logger.log_info(f" Plus: {result['plus_passed']}/{result['plus_total']}") self.logger.log_info(f" Total: {result['passed_tests']}/{result['total_tests']}") self.logger.log_info(f" Correct: {result['correct']}") except Exception as e: result['error'] = f"Evaluation failed: {str(e)}" import traceback self.logger.log_info(f"Evaluation traceback: {traceback.format_exc()}") return result @staticmethod def load_model_with_optimizations(model_name: str, device: str, config: TestTimeConfig, use_vllm: bool = True, tensor_parallel_size: int = 1) -> Tuple[Any, Any]: """모델과 토크나이저 로드 (AZR 스타일 최적화, VLLM 지원)""" # 토크나이저 로드 tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # VLLM 사용 가능 여부 확인 및 모델 로드 if use_vllm and VLLM_AVAILABLE and device.startswith('cuda'): try: # GPU 디바이스 설정 (이미 설정된 CUDA_VISIBLE_DEVICES 우선 사용) import os if 'CUDA_VISIBLE_DEVICES' not in os.environ: gpu_id = device.split(':')[1] if ':' in device else '0' os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id else: # 이미 설정된 CUDA_VISIBLE_DEVICES 사용 gpu_id = os.environ['CUDA_VISIBLE_DEVICES'] print(f"🎯 Using existing CUDA_VISIBLE_DEVICES: {gpu_id}") # VLLM 모델 로드 (Ray Actor 환경에서 메모리 최적화) model = LLM( model=model_name, dtype=str(config.torch_dtype).split('.')[-1], # torch.float16 -> float16 trust_remote_code=True, gpu_memory_utilization=config.gpu_memory_utilization, max_model_len=getattr(config, 'max_model_len', 2048), # 충분한 길이로 증가 tensor_parallel_size=tensor_parallel_size, # GPU 개수에 맞춤 ) print(f"✅ VLLM model loaded successfully on GPU {gpu_id} (tensor_parallel_size={tensor_parallel_size})") return model, tokenizer except Exception as e: import traceback print(f"⚠️ VLLM loading failed: {e}") print(f"🔍 Full traceback: {traceback.format_exc()}") print(f"🔄 Falling back to HuggingFace") # HuggingFace 모델 로드 (기존 방식) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=config.torch_dtype, device_map=device if device.startswith('cuda') else None, trust_remote_code=True, attn_implementation="flash_attention_2" if config.use_flash_attention and device.startswith('cuda') else None, use_cache=False, # 학습용으로 캐시 비활성화 ) # Gradient checkpointing 활성화 # Gradient checkpointing 비활성화 - 추론 시에는 불필요하고 경고만 발생 # 학습이 필요한 경우 별도로 활성화해야 함 if hasattr(model, 'gradient_checkpointing_disable'): model.gradient_checkpointing_disable() print(f"✅ HuggingFace model loaded successfully") return model, tokenizer