Skip to content

第87天:代码助手系统-核心功能开发

学习目标

  • 掌握代码补全实现
  • 学习代码分析实现
  • 理解代码生成实现
  • 掌握代码导航实现
  • 学习代码重构实现

代码补全实现

智能代码补全

python
class SmartCodeCompleter:
    def __init__(self, llm_client):
        self.llm_client = llm_client
        self.completion_cache = {}
    
    async def complete_code(
        self,
        code: str,
        language: str,
        position: Dict,
        context: Optional[Dict] = None
    ) -> Dict:
        cache_key = self._generate_cache_key(
            code,
            language,
            position
        )
        
        if cache_key in self.completion_cache:
            return {
                "success": True,
                "completions": self.completion_cache[cache_key]
            }
        
        prompt = self._build_completion_prompt(
            code,
            language,
            position,
            context
        )
        
        try:
            completion = self.llm_client.chat.completions.create(
                model="gpt-4o",
                messages=[
                    {
                        "role": "system",
                        "content": f"你是一个专业的{language}代码补全助手,提供准确、高效的代码补全"
                    },
                    {
                        "role": "user",
                        "content": prompt
                    }
                ],
                temperature=0.2,
                max_tokens=500
            )
            
            completion_text = completion.choices[0].message.content
            completions = self._parse_completions(completion_text, language)
            
            self.completion_cache[cache_key] = completions
            
            return {
                "success": True,
                "completions": completions
            }
        
        except Exception as e:
            return {
                "success": False,
                "error": str(e)
            }
    
    def _generate_cache_key(
        self,
        code: str,
        language: str,
        position: Dict
    ) -> str:
        import hashlib
        
        key = f"{language}:{position['line']}:{position['column']}:{code}"
        return hashlib.md5(key.encode()).hexdigest()
    
    def _build_completion_prompt(
        self,
        code: str,
        language: str,
        position: Dict,
        context: Optional[Dict]
    ) -> str:
        lines = code.split('\n')
        start_line = max(0, position['line'] - 10)
        end_line = min(len(lines), position['line'] + 5)
        
        context_code = '\n'.join([
            f"{i+1:4d}: {line}"
            for i, line in enumerate(lines[start_line:end_line])
        ])
        
        current_line = lines[position['line'] - 1] if position['line'] > 0 else ''
        
        prompt = f"""请补全以下{language}代码:

上下文代码:
{context_code}

当前行(光标位置):
{current_line}

请提供最可能的代码补全,保持代码风格一致,并考虑上下文语义。"""
        
        return prompt
    
    def _parse_completions(
        self,
        completion_text: str,
        language: str
    ) -> List[Dict]:
        completions = []
        
        lines = completion_text.split('\n')
        
        for i, line in enumerate(lines):
            line = line.strip()
            if not line:
                continue
            
            completions.append({
                "id": f"completion_{i}",
                "text": line,
                "kind": self._detect_completion_kind(line, language),
                "score": 1.0 - (i * 0.1)
            })
        
        return completions
    
    def _detect_completion_kind(
        self,
        completion: str,
        language: str
    ) -> str:
        if "def " in completion or "function " in completion:
            return "function"
        elif "class " in completion:
            return "class"
        elif "import " in completion or "from " in completion:
            return "import"
        elif "=" in completion:
            return "variable"
        elif any(keyword in completion for keyword in ["if ", "for ", "while ", "try", "except"]):
            return "statement"
        else:
            return "other"

上下文感知补全

python
class ContextAwareCompleter:
    def __init__(self, code_parser, semantic_analyzer):
        self.code_parser = code_parser
        self.semantic_analyzer = semantic_analyzer
    
    async def get_contextual_completions(
        self,
        code: str,
        language: str,
        position: Dict
    ) -> Dict:
        parsed_code = await self.code_parser.parse(code, language)
        
        if not parsed_code.get("success"):
            return {
                "success": False,
                "error": "代码解析失败"
            }
        
        semantic_analysis = await self.semantic_analyzer.analyze(
            code,
            language
        )
        
        if not semantic_analysis.get("success"):
            return {
                "success": False,
                "error": "语义分析失败"
            }
        
        contextual_completions = await self._generate_contextual_completions(
            parsed_code["result"],
            semantic_analysis["result"],
            language,
            position
        )
        
        return {
            "success": True,
            "completions": contextual_completions
        }
    
    async def _generate_contextual_completions(
        self,
        parsed_code: Dict,
        semantic_analysis: Dict,
        language: str,
        position: Dict
    ) -> List[Dict]:
        completions = []
        
        functions = parsed_code.get("functions", [])
        classes = parsed_code.get("classes", [])
        variables = parsed_code.get("variables", [])
        dependencies = semantic_analysis.get("dependencies", [])
        
        for func in functions:
            completions.append({
                "id": f"func_{func['name']}",
                "text": func['name'],
                "kind": "function",
                "score": 0.9
            })
        
        for cls in classes:
            completions.append({
                "id": f"class_{cls['name']}",
                "text": cls['name'],
                "kind": "class",
                "score": 0.8
            })
        
        for var in variables:
            completions.append({
                "id": f"var_{var['name']}",
                "text": var['name'],
                "kind": "variable",
                "score": 0.7
            })
        
        return completions[:10]

代码分析实现

静态代码分析

python
class StaticCodeAnalyzer:
    def __init__(self):
        self.analyzers = {
            "python": self._analyze_python,
            "javascript": self._analyze_javascript,
            "typescript": self._analyze_typescript
        }
    
    async def analyze(
        self,
        code: str,
        language: str
    ) -> Dict:
        if language not in self.analyzers:
            return {
                "success": False,
                "error": f"不支持的语言: {language}"
            }
        
        analyzer = self.analyzers[language]
        
        try:
            issues = await analyzer(code)
            
            return {
                "success": True,
                "issues": issues
            }
        
        except Exception as e:
            return {
                "success": False,
                "error": str(e)
            }
    
    async def _analyze_python(self, code: str) -> List[Dict]:
        issues = []
        
        lines = code.split('\n')
        
        for i, line in enumerate(lines):
            line = line.strip()
            
            if len(line) > 100:
                issues.append({
                    "id": f"line_too_long_{i}",
                    "line": i + 1,
                    "column": 0,
                    "message": "代码行过长",
                    "severity": "warning"
                })
            
            if line.startswith('print(') and 'debug' in line.lower():
                issues.append({
                    "id": f"debug_print_{i}",
                    "line": i + 1,
                    "column": 0,
                    "message": "调试打印语句",
                    "severity": "info"
                })
            
            if 'import ' in line and 'from ' not in line:
                module_name = line.split('import ')[1].strip()
                if module_name not in ['os', 'sys', 're', 'json']:
                    issues.append({
                        "id": f"import_{i}",
                        "line": i + 1,
                        "column": 0,
                        "message": f"导入模块: {module_name}",
                        "severity": "info"
                    })
        
        return issues
    
    async def _analyze_javascript(self, code: str) -> List[Dict]:
        issues = []
        
        lines = code.split('\n')
        
        for i, line in enumerate(lines):
            line = line.strip()
            
            if 'console.log(' in line:
                issues.append({
                    "id": f"console_log_{i}",
                    "line": i + 1,
                    "column": 0,
                    "message": "控制台打印语句",
                    "severity": "info"
                })
        
        return issues
    
    async def _analyze_typescript(self, code: str) -> List[Dict]:
        issues = []
        
        lines = code.split('\n')
        
        for i, line in enumerate(lines):
            line = line.strip()
            
            if 'console.log(' in line:
                issues.append({
                    "id": f"console_log_{i}",
                    "line": i + 1,
                    "column": 0,
                    "message": "控制台打印语句",
                    "severity": "info"
                })
        
        return issues

代码质量评估

python
class CodeQualityEvaluator:
    def __init__(self):
        self.metrics = {
            "complexity": self._calculate_complexity,
            "readability": self._calculate_readability,
            "maintainability": self._calculate_maintainability
        }
    
    async def evaluate(
        self,
        code: str,
        language: str
    ) -> Dict:
        metrics = {}
        
        for metric_name, metric_func in self.metrics.items():
            metrics[metric_name] = await metric_func(code, language)
        
        overall_score = sum(metrics.values()) / len(metrics)
        
        return {
            "success": True,
            "metrics": metrics,
            "overall_score": overall_score,
            "grade": self._calculate_grade(overall_score)
        }
    
    async def _calculate_complexity(
        self,
        code: str,
        language: str
    ) -> float:
        lines = code.split('\n')
        
        complexity = 0
        
        for line in lines:
            line = line.strip()
            
            if any(keyword in line for keyword in ["if ", "for ", "while ", "try", "except", "switch", "case"]):
                complexity += 1
            
            if "&&" in line or "||" in line:
                complexity += 0.5
        
        max_complexity = len(lines) * 0.5
        
        if max_complexity == 0:
            return 1.0
        
        score = max(0, 1.0 - (complexity / max_complexity))
        
        return score
    
    async def _calculate_readability(
        self,
        code: str,
        language: str
    ) -> float:
        lines = code.split('\n')
        
        readable_lines = 0
        total_lines = len(lines)
        
        for line in lines:
            line = line.strip()
            
            if not line:
                continue
            
            if len(line) < 80 and '  ' not in line:
                readable_lines += 1
        
        if total_lines == 0:
            return 1.0
        
        return readable_lines / total_lines
    
    async def _calculate_maintainability(
        self,
        code: str,
        language: str
    ) -> float:
        lines = code.split('\n')
        
        maintainable_lines = 0
        total_lines = len(lines)
        
        for line in lines:
            line = line.strip()
            
            if not line:
                continue
            
            if '#' in line or '//' in line:
                maintainable_lines += 1
            elif 'def ' in line or 'function ' in line:
                maintainable_lines += 1
        
        if total_lines == 0:
            return 1.0
        
        return maintainable_lines / total_lines
    
    def _calculate_grade(self, score: float) -> str:
        if score >= 0.8:
            return "A"
        elif score >= 0.6:
            return "B"
        elif score >= 0.4:
            return "C"
        else:
            return "D"

代码生成实现

单元测试生成

python
class TestGenerator:
    def __init__(self, llm_client):
        self.llm_client = llm_client
    
    async def generate_tests(
        self,
        code: str,
        language: str,
        framework: str = "pytest"
    ) -> Dict:
        prompt = self._build_test_prompt(
            code,
            language,
            framework
        )
        
        try:
            completion = self.llm_client.chat.completions.create(
                model="gpt-4o",
                messages=[
                    {
                        "role": "system",
                        "content": f"你是一个专业的{language}测试工程师,擅长生成高质量的单元测试"
                    },
                    {
                        "role": "user",
                        "content": prompt
                    }
                ],
                temperature=0.3,
                max_tokens=1000
            )
            
            test_code = completion.choices[0].message.content
            
            return {
                "success": True,
                "test_code": test_code
            }
        
        except Exception as e:
            return {
                "success": False,
                "error": str(e)
            }
    
    def _build_test_prompt(
        self,
        code: str,
        language: str,
        framework: str
    ) -> str:
        prompt = f"""请为以下{language}代码生成单元测试:


请使用{framework}测试框架,生成全面的测试用例,包括:
1. 正常情况测试
2. 边界情况测试
3. 异常情况测试

测试代码应包含适当的断言和测试描述。"""
        
        return prompt

文档生成

python
class DocumentationGenerator:
    def __init__(self, llm_client):
        self.llm_client = llm_client
    
    async def generate_docstring(
        self,
        code: str,
        language: str
    ) -> Dict:
        prompt = self._build_docstring_prompt(
            code,
            language
        )
        
        try:
            completion = self.llm_client.chat.completions.create(
                model="gpt-4o",
                messages=[
                    {
                        "role": "system",
                        "content": f"你是一个专业的{language}文档工程师,擅长生成清晰、准确的文档"
                    },
                    {
                        "role": "user",
                        "content": prompt
                    }
                ],
                temperature=0.3,
                max_tokens=500
            )
            
            docstring = completion.choices[0].message.content
            
            return {
                "success": True,
                "docstring": docstring
            }
        
        except Exception as e:
            return {
                "success": False,
                "error": str(e)
            }
    
    def _build_docstring_prompt(
        self,
        code: str,
        language: str
    ) -> str:
        if language == "python":
            prompt = f"""请为以下Python代码生成docstring:

```python
{code}

请使用Google风格的docstring,包含:

  1. 函数/类描述
  2. 参数说明
  3. 返回值说明
  4. 异常说明(如果有)
  5. 示例(如果合适)""" elif language == "javascript" or language == "typescript": prompt = f"""请为以下{language}代码生成JSDoc文档:
{code}

请使用标准的JSDoc格式,包含:

  1. 函数/类描述
  2. 参数说明
  3. 返回值说明
  4. 异常说明(如果有)
  5. 示例(如果合适)""" else: prompt = f"""请为以下{language}代码生成文档:
{code}

请生成清晰、准确的文档,包含:

  1. 函数/类描述

  2. 参数说明

  3. 返回值说明

  4. 异常说明(如果有)

  5. 示例(如果合适)"""

     return prompt
    

## 代码导航实现

### 定义导航

```python
class CodeNavigator:
    def __init__(self, code_parser):
        self.code_parser = code_parser
    
    async def find_definition(
        self,
        code: str,
        language: str,
        position: Dict
    ) -> Dict:
        parsed_code = await self.code_parser.parse(code, language)
        
        if not parsed_code.get("success"):
            return {
                "success": False,
                "error": "代码解析失败"
            }
        
        symbol_name = await self._extract_symbol_at_position(
            code,
            position
        )
        
        if not symbol_name:
            return {
                "success": False,
                "error": "未找到符号"
            }
        
        definition = await self._find_symbol_definition(
            parsed_code["result"],
            symbol_name
        )
        
        if not definition:
            return {
                "success": False,
                "error": "未找到定义"
            }
        
        return {
            "success": True,
            "definition": definition
        }
    
    async def _extract_symbol_at_position(
        self,
        code: str,
        position: Dict
    ) -> Optional[str]:
        lines = code.split('\n')
        line = lines[position['line'] - 1]
        
        start = position['column']
        while start > 0 and line[start - 1].isalnum():
            start -= 1
        
        end = position['column']
        while end < len(line) and line[end].isalnum():
            end += 1
        
        symbol = line[start:end]
        
        if symbol.isalnum() and len(symbol) > 0:
            return symbol
        
        return None
    
    async def _find_symbol_definition(
        self,
        parsed_code: Dict,
        symbol_name: str
    ) -> Optional[Dict]:
        for func in parsed_code.get("functions", []):
            if func["name"] == symbol_name:
                return {
                    "kind": "function",
                    "name": func["name"],
                    "line": func["lineno"],
                    "column": 0
                }
        
        for cls in parsed_code.get("classes", []):
            if cls["name"] == symbol_name:
                return {
                    "kind": "class",
                    "name": cls["name"],
                    "line": cls["lineno"],
                    "column": 0
                }
        
        for var in parsed_code.get("variables", []):
            if var["name"] == symbol_name:
                return {
                    "kind": "variable",
                    "name": var["name"],
                    "line": var["lineno"],
                    "column": 0
                }
        
        return None

引用查找

python
class ReferenceFinder:
    def __init__(self):
        pass
    
    async def find_references(
        self,
        code: str,
        language: str,
        symbol_name: str
    ) -> Dict:
        references = []
        
        lines = code.split('\n')
        
        for i, line in enumerate(lines):
            if symbol_name in line:
                start = line.find(symbol_name)
                
                while start != -1:
                    references.append({
                        "id": f"ref_{i}_{start}",
                        "line": i + 1,
                        "column": start,
                        "symbol": symbol_name,
                        "kind": self._detect_reference_kind(
                            line,
                            start,
                            symbol_name,
                            language
                        )
                    })
                    
                    start = line.find(symbol_name, start + len(symbol_name))
        
        return {
            "success": True,
            "references": references
        }
    
    def _detect_reference_kind(
        self,
        line: str,
        start: int,
        symbol_name: str,
        language: str
    ) -> str:
        line_before = line[:start].strip()
        line_after = line[start + len(symbol_name):].strip()
        
        if line_before.endswith('def '):
            return "definition"
        elif line_before.endswith('class '):
            return "definition"
        elif line_after.startswith('('):
            return "call"
        elif line_after.startswith('.'):
            return "member"
        else:
            return "reference"

代码重构实现

代码重构器

python
class CodeRefactor:
    def __init__(self, llm_client):
        self.llm_client = llm_client
    
    async def refactor(
        self,
        code: str,
        language: str,
        refactor_type: str
    ) -> Dict:
        prompt = self._build_refactor_prompt(
            code,
            language,
            refactor_type
        )
        
        try:
            completion = self.llm_client.chat.completions.create(
                model="gpt-4o",
                messages=[
                    {
                        "role": "system",
                        "content": f"你是一个专业的{language}代码重构工程师"
                    },
                    {
                        "role": "user",
                        "content": prompt
                    }
                ],
                temperature=0.3,
                max_tokens=1000
            )
            
            refactored_code = completion.choices[0].message.content
            
            return {
                "success": True,
                "refactored_code": refactored_code
            }
        
        except Exception as e:
            return {
                "success": False,
                "error": str(e)
            }
    
    def _build_refactor_prompt(
        self,
        code: str,
        language: str,
        refactor_type: str
    ) -> str:
        refactor_instructions = {
            "extract_function": "将重复代码提取为函数",
            "rename_variables": "重命名变量,使用更具描述性的名称",
            "simplify_conditions": "简化复杂的条件语句",
            "optimize_loops": "优化循环结构",
            "add_type_hints": "添加类型提示",
            "improve_readability": "提高代码可读性"
        }
        
        instruction = refactor_instructions.get(
            refactor_type,
            "改进代码质量"
        )
        
        prompt = f"""请重构以下{language}代码,{instruction}


请提供重构后的完整代码,并解释所做的改动。"""
        
        return prompt

实践练习

练习1:实现代码补全

python
def implement_code_completion():
    llm_client = openai.OpenAI(api_key="your-api-key")
    completer = SmartCodeCompleter(llm_client)
    context_completer = ContextAwareCompleter(code_parser, semantic_analyzer)
    
    return completer, context_completer

练习2:实现代码分析

python
def implement_code_analysis():
    static_analyzer = StaticCodeAnalyzer()
    quality_evaluator = CodeQualityEvaluator()
    
    return static_analyzer, quality_evaluator

练习3:实现代码生成

python
def implement_code_generation():
    llm_client = openai.OpenAI(api_key="your-api-key")
    test_generator = TestGenerator(llm_client)
    doc_generator = DocumentationGenerator(llm_client)
    
    return test_generator, doc_generator

总结

本节我们学习了代码助手系统的核心功能开发:

  1. 代码补全实现
  2. 代码分析实现
  3. 代码生成实现
  4. 代码导航实现
  5. 代码重构实现

核心功能是代码助手系统的基础,需要准确、高效地帮助开发者提高编码效率。

参考资源