Skip to content

第23天:MCP Client开发

学习目标

  • 掌握MCP Client的架构设计
  • 学会连接MCP Server
  • 学会调用MCP工具
  • 能够开发MCP Client

核心内容

1. MCP Client概述

MCP Client的定义

MCP Client是与MCP Server进行交互的客户端,负责发送请求、接收响应、处理工具调用等操作。MCP Client可以是AI模型(如Claude)、应用程序或其他系统。

MCP Client的职责

  1. 连接管理:建立和维护与MCP Server的连接
  2. 工具发现:获取Server提供的工具列表
  3. 工具调用:调用Server上的工具
  4. 响应处理:处理Server返回的响应
  5. 错误处理:处理通信和执行错误
  6. 会话管理:管理与Server的会话
  7. 认证管理:处理认证和授权
  8. 缓存管理:缓存工具信息和响应结果

2. Client架构设计

2.1 架构组件

MCP Client核心组件

  1. 连接管理器:管理与Server的连接
  2. 请求构建器:构建MCP请求
  3. 响应处理器:处理MCP响应
  4. 工具发现器:发现Server上的工具
  5. 工具调用器:调用Server上的工具
  6. 错误处理器:处理错误
  7. 认证管理器:处理认证和授权
  8. 缓存管理器:管理缓存
  9. 会话管理器:管理会话

2.2 架构设计

MCP Client架构

┌────────────────┐     ┌────────────────┐     ┌────────────────┐
│   Application  │────>│   MCP Client   │────>│   MCP Server   │
│   (AI Model)   │<────│                │<────│                │
└────────────────┘     └────────────────┘     └────────────────┘
        │                      │                      │
        ▼                      ▼                      ▼
┌────────────────┐     ┌────────────────┐     ┌────────────────┐
│   Connection   │     │   Request/     │     │   Tool/        │
│   Manager      │     │   Response     │     │   Resource     │
└────────────────┘     │   Handler      │     │   Manager      │
                       └────────────────┘     └────────────────┘

3. 连接管理

3.1 连接类型

MCP Client连接类型

  1. HTTP连接:基于HTTP/HTTPS的连接
  2. WebSocket连接:基于WebSocket的实时连接
  3. TCP连接:基于TCP的直接连接
  4. UDP连接:基于UDP的无连接通信

3.2 HTTP连接实现

HTTP连接管理

python
import httpx
from typing import Optional, Dict, Any
import time
import logging

class ConnectionManager:
    """连接管理器"""
    def __init__(self, base_url: str, timeout: int = 30, max_retries: int = 3):
        """初始化连接管理器"""
        self.base_url = base_url.rstrip('/')
        self.timeout = timeout
        self.max_retries = max_retries
        self.client = None
        self.last_connected = None
        self.connection_attempts = 0
        self.logger = logging.getLogger(__name__)
    
    def connect(self) -> bool:
        """建立连接"""
        try:
            # 关闭现有连接
            if self.client:
                self.client.close()
            
            # 创建新的HTTP客户端
            self.client = httpx.Client(
                base_url=self.base_url,
                timeout=httpx.Timeout(self.timeout),
                follow_redirects=True
            )
            
            # 测试连接
            response = self.client.get("/health")
            response.raise_for_status()
            
            self.last_connected = time.time()
            self.connection_attempts = 0
            self.logger.info(f"Connected to MCP Server at {self.base_url}")
            return True
        except Exception as e:
            self.connection_attempts += 1
            self.logger.error(f"Failed to connect to MCP Server: {str(e)}")
            return False
    
    def disconnect(self):
        """关闭连接"""
        if self.client:
            self.client.close()
            self.client = None
            self.logger.info("Disconnected from MCP Server")
    
    def is_connected(self) -> bool:
        """检查连接是否活跃"""
        if not self.client:
            return False
        
        # 检查连接时间
        if self.last_connected and (time.time() - self.last_connected) > 3600:
            return False
        
        # 测试连接
        try:
            response = self.client.get("/health")
            return response.status_code == 200
        except:
            return False
    
    def ensure_connected(self) -> bool:
        """确保连接活跃"""
        if not self.is_connected():
            return self.connect()
        return True
    
    def get_client(self) -> Optional[httpx.Client]:
        """获取HTTP客户端"""
        if self.ensure_connected():
            return self.client
        return None
    
    def request(self, method: str, endpoint: str, **kwargs) -> Optional[httpx.Response]:
        """发送HTTP请求"""
        for attempt in range(self.max_retries):
            try:
                client = self.get_client()
                if not client:
                    continue
                
                url = f"{self.base_url}{endpoint}"
                response = client.request(method, url, **kwargs)
                self.last_connected = time.time()
                return response
            except Exception as e:
                self.logger.warning(f"Request failed (attempt {attempt+1}/{self.max_retries}): {str(e)}")
                if attempt < self.max_retries - 1:
                    time.sleep(1)  # 等待后重试
                else:
                    self.logger.error(f"Request failed after {self.max_retries} attempts")
        
        return None

4. 工具发现

4.1 工具发现流程

工具发现流程

  1. 发送工具列表请求:向Server发送mcp.list_tools请求
  2. 接收工具列表:接收Server返回的工具列表
  3. 解析工具信息:解析工具的名称、描述、参数等信息
  4. 缓存工具信息:缓存工具信息,避免重复请求
  5. 更新工具信息:定期更新工具信息

4.2 工具发现实现

工具发现

python
from typing import Dict, List, Optional
import json
import time

class ToolDiscoverer:
    """工具发现器"""
    def __init__(self, connection_manager: ConnectionManager, cache_ttl: int = 3600):
        """初始化工具发现器"""
        self.connection_manager = connection_manager
        self.cache_ttl = cache_ttl  # 缓存有效期(秒)
        self.tools_cache = None
        self.cache_timestamp = None
        self.logger = logging.getLogger(__name__)
    
    def discover_tools(self) -> Optional[List[Dict[str, Any]]]:
        """发现工具"""
        # 检查缓存
        if self._is_cache_valid():
            return self.tools_cache
        
        # 发送工具列表请求
        try:
            client = self.connection_manager.get_client()
            if not client:
                self.logger.error("Cannot connect to MCP Server")
                return None
            
            # 构建请求
            request_data = {
                "jsonrpc": "2.0",
                "method": "mcp.list_tools",
                "params": {},
                "id": 1
            }
            
            # 发送请求
            response = client.post(
                "/mcp",
                json=request_data,
                headers={"Content-Type": "application/json"}
            )
            
            # 检查响应
            if response.status_code != 200:
                self.logger.error(f"Failed to discover tools: HTTP {response.status_code}")
                return None
            
            # 解析响应
            response_data = response.json()
            if "error" in response_data:
                self.logger.error(f"Failed to discover tools: {response_data['error']['message']}")
                return None
            
            tools = response_data.get("result", {}).get("tools", [])
            
            # 更新缓存
            self.tools_cache = tools
            self.cache_timestamp = time.time()
            
            self.logger.info(f"Discovered {len(tools)} tools from MCP Server")
            return tools
        except Exception as e:
            self.logger.error(f"Failed to discover tools: {str(e)}")
            return None
    
    def get_tool(self, tool_name: str) -> Optional[Dict[str, Any]]:
        """获取特定工具"""
        tools = self.discover_tools()
        if not tools:
            return None
        
        for tool in tools:
            if tool.get("name") == tool_name:
                return tool
        
        return None
    
    def get_tools_by_category(self, category: str) -> List[Dict[str, Any]]:
        """根据分类获取工具"""
        tools = self.discover_tools()
        if not tools:
            return []
        
        category_tools = []
        for tool in tools:
            tool_name = tool.get("name")
            if tool_name and tool_name.startswith(f"{category}."):
                category_tools.append(tool)
        
        return category_tools
    
    def refresh_tools(self) -> Optional[List[Dict[str, Any]]]:
        """刷新工具列表"""
        # 清除缓存
        self.tools_cache = None
        self.cache_timestamp = None
        
        # 重新发现工具
        return self.discover_tools()
    
    def _is_cache_valid(self) -> bool:
        """检查缓存是否有效"""
        if self.tools_cache is None or self.cache_timestamp is None:
            return False
        
        return (time.time() - self.cache_timestamp) < self.cache_ttl

5. 工具调用

5.1 工具调用流程

工具调用流程

  1. 构建工具调用请求:构建mcp.call_tool请求
  2. 发送请求:向Server发送请求
  3. 接收响应:接收Server返回的响应
  4. 解析响应:解析响应结果
  5. 处理结果:处理工具执行结果
  6. 处理错误:处理执行错误

5.2 工具调用实现

工具调用

python
from typing import Dict, Any, Optional
import json
import time

class ToolInvoker:
    """工具调用器"""
    def __init__(self, connection_manager: ConnectionManager, tool_discoverer: ToolDiscoverer):
        """初始化工具调用器"""
        self.connection_manager = connection_manager
        self.tool_discoverer = tool_discoverer
        self.logger = logging.getLogger(__name__)
    
    def invoke_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        """调用工具"""
        # 验证工具是否存在
        tool = self.tool_discoverer.get_tool(tool_name)
        if not tool:
            self.logger.error(f"Tool not found: {tool_name}")
            return None
        
        # 验证参数
        if not self._validate_arguments(tool, arguments):
            self.logger.error(f"Invalid arguments for tool: {tool_name}")
            return None
        
        # 构建请求
        request_data = {
            "jsonrpc": "2.0",
            "method": "mcp.call_tool",
            "params": {
                "tool_name": tool_name,
                "arguments": arguments
            },
            "id": int(time.time() * 1000)  # 使用时间戳作为ID
        }
        
        # 发送请求
        try:
            client = self.connection_manager.get_client()
            if not client:
                self.logger.error("Cannot connect to MCP Server")
                return None
            
            # 发送请求
            response = client.post(
                "/mcp",
                json=request_data,
                headers={"Content-Type": "application/json"}
            )
            
            # 检查响应
            if response.status_code != 200:
                self.logger.error(f"Failed to invoke tool: HTTP {response.status_code}")
                return None
            
            # 解析响应
            response_data = response.json()
            if "error" in response_data:
                error_message = response_data["error"].get("message", "Unknown error")
                self.logger.error(f"Failed to invoke tool: {error_message}")
                return {
                    "success": False,
                    "error": error_message
                }
            
            # 提取结果
            result = response_data.get("result", {})
            
            # 处理结果
            tool_result = {
                "success": result.get("status") == "success",
                "output": result.get("output"),
                "error": result.get("error")
            }
            
            self.logger.info(f"Tool invoked successfully: {tool_name}")
            return tool_result
        except Exception as e:
            self.logger.error(f"Failed to invoke tool: {str(e)}")
            return {
                "success": False,
                "error": str(e)
            }
    
    def _validate_arguments(self, tool: Dict[str, Any], arguments: Dict[str, Any]) -> bool:
        """验证工具参数"""
        parameters = tool.get("parameters", [])
        
        # 检查必填参数
        for param in parameters:
            param_name = param.get("name")
            required = param.get("required", False)
            
            if required and param_name not in arguments:
                self.logger.error(f"Missing required parameter: {param_name}")
                return False
            
            # 检查参数类型(简单验证)
            param_type = param.get("type")
            if param_name in arguments:
                arg_value = arguments[param_name]
                if not self._validate_type(arg_value, param_type):
                    self.logger.error(f"Invalid type for parameter {param_name}: expected {param_type}")
                    return False
        
        return True
    
    def _validate_type(self, value: Any, expected_type: str) -> bool:
        """验证参数类型"""
        if expected_type == "string":
            return isinstance(value, str)
        elif expected_type == "number":
            return isinstance(value, (int, float))
        elif expected_type == "boolean":
            return isinstance(value, bool)
        elif expected_type == "integer":
            return isinstance(value, int)
        elif expected_type == "array":
            return isinstance(value, list)
        elif expected_type == "object":
            return isinstance(value, dict)
        elif expected_type == "null":
            return value is None
        
        # 未知类型,默认为有效
        return True
    
    def invoke_tool_async(self, tool_name: str, arguments: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        """异步调用工具"""
        # 这里可以实现真正的异步调用
        # 目前使用同步调用作为示例
        return self.invoke_tool(tool_name, arguments)

6. 响应处理

6.1 响应格式

MCP响应格式

json
{
  "jsonrpc": "2.0",
  "result": {
    "tool_name": "file.read",
    "output": "文件内容...",
    "status": "success",
    "error": null
  },
  "id": 1234567890
}

错误响应格式

json
{
  "jsonrpc": "2.0",
  "error": {
    "code": 400,
    "message": "参数错误"
  },
  "id": 1234567890
}

6.2 响应处理实现

响应处理

python
from typing import Dict, Any, Optional
import json

class ResponseHandler:
    """响应处理器"""
    def __init__(self):
        """初始化响应处理器"""
        self.logger = logging.getLogger(__name__)
    
    def handle_response(self, response: Optional[Dict[str, Any]]) -> Dict[str, Any]:
        """处理响应"""
        if not response:
            return {
                "success": False,
                "error": "No response received"
            }
        
        # 处理错误响应
        if "error" in response:
            error_data = response["error"]
            error_message = error_data.get("message", "Unknown error")
            error_code = error_data.get("code", 500)
            
            self.logger.error(f"Response error (code {error_code}): {error_message}")
            return {
                "success": False,
                "error": error_message,
                "error_code": error_code
            }
        
        # 处理成功响应
        if "result" in response:
            result = response["result"]
            
            # 处理工具调用结果
            if "tool_name" in result:
                return self._handle_tool_result(result)
            
            # 处理工具列表结果
            if "tools" in result:
                return self._handle_tool_list(result)
            
            # 处理资源相关结果
            if "resources" in result or "resource" in result:
                return self._handle_resource_result(result)
            
            # 处理其他结果
            return {
                "success": True,
                "data": result
            }
        
        # 未知响应格式
        self.logger.warning("Unknown response format")
        return {
            "success": False,
            "error": "Unknown response format"
        }
    
    def _handle_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
        """处理工具调用结果"""
        status = result.get("status", "error")
        tool_name = result.get("tool_name", "unknown")
        output = result.get("output")
        error = result.get("error")
        
        if status == "success":
            self.logger.info(f"Tool {tool_name} executed successfully")
            return {
                "success": True,
                "tool_name": tool_name,
                "output": output
            }
        else:
            self.logger.error(f"Tool {tool_name} failed: {error}")
            return {
                "success": False,
                "tool_name": tool_name,
                "error": error
            }
    
    def _handle_tool_list(self, result: Dict[str, Any]) -> Dict[str, Any]:
        """处理工具列表结果"""
        tools = result.get("tools", [])
        tool_count = len(tools)
        
        self.logger.info(f"Received {tool_count} tools")
        return {
            "success": True,
            "tools": tools,
            "tool_count": tool_count
        }
    
    def _handle_resource_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
        """处理资源相关结果"""
        if "resources" in result:
            resources = result.get("resources", [])
            resource_count = len(resources)
            
            self.logger.info(f"Received {resource_count} resources")
            return {
                "success": True,
                "resources": resources,
                "resource_count": resource_count
            }
        elif "resource" in result:
            resource = result.get("resource")
            
            self.logger.info(f"Received resource: {resource.get('id', 'unknown')}")
            return {
                "success": True,
                "resource": resource
            }
        
        self.logger.warning("Unknown resource result format")
        return {
            "success": False,
            "error": "Unknown resource result format"
        }
    
    def parse_response(self, response_text: str) -> Optional[Dict[str, Any]]:
        """解析响应文本"""
        try:
            return json.loads(response_text)
        except json.JSONDecodeError as e:
            self.logger.error(f"Failed to parse response: {str(e)}")
            return None

7. 错误处理

7.1 错误类型

MCP Client常见错误

  1. 连接错误:无法连接到Server
  2. 认证错误:认证失败
  3. 授权错误:授权失败
  4. 请求错误:请求格式错误
  5. 响应错误:响应格式错误
  6. 工具错误:工具执行错误
  7. 超时错误:请求超时
  8. 网络错误:网络故障

7.2 错误处理实现

错误处理

python
from typing import Dict, Any, Optional
import logging
import traceback

class ErrorHandler:
    """错误处理器"""
    def __init__(self):
        """初始化错误处理器"""
        self.logger = logging.getLogger(__name__)
        self.error_count = 0
        self.retry_count = 0
    
    def handle_error(self, error: Exception, context: str = "") -> Dict[str, Any]:
        """处理错误"""
        self.error_count += 1
        error_message = str(error)
        error_type = type(error).__name__
        
        # 记录错误
        self.logger.error(f"Error in {context}: {error_message}")
        self.logger.debug(traceback.format_exc())
        
        # 分类错误
        error_info = {
            "error": error_message,
            "error_type": error_type,
            "context": context,
            "timestamp": self._get_timestamp()
        }
        
        # 处理特定类型的错误
        if isinstance(error, ConnectionError):
            error_info["error_category"] = "connection"
            error_info["recommendation"] = "Check network connection and server status"
        elif isinstance(error, TimeoutError):
            error_info["error_category"] = "timeout"
            error_info["recommendation"] = "Increase timeout or check server performance"
        elif isinstance(error, AuthenticationError):
            error_info["error_category"] = "authentication"
            error_info["recommendation"] = "Check authentication credentials"
        elif isinstance(error, AuthorizationError):
            error_info["error_category"] = "authorization"
            error_info["recommendation"] = "Check permissions"
        else:
            error_info["error_category"] = "unknown"
            error_info["recommendation"] = "Contact server administrator"
        
        return error_info
    
    def handle_api_error(self, error_response: Dict[str, Any], context: str = "") -> Dict[str, Any]:
        """处理API错误"""
        self.error_count += 1
        
        error_data = error_response.get("error", {})
        error_code = error_data.get("code", 500)
        error_message = error_data.get("message", "Unknown API error")
        
        # 记录错误
        self.logger.error(f"API error in {context} (code {error_code}): {error_message}")
        
        error_info = {
            "error": error_message,
            "error_code": error_code,
            "context": context,
            "timestamp": self._get_timestamp(),
            "error_category": "api"
        }
        
        # 处理特定错误代码
        if error_code == 400:
            error_info["recommendation"] = "Check request parameters"
        elif error_code == 401:
            error_info["recommendation"] = "Check authentication credentials"
        elif error_code == 403:
            error_info["recommendation"] = "Check permissions"
        elif error_code == 404:
            error_info["recommendation"] = "Check resource existence"
        elif error_code == 500:
            error_info["recommendation"] = "Contact server administrator"
        elif error_code == 503:
            error_info["recommendation"] = "Server is temporarily unavailable"
        else:
            error_info["recommendation"] = "Check request and server status"
        
        return error_info
    
    def should_retry(self, error_info: Dict[str, Any]) -> bool:
        """判断是否应该重试"""
        error_category = error_info.get("error_category")
        error_code = error_info.get("error_code")
        
        # 可以重试的错误类型
        retryable_categories = ["connection", "timeout", "api"]
        retryable_codes = [408, 429, 500, 502, 503, 504]
        
        if error_category in retryable_categories:
            return True
        
        if error_code in retryable_codes:
            return True
        
        return False
    
    def get_retry_delay(self, attempt: int) -> float:
        """获取重试延迟"""
        # 指数退避策略
        base_delay = 1.0  # 基础延迟(秒)
        max_delay = 30.0  # 最大延迟(秒)
        
        delay = min(base_delay * (2 ** attempt), max_delay)
        self.retry_count += 1
        
        return delay
    
    def get_error_stats(self) -> Dict[str, Any]:
        """获取错误统计信息"""
        return {
            "total_errors": self.error_count,
            "total_retries": self.retry_count
        }
    
    def _get_timestamp(self) -> str:
        """获取当前时间戳"""
        import datetime
        return datetime.datetime.utcnow().isoformat()

class AuthenticationError(Exception):
    """认证错误"""
    pass

class AuthorizationError(Exception):
    """授权错误"""
    pass

class MCPError(Exception):
    """MCP错误"""
    def __init__(self, message: str, error_code: int = 500):
        self.message = message
        self.error_code = error_code
        super().__init__(self.message)

8. 认证管理

8.1 认证方式

MCP Client认证方式

  1. API Key:通过HTTP头或请求参数传递API Key
  2. JWT:通过Bearer令牌认证
  3. OAuth2:通过OAuth2令牌认证
  4. 基本认证:通过用户名和密码认证
  5. 自定义认证:自定义认证方式

8.2 认证管理实现

认证管理

python
from typing import Dict, Any, Optional
import time
import jwt
from jose import JWTError

class AuthManager:
    """认证管理器"""
    def __init__(self, auth_type: str = "api_key", **kwargs):
        """初始化认证管理器"""
        self.auth_type = auth_type
        self.auth_config = kwargs
        self.token_cache = None
        self.token_expiry = None
        self.logger = logging.getLogger(__name__)
    
    def get_headers(self) -> Dict[str, str]:
        """获取认证头"""
        if self.auth_type == "api_key":
            return self._get_api_key_headers()
        elif self.auth_type == "jwt":
            return self._get_jwt_headers()
        elif self.auth_type == "oauth2":
            return self._get_oauth2_headers()
        elif self.auth_type == "basic":
            return self._get_basic_headers()
        else:
            self.logger.warning(f"Unknown auth type: {self.auth_type}")
            return {}
    
    def _get_api_key_headers(self) -> Dict[str, str]:
        """获取API Key认证头"""
        api_key = self.auth_config.get("api_key")
        header_name = self.auth_config.get("header_name", "X-API-Key")
        
        if api_key:
            return {header_name: api_key}
        return {}
    
    def _get_jwt_headers(self) -> Dict[str, str]:
        """获取JWT认证头"""
        token = self._get_jwt_token()
        if token:
            return {"Authorization": f"Bearer {token}"}
        return {}
    
    def _get_oauth2_headers(self) -> Dict[str, str]:
        """获取OAuth2认证头"""
        token = self._get_oauth2_token()
        if token:
            return {"Authorization": f"Bearer {token}"}
        return {}
    
    def _get_basic_headers(self) -> Dict[str, str]:
        """获取基本认证头"""
        username = self.auth_config.get("username")
        password = self.auth_config.get("password")
        
        if username and password:
            import base64
            credentials = f"{username}:{password}"
            encoded_credentials = base64.b64encode(credentials.encode()).decode()
            return {"Authorization": f"Basic {encoded_credentials}"}
        return {}
    
    def _get_jwt_token(self) -> Optional[str]:
        """获取JWT令牌"""
        # 检查缓存
        if self._is_token_valid():
            return self.token_cache
        
        # 生成新令牌
        if "secret" in self.auth_config:
            token = self._generate_jwt_token()
            if token:
                self.token_cache = token
                self.token_expiry = time.time() + self.auth_config.get("expiry", 3600)
                return token
        
        # 使用配置的令牌
        return self.auth_config.get("token")
    
    def _generate_jwt_token(self) -> Optional[str]:
        """生成JWT令牌"""
        try:
            secret = self.auth_config.get("secret")
            algorithm = self.auth_config.get("algorithm", "HS256")
            expiry = self.auth_config.get("expiry", 3600)
            payload = self.auth_config.get("payload", {})
            
            # 添加过期时间
            payload["exp"] = time.time() + expiry
            payload["iat"] = time.time()
            
            token = jwt.encode(payload, secret, algorithm=algorithm)
            return token
        except Exception as e:
            self.logger.error(f"Failed to generate JWT token: {str(e)}")
            return None
    
    def _get_oauth2_token(self) -> Optional[str]:
        """获取OAuth2令牌"""
        # 检查缓存
        if self._is_token_valid():
            return self.token_cache
        
        # 这里可以实现OAuth2令牌获取逻辑
        # 目前使用配置的令牌
        return self.auth_config.get("token")
    
    def _is_token_valid(self) -> bool:
        """检查令牌是否有效"""
        if self.token_cache and self.token_expiry:
            return time.time() < self.token_expiry
        return False
    
    def refresh_token(self) -> bool:
        """刷新令牌"""
        try:
            self.token_cache = None
            self.token_expiry = None
            
            # 重新获取令牌
            headers = self.get_headers()
            return "Authorization" in headers
        except Exception as e:
            self.logger.error(f"Failed to refresh token: {str(e)}")
            return False
    
    def validate_auth(self) -> bool:
        """验证认证配置"""
        if self.auth_type == "api_key":
            return "api_key" in self.auth_config
        elif self.auth_type == "jwt":
            return "token" in self.auth_config or "secret" in self.auth_config
        elif self.auth_type == "oauth2":
            return "token" in self.auth_config
        elif self.auth_type == "basic":
            return "username" in self.auth_config and "password" in self.auth_config
        else:
            return True

9. 会话管理

9.1 会话概念

MCP会话

会话是Client与Server之间的交互上下文,包含认证信息、状态信息、历史记录等。会话管理确保Client与Server之间的交互能够正确进行,特别是在需要保持状态的场景中。

会话的重要性

  1. 状态保持:保持交互状态
  2. 认证持久化:保持认证信息
  3. 上下文管理:管理交互上下文
  4. 历史记录:记录交互历史
  5. 性能优化:减少重复认证和初始化

9.2 会话管理实现

会话管理

python
from typing import Dict, Any, Optional
import uuid
import time
import json

class SessionManager:
    """会话管理器"""
    def __init__(self, session_timeout: int = 3600):
        """初始化会话管理器"""
        self.session_timeout = session_timeout
        self.sessions: Dict[str, Dict[str, Any]] = {}
        self.logger = logging.getLogger(__name__)
    
    def create_session(self, initial_data: Dict[str, Any] = None) -> str:
        """创建会话"""
        session_id = str(uuid.uuid4())
        session_data = {
            "id": session_id,
            "created_at": time.time(),
            "last_accessed": time.time(),
            "data": initial_data or {},
            "tool_calls": [],
            "responses": []
        }
        
        self.sessions[session_id] = session_data
        self.logger.info(f"Created session: {session_id}")
        return session_id
    
    def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
        """获取会话"""
        if session_id not in self.sessions:
            self.logger.warning(f"Session not found: {session_id}")
            return None
        
        session = self.sessions[session_id]
        
        # 检查会话是否过期
        if self._is_session_expired(session):
            self.delete_session(session_id)
            self.logger.warning(f"Session expired: {session_id}")
            return None
        
        # 更新最后访问时间
        session["last_accessed"] = time.time()
        self.sessions[session_id] = session
        
        return session
    
    def update_session(self, session_id: str, data: Dict[str, Any]) -> bool:
        """更新会话"""
        session = self.get_session(session_id)
        if not session:
            return False
        
        session["data"].update(data)
        session["last_accessed"] = time.time()
        self.sessions[session_id] = session
        
        self.logger.debug(f"Updated session: {session_id}")
        return True
    
    def delete_session(self, session_id: str) -> bool:
        """删除会话"""
        if session_id in self.sessions:
            del self.sessions[session_id]
            self.logger.info(f"Deleted session: {session_id}")
            return True
        return False
    
    def record_tool_call(self, session_id: str, tool_name: str, arguments: Dict[str, Any], result: Dict[str, Any]):
        """记录工具调用"""
        session = self.get_session(session_id)
        if not session:
            return
        
        tool_call_record = {
            "tool_name": tool_name,
            "arguments": arguments,
            "result": result,
            "timestamp": time.time()
        }
        
        session["tool_calls"].append(tool_call_record)
        
        # 限制记录数量
        if len(session["tool_calls"]) > 100:
            session["tool_calls"] = session["tool_calls"][-100:]
        
        self.sessions[session_id] = session
    
    def get_tool_calls(self, session_id: str) -> List[Dict[str, Any]]:
        """获取工具调用记录"""
        session = self.get_session(session_id)
        if not session:
            return []
        
        return session.get("tool_calls", [])
    
    def clear_expired_sessions(self):
        """清理过期会话"""
        expired_sessions = []
        
        for session_id, session in self.sessions.items():
            if self._is_session_expired(session):
                expired_sessions.append(session_id)
        
        for session_id in expired_sessions:
            self.delete_session(session_id)
        
        if expired_sessions:
            self.logger.info(f"Cleared {len(expired_sessions)} expired sessions")
    
    def get_active_sessions(self) -> List[str]:
        """获取活跃会话列表"""
        active_sessions = []
        
        for session_id, session in self.sessions.items():
            if not self._is_session_expired(session):
                active_sessions.append(session_id)
        
        return active_sessions
    
    def _is_session_expired(self, session: Dict[str, Any]) -> bool:
        """检查会话是否过期"""
        last_accessed = session.get("last_accessed", 0)
        return (time.time() - last_accessed) > self.session_timeout
    
    def save_session(self, session_id: str, file_path: str):
        """保存会话到文件"""
        session = self.get_session(session_id)
        if not session:
            return False
        
        try:
            with open(file_path, 'w', encoding='utf-8') as f:
                json.dump(session, f, indent=2, ensure_ascii=False)
            self.logger.info(f"Saved session {session_id} to {file_path}")
            return True
        except Exception as e:
            self.logger.error(f"Failed to save session: {str(e)}")
            return False
    
    def load_session(self, file_path: str) -> Optional[str]:
        """从文件加载会话"""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                session_data = json.load(f)
            
            # 创建新会话
            session_id = self.create_session(session_data.get("data", {}))
            
            # 恢复工具调用记录
            tool_calls = session_data.get("tool_calls", [])
            session = self.get_session(session_id)
            if session:
                session["tool_calls"] = tool_calls
                self.sessions[session_id] = session
            
            self.logger.info(f"Loaded session from {file_path}")
            return session_id
        except Exception as e:
            self.logger.error(f"Failed to load session: {str(e)}")
            return None

10. 缓存管理

10.1 缓存策略

MCP Client缓存策略

  1. 工具缓存:缓存工具列表和工具信息
  2. 响应缓存:缓存工具调用的响应结果
  3. 会话缓存:缓存会话信息
  4. 认证缓存:缓存认证令牌

缓存的好处

  1. 减少请求:减少对Server的请求
  2. 提高性能:快速响应,减少延迟
  3. 节省带宽:减少网络传输
  4. 容错性:在网络故障时提供备用数据

10.2 缓存管理实现

缓存管理

python
from typing import Dict, Any, Optional
import time
import hashlib

class CacheManager:
    """缓存管理器"""
    def __init__(self, cache_ttl: int = 3600):
        """初始化缓存管理器"""
        self.cache_ttl = cache_ttl
        self.cache: Dict[str, Dict[str, Any]] = {}
        self.logger = logging.getLogger(__name__)
    
    def set(self, key: str, value: Any, ttl: Optional[int] = None):
        """设置缓存"""
        ttl = ttl or self.cache_ttl
        timestamp = time.time()
        
        self.cache[key] = {
            "value": value,
            "timestamp": timestamp,
            "ttl": ttl
        }
        
        self.logger.debug(f"Set cache: {key} (ttl: {ttl}s)")
    
    def get(self, key: str) -> Optional[Any]:
        """获取缓存"""
        if key not in self.cache:
            return None
        
        cache_item = self.cache[key]
        timestamp = cache_item.get("timestamp", 0)
        ttl = cache_item.get("ttl", self.cache_ttl)
        
        # 检查缓存是否过期
        if (time.time() - timestamp) > ttl:
            self.delete(key)
            return None
        
        self.logger.debug(f"Get cache: {key}")
        return cache_item.get("value")
    
    def delete(self, key: str):
        """删除缓存"""
        if key in self.cache:
            del self.cache[key]
            self.logger.debug(f"Delete cache: {key}")
    
    def clear(self):
        """清空缓存"""
        self.cache.clear()
        self.logger.info("Clear all cache")
    
    def get_cache_key(self, prefix: str, **kwargs) -> str:
        """生成缓存键"""
        # 对参数进行排序,确保相同参数生成相同的键
        sorted_kwargs = sorted(kwargs.items())
        kwargs_str = "_".join([f"{k}={v}" for k, v in sorted_kwargs])
        
        # 使用MD5对参数进行哈希,避免键过长
        hash_obj = hashlib.md5(kwargs_str.encode())
        hash_str = hash_obj.hexdigest()
        
        return f"{prefix}:{hash_str}"
    
    def set_tool_list(self, tools: List[Dict[str, Any]], ttl: Optional[int] = None):
        """设置工具列表缓存"""
        key = self.get_cache_key("tools", action="list")
        self.set(key, tools, ttl)
    
    def get_tool_list(self) -> Optional[List[Dict[str, Any]]]:
        """获取工具列表缓存"""
        key = self.get_cache_key("tools", action="list")
        return self.get(key)
    
    def set_tool_info(self, tool_name: str, tool_info: Dict[str, Any], ttl: Optional[int] = None):
        """设置工具信息缓存"""
        key = self.get_cache_key("tools", action="info", name=tool_name)
        self.set(key, tool_info, ttl)
    
    def get_tool_info(self, tool_name: str) -> Optional[Dict[str, Any]]:
        """获取工具信息缓存"""
        key = self.get_cache_key("tools", action="info", name=tool_name)
        return self.get(key)
    
    def set_tool_result(self, tool_name: str, arguments: Dict[str, Any], result: Dict[str, Any], ttl: Optional[int] = None):
        """设置工具调用结果缓存"""
        key = self.get_cache_key("tools", action="result", name=tool_name, **arguments)
        self.set(key, result, ttl)
    
    def get_tool_result(self, tool_name: str, arguments: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        """获取工具调用结果缓存"""
        key = self.get_cache_key("tools", action="result", name=tool_name, **arguments)
        return self.get(key)
    
    def clear_tool_cache(self):
        """清空工具相关缓存"""
        tools_keys = [key for key in self.cache if key.startswith("tools:")]
        for key in tools_keys:
            self.delete(key)
        self.logger.info("Clear tool cache")
    
    def get_cache_stats(self) -> Dict[str, Any]:
        """获取缓存统计信息"""
        total_items = len(self.cache)
        expired_items = 0
        tool_items = 0
        
        for key, item in self.cache.items():
            if key.startswith("tools:"):
                tool_items += 1
            
            timestamp = item.get("timestamp", 0)
            ttl = item.get("ttl", self.cache_ttl)
            if (time.time() - timestamp) > ttl:
                expired_items += 1
        
        return {
            "total_items": total_items,
            "expired_items": expired_items,
            "tool_items": tool_items
        }

11. 实战:开发完整的MCP Client

11.1 Client架构

MCP Client完整架构

python
from typing import Dict, Any, Optional, List
import logging

class MCPClient:
    """MCP Client主类"""
    def __init__(self, base_url: str, auth_type: str = "api_key", **kwargs):
        """初始化MCP Client"""
        # 配置日志
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger("MCPClient")
        
        # 初始化组件
        from app.client.connection import ConnectionManager
        from app.client.tool_discoverer import ToolDiscoverer
        from app.client.tool_invoker import ToolInvoker
        from app.client.response_handler import ResponseHandler
        from app.client.error_handler import ErrorHandler
        from app.client.auth_manager import AuthManager
        from app.client.session_manager import SessionManager
        from app.client.cache_manager import CacheManager
        
        # 初始化连接管理器
        self.connection_manager = ConnectionManager(
            base_url=base_url,
            timeout=kwargs.get("timeout", 30),
            max_retries=kwargs.get("max_retries", 3)
        )
        
        # 初始化认证管理器
        self.auth_manager = AuthManager(
            auth_type=auth_type,
            **kwargs.get("auth_config", {})
        )
        
        # 初始化缓存管理器
        self.cache_manager = CacheManager(
            cache_ttl=kwargs.get("cache_ttl", 3600)
        )
        
        # 初始化工具发现器
        self.tool_discoverer = ToolDiscoverer(
            connection_manager=self.connection_manager
        )
        
        # 初始化工具调用器
        self.tool_invoker = ToolInvoker(
            connection_manager=self.connection_manager,
            tool_discoverer=self.tool_discoverer
        )
        
        # 初始化响应处理器
        self.response_handler = ResponseHandler()
        
        # 初始化错误处理器
        self.error_handler = ErrorHandler()
        
        # 初始化会话管理器
        self.session_manager = SessionManager(
            session_timeout=kwargs.get("session_timeout", 3600)
        )
        
        # 验证配置
        if not self.auth_manager.validate_auth():
            self.logger.warning("Authentication configuration may be incomplete")
        
        self.logger.info(f"MCP Client initialized for {base_url}")
    
    def list_tools(self) -> Optional[List[Dict[str, Any]]]:
        """获取工具列表"""
        # 检查缓存
        cached_tools = self.cache_manager.get_tool_list()
        if cached_tools:
            self.logger.info("Using cached tool list")
            return cached_tools
        
        # 发现工具
        tools = self.tool_discoverer.discover_tools()
        
        # 缓存工具列表
        if tools:
            self.cache_manager.set_tool_list(tools)
        
        return tools
    
    def get_tool(self, tool_name: str) -> Optional[Dict[str, Any]]:
        """获取工具信息"""
        # 检查缓存
        cached_tool = self.cache_manager.get_tool_info(tool_name)
        if cached_tool:
            self.logger.info(f"Using cached tool info for {tool_name}")
            return cached_tool
        
        # 发现工具
        tool = self.tool_discoverer.get_tool(tool_name)
        
        # 缓存工具信息
        if tool:
            self.cache_manager.set_tool_info(tool_name, tool)
        
        return tool
    
    def call_tool(self, tool_name: str, arguments: Dict[str, Any], session_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
        """调用工具"""
        # 检查缓存
        cached_result = self.cache_manager.get_tool_result(tool_name, arguments)
        if cached_result:
            self.logger.info(f"Using cached result for tool {tool_name}")
            return cached_result
        
        # 调用工具
        result = self.tool_invoker.invoke_tool(tool_name, arguments)
        
        # 缓存结果
        if result:
            self.cache_manager.set_tool_result(tool_name, arguments, result)
        
        # 记录工具调用(如果有会话)
        if session_id:
            self.session_manager.record_tool_call(session_id, tool_name, arguments, result)
        
        return result
    
    def create_session(self, initial_data: Dict[str, Any] = None) -> str:
        """创建会话"""
        return self.session_manager.create_session(initial_data)
    
    def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
        """获取会话"""
        return self.session_manager.get_session(session_id)
    
    def update_session(self, session_id: str, data: Dict[str, Any]) -> bool:
        """更新会话"""
        return self.session_manager.update_session(session_id, data)
    
    def delete_session(self, session_id: str) -> bool:
        """删除会话"""
        return self.session_manager.delete_session(session_id)
    
    def clear_cache(self):
        """清空缓存"""
        self.cache_manager.clear()
        self.logger.info("Cache cleared")
    
    def clear_tool_cache(self):
        """清空工具缓存"""
        self.cache_manager.clear_tool_cache()
        self.logger.info("Tool cache cleared")
    
    def clear_expired_sessions(self):
        """清理过期会话"""
        self.session_manager.clear_expired_sessions()
    
    def get_stats(self) -> Dict[str, Any]:
        """获取统计信息"""
        return {
            "cache": self.cache_manager.get_cache_stats(),
            "errors": self.error_handler.get_error_stats(),
            "active_sessions": len(self.session_manager.get_active_sessions())
        }
    
    def close(self):
        """关闭Client"""
        self.connection_manager.disconnect()
        self.clear_cache()
        self.clear_expired_sessions()
        self.logger.info("MCP Client closed")
    
    def __enter__(self):
        """进入上下文管理器"""
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        """退出上下文管理器"""
        self.close()

11.2 Client使用示例

MCP Client使用

python
# 使用示例
from app.client.mcp_client import MCPClient

# 初始化Client
client = MCPClient(
    base_url="http://localhost:8000",
    auth_type="api_key",
    auth_config={
        "api_key": "your-secret-api-key"
    }
)

try:
    # 获取工具列表
    tools = client.list_tools()
    print(f"Available tools: {[tool['name'] for tool in tools]}")
    
    # 调用工具
    result = client.call_tool(
        "file.read",
        {
            "file_path": "/tmp/test.txt",
            "encoding": "utf-8"
        }
    )
    print(f"Tool result: {result}")
    
    # 创建会话
    session_id = client.create_session({"user": "test"})
    print(f"Created session: {session_id}")
    
    # 在会话中调用工具
    session_result = client.call_tool(
        "file.read",
        {
            "file_path": "/tmp/test.txt"
        },
        session_id=session_id
    )
    print(f"Session tool result: {session_result}")
    
    # 获取会话信息
    session = client.get_session(session_id)
    print(f"Session info: {session}")
    
    # 获取统计信息
    stats = client.get_stats()
    print(f"Client stats: {stats}")
    
finally:
    # 关闭Client
    client.close()

# 使用上下文管理器
with MCPClient(
    base_url="http://localhost:8000",
    auth_type="api_key",
    auth_config={"api_key": "your-secret-api-key"}
) as client:
    tools = client.list_tools()
    print(f"Tools count: {len(tools)}")
    
    # 调用工具
    result = client.call_tool(
        "file.write",
        {
            "file_path": "/tmp/test.txt",
            "content": "Hello, MCP!",
            "encoding": "utf-8"
        }
    )
    print(f"Write result: {result}")

12. 技术选型建议

12.1 Client技术栈

推荐技术栈

组件技术推荐理由
语言Python 3.8+简洁易读,生态丰富
HTTP客户端httpx现代HTTP客户端,支持异步
异步支持asyncio支持异步操作
JSON处理json标准库,性能好
认证python-joseJWT支持
缓存内存缓存简单高效
日志logging标准库,功能强大
测试pytest灵活的测试框架

12.2 替代方案

替代技术

技术适用场景优势
Node.js前端集成与前端技术栈一致
Go高性能场景编译型语言,性能优异
Java企业级应用成熟稳定,生态丰富
C#Windows环境与.NET集成良好
Rust安全场景内存安全,性能优异

13. 常见问题与解决方案

13.1 连接问题

问题:无法连接到MCP Server。

解决方案

  • 检查Server地址是否正确
  • 检查网络连接
  • 检查Server是否运行
  • 检查防火墙设置
  • 检查认证信息

13.2 认证问题

问题:认证失败。

解决方案

  • 检查API Key是否正确
  • 检查JWT令牌是否有效
  • 检查OAuth2配置是否正确
  • 检查用户名和密码是否正确
  • 检查权限设置

13.3 工具调用问题

问题:工具调用失败。

解决方案

  • 检查工具名称是否正确
  • 检查参数是否正确
  • 检查权限是否足够
  • 检查Server日志
  • 检查网络连接

13.4 性能问题

问题:Client响应缓慢。

解决方案

  • 启用缓存
  • 使用异步操作
  • 优化网络设置
  • 减少请求频率
  • 增加超时时间

14. 学习资源

14.1 官方文档

14.2 在线资源

14.3 推荐书籍

  • Python Cookbook:by David Beazley & Brian K. Jones
  • Fluent Python:by Luciano Ramalho
  • HTTP: The Definitive Guide:by David Gourley & Brian Totty
  • Network Programming with Python:by Brandon Rhodes & John Goerzen

15. 总结

本课程深入介绍了MCP Client的开发基础,包括架构设计、连接管理、工具发现、工具调用、响应处理、错误处理、认证管理、会话管理和缓存管理等核心内容。通过本课程的学习,你应该能够:

  • 理解MCP Client的架构和职责
  • 开发完整的MCP Client
  • 与MCP Server进行交互
  • 调用Server上的工具
  • 处理错误和异常
  • 管理会话和缓存

在后续课程中,我们将学习MCP与Claude Desktop集成、MCP生态系统等内容,进一步完善你的MCP开发技能。


课后作业

  1. 实践题

    • 开发完整的MCP Client
    • 实现与MCP Server的交互
    • 测试工具调用功能
    • 实现会话管理和缓存
  2. 思考题

    • 如何优化MCP Client的性能?
    • 如何提高MCP Client的可靠性?
    • 如何处理大规模工具调用?
    • 如何实现MCP Client的负载均衡?

架构师AI杜公众号二维码

扫描二维码关注"架构师AI杜"公众号,获取更多技术内容和最新动态