第69集:工厂模式
学习目标
- 理解工厂模式的概念和作用
- 掌握简单工厂、工厂方法和抽象工厂三种实现
- 学会根据不同场景选择合适的工厂模式
- 了解工厂模式与依赖注入的关系
- 能够在实际项目中应用工厂模式解耦对象创建
工厂模式概念
什么是工厂模式
工厂模式(Factory Pattern)是一种创建型设计模式,它提供了一种创建对象的接口,但让子类决定要实例化的类是哪一个。工厂模式将对象的创建与使用分离,提高了代码的灵活性和可扩展性。
为什么要使用工厂模式
- 解耦创建与使用:客户端不需要知道具体类名,只需知道对应的参数
- 提高扩展性:新增产品时,不需要修改现有客户端代码
- 统一管理创建逻辑:集中处理对象的创建、初始化、配置等逻辑
- 支持多态:客户端面向抽象编程,运行时决定具体实现
- 便于测试:可以通过工厂模拟对象创建,便于单元测试
工厂模式的三种形式
- 简单工厂模式:一个工厂类根据参数创建不同产品
- 工厂方法模式:定义创建产品的接口,由子类决定创建哪种产品
- 抽象工厂模式:创建相关或依赖对象的家族,而不需要指定具体类
简单工厂模式
基本概念
简单工厂模式由一个工厂类根据传入的参数,动态决定应该创建哪一个产品类的实例。它是最基本的工厂模式,但违反了开闭原则。
实现示例
from abc import ABC, abstractmethod
from typing import Dict, Any
# 产品接口
class Payment(ABC):
"""支付接口 - 抽象产品"""
@abstractmethod
def pay(self, amount: float) -> str:
"""支付方法"""
pass
@abstractmethod
def get_name(self) -> str:
"""获取支付方式名称"""
pass
# 具体产品类
class Alipay(Payment):
"""支付宝支付"""
def pay(self, amount: float) -> str:
return f"使用支付宝成功支付 ¥{amount:.2f}"
def get_name(self) -> str:
return "支付宝"
class WechatPay(Payment):
"""微信支付"""
def pay(self, amount: float) -> str:
return f"使用微信支付成功支付 ¥{amount:.2f}"
def get_name(self) -> str:
return "微信支付"
class BankCardPay(Payment):
"""银行卡支付"""
def __init__(self, bank_name: str = "工商银行"):
self.bank_name = bank_name
def pay(self, amount: float) -> str:
return f"使用{self.bank_name}银行卡成功支付 ¥{amount:.2f}"
def get_name(self) -> str:
return f"银行卡({self.bank_name})"
# 简单工厂类
class PaymentFactory:
"""支付工厂 - 简单工厂模式"""
# 支持的支付方式映射
_payment_types = {
"alipay": Alipay,
"wechat": WechatPay,
"bank": BankCardPay
}
@classmethod
def create_payment(cls, payment_type: str, **kwargs) -> Payment:
"""创建支付对象
Args:
payment_type: 支付类型(alipay, wechat, bank)
**kwargs: 额外参数,如银行名称等
Returns:
Payment: 支付对象实例
Raises:
ValueError: 当支付类型不支持时
"""
# 标准化支付类型
payment_type = payment_type.lower().strip()
if payment_type not in cls._payment_types:
supported_types = ", ".join(cls._payment_types.keys())
raise ValueError(f"不支持的支付类型: {payment_type}。支持的类型: {supported_types}")
# 获取对应的支付类
payment_class = cls._payment_types[payment_type]
# 创建实例(处理特殊参数)
if payment_type == "bank" and "bank_name" in kwargs:
return payment_class(kwargs["bank_name"])
else:
return payment_class()
@classmethod
def get_supported_types(cls) -> list:
"""获取支持的支付类型"""
return list(cls._payment_types.keys())
# 使用示例
print("=== 简单工厂模式演示 ===")
# 客户端代码 - 不需要知道具体类名
print("创建不同的支付对象:")
try:
alipay = PaymentFactory.create_payment("alipay")
print(f"创建成功: {alipay.get_name()}")
wechat_pay = PaymentFactory.create_payment("wechat")
print(f"创建成功: {wechat_pay.get_name()}")
icbc_pay = PaymentFactory.create_payment("bank", bank_name="工商银行")
print(f"创建成功: {icbc_pay.get_name()}")
cmb_pay = PaymentFactory.create_payment("bank", bank_name="招商银行")
print(f"创建成功: {cmb_pay.get_name()}")
except ValueError as e:
print(f"创建失败: {e}")
# 使用支付功能
print("\n执行支付操作:")
payments = [
PaymentFactory.create_payment("alipay"),
PaymentFactory.create_payment("wechat"),
PaymentFactory.create_payment("bank", bank_name="建设银行")
]
for payment in payments:
result = payment.pay(100.0)
print(f" {result}")
# 查看支持的支付类型
print(f"\n支持的支付类型: {PaymentFactory.get_supported_types()}")
# 测试错误处理
print("\n测试错误处理:")
try:
invalid_pay = PaymentFactory.create_payment("bitcoin")
except ValueError as e:
print(f"预期的错误: {e}")简单工厂模式的优缺点
优点:
- 客户端无需知道具体产品类名,只需知道参数
- 实现了创建逻辑与使用逻辑的分离
- 代码结构简单,易于理解和维护
缺点:
- 违反开闭原则:新增产品时需要修改工厂类
- 工厂类职责过重,不易于扩展
- 当产品较多时,工厂类会变得庞大
工厂方法模式
基本概念
工厂方法模式定义了一个创建对象的接口,但让子类决定实例化哪一个类。工厂方法使一个类的实例化延迟到其子类。
实现示例
from abc import ABC, abstractmethod
from typing import List
# 产品接口
class LoggerFactory(ABC):
"""日志工厂接口 - 抽象工厂"""
@abstractmethod
def create_logger(self):
"""创建日志记录器"""
pass
def log_info(self, message: str):
"""记录信息日志 - 通用方法"""
logger = self.create_logger()
logger.info(message)
def log_error(self, message: str):
"""记录错误日志 - 通用方法"""
logger = self.create_logger()
logger.error(message)
# 具体产品类
class FileLogger:
"""文件日志记录器"""
def info(self, message: str):
print(f"[文件日志] INFO: {message}")
def error(self, message: str):
print(f"[文件日志] ERROR: {message}")
def get_type(self) -> str:
return "文件日志"
class DatabaseLogger:
"""数据库日志记录器"""
def info(self, message: str):
print(f"[数据库日志] INFO: {message}")
def error(self, message: str):
print(f"[数据库日志] ERROR: {message}")
def get_type(self) -> str:
return "数据库日志"
class ConsoleLogger:
"""控制台日志记录器"""
def info(self, message: str):
print(f"[控制台日志] INFO: {message}")
def error(self, message: str):
print(f"[控制台日志] ERROR: {message}")
def get_type(self) -> str:
return "控制台日志"
# 具体工厂类
class FileLoggerFactory(LoggerFactory):
"""文件日志工厂"""
def create_logger(self):
print("创建文件日志记录器...")
# 这里可以添加文件日志的初始化逻辑
return FileLogger()
class DatabaseLoggerFactory(LoggerFactory):
"""数据库日志工厂"""
def __init__(self, connection_string: str = "default_db"):
self.connection_string = connection_string
def create_logger(self):
print(f"创建数据库日志记录器,连接: {self.connection_string}...")
# 这里可以添加数据库连接的初始化逻辑
return DatabaseLogger()
class ConsoleLoggerFactory(LoggerFactory):
"""控制台日志工厂"""
def __init__(self, colored_output: bool = True):
self.colored_output = colored_output
def create_logger(self):
print(f"创建控制台日志记录器,彩色输出: {self.colored_output}...")
return ConsoleLogger()
# 日志管理器(客户端)
class LogManager:
"""日志管理器 - 演示如何使用工厂方法模式"""
def __init__(self, factory: LoggerFactory):
self.factory = factory
def set_factory(self, factory: LoggerFactory):
"""切换日志工厂"""
print(f"切换日志工厂从 {self.factory.create_logger().get_type()} 到 {factory.create_logger().get_type()}")
self.factory = factory
def process_business_operation(self, operation: str):
"""模拟业务操作,记录日志"""
try:
print(f"执行操作: {operation}")
# 模拟业务逻辑
if "error" in operation.lower():
raise ValueError("模拟业务错误")
# 记录成功日志
self.factory.log_info(f"操作 '{operation}' 执行成功")
except Exception as e:
# 记录错误日志
self.factory.log_error(f"操作 '{operation}' 执行失败: {str(e)}")
# 使用示例
print("\n=== 工厂方法模式演示 ===")
# 创建不同的工厂
file_factory = FileLoggerFactory()
database_factory = DatabaseLoggerFactory("mysql://localhost/logs")
console_factory = ConsoleLoggerFactory(colored_output=True)
# 使用文件日志
print("--- 使用文件日志 ---")
log_manager = LogManager(file_factory)
log_manager.process_business_operation("用户登录")
log_manager.process_business_operation("数据保存")
# 切换到数据库日志
print("\n--- 切换到数据库日志 ---")
log_manager.set_factory(database_factory)
log_manager.process_business_operation("订单处理")
log_manager.process_business_operation("支付操作")
# 切换到控制台日志
print("\n--- 切换到控制台日志 ---")
log_manager.set_factory(console_factory)
log_manager.process_business_operation("系统启动")
log_manager.process_business_operation("模拟错误操作")
# 演示扩展性 - 新增日志类型不需要修改现有代码
class RemoteLogger:
"""远程日志服务(新增的产品)"""
def info(self, message: str):
print(f"[远程日志服务] INFO: {message}")
def error(self, message: str):
print(f"[远程日志服务] ERROR: {message}")
def get_type(self) -> str:
return "远程日志服务"
class RemoteLoggerFactory(LoggerFactory):
"""远程日志工厂(新增的工厂)"""
def __init__(self, endpoint: str = "https://logs.example.com"):
self.endpoint = endpoint
def create_logger(self):
print(f"创建远程日志服务,端点: {self.endpoint}...")
return RemoteLogger()
print("\n--- 使用新增的远程日志 ---")
remote_factory = RemoteLoggerFactory("https://myapp-logs.company.com")
log_manager.set_factory(remote_factory)
log_manager.process_business_operation("API调用")
print("✅ 新增日志类型无需修改现有代码,符合开闭原则!")工厂方法模式的优缺点
优点:
- 符合开闭原则:新增产品时只需添加新的工厂类
- 符合单一职责原则:每个工厂只负责创建一种产品
- 支持多态:客户端面向抽象编程
- 易于扩展:新增产品不影响现有代码
缺点:
- 类的数量增多:每个产品都需要对应的工厂类
- 增加了系统的抽象性和理解难度
- 增加了系统的复杂度
抽象工厂模式
基本概念
抽象工厂模式提供一个创建一系列相关或相互依赖对象的接口,而无需指定它们具体的类。它针对的是产品族的创建。
实现示例
from abc import ABC, abstractmethod
from typing import List
# 抽象产品族 - 按钮
class Button(ABC):
"""按钮接口"""
@abstractmethod
def render(self) -> str:
"""渲染按钮"""
pass
@abstractmethod
def on_click(self) -> str:
"""点击事件"""
pass
# 抽象产品族 - 文本框
class TextBox(ABC):
"""文本框接口"""
@abstractmethod
def render(self) -> str:
"""渲染文本框"""
pass
@abstractmethod
def get_text(self) -> str:
"""获取文本"""
pass
# 抽象产品族 - 对话框
class Dialog(ABC):
"""对话框接口"""
@abstractmethod
def render(self) -> str:
"""渲染对话框"""
pass
@abstractmethod
def show(self) -> str:
"""显示对话框"""
pass
# 具体产品 - Windows风格
class WindowsButton(Button):
def render(self) -> str:
return "🟦 渲染Windows风格按钮(蓝色直角)"
def on_click(self) -> str:
return "Windows按钮被点击,执行.exe命令"
class WindowsTextBox(TextBox):
def __init__(self):
self._text = ""
def render(self) -> str:
return "📝 渲染Windows风格文本框(白色背景,灰色边框)"
def get_text(self) -> str:
return self._text
def set_text(self, text: str):
self._text = text
class WindowsDialog(Dialog):
def render(self) -> str:
return "🪟 渲染Windows风格对话框(标题栏+最小化/最大化/关闭按钮)"
def show(self) -> str:
return "显示Windows对话框,支持Alt+Tab切换"
# 具体产品 - Mac风格
class MacButton(Button):
def render(self) -> str:
return "🟢 渲染Mac风格按钮(绿色圆角)"
def on_click(self) -> str:
return "Mac按钮被点击,执行.App命令"
class MacTextBox(TextBox):
def __init__(self):
self._text = ""
def render(self) -> str:
return "📝 渲染Mac风格文本框(浅灰背景,细边框)"
def get_text(self) -> str:
return self._text
def set_text(self, text: str):
self._text = text
class MacDialog(Dialog):
def render(self) -> str:
return "🍎 渲染Mac风格对话框(无标题栏按钮,红绿灯控制)"
def show(self) -> str:
return "显示Mac对话框,支持Mission Control"
# 具体产品 - Linux风格
class LinuxButton(Button):
def render(self) -> str:
return "🟡 渲染Linux风格按钮(黄色,可自定义主题)"
def on_click(self) -> str:
return "Linux按钮被点击,执行.sh脚本"
class LinuxTextBox(TextBox):
def __init__(self):
self._text = ""
def render(self) -> str:
return "📝 渲染Linux风格文本框(终端风格,黑色背景可选)"
def get_text(self) -> str:
return self._text
def set_text(self, text: str):
self._text = text
class LinuxDialog(Dialog):
def render(self) -> str:
return "🐧 渲染Linux风格对话框(GNOME/KDE样式可选)"
def show(self) -> str:
return "显示Linux对话框,支持窗口管理器特效"
# 抽象工厂接口
class GUIFactory(ABC):
"""GUI工厂接口 - 抽象工厂"""
@abstractmethod
def create_button(self) -> Button:
"""创建按钮"""
pass
@abstractmethod
def create_textbox(self) -> TextBox:
"""创建文本框"""
pass
@abstractmethod
def create_dialog(self) -> Dialog:
"""创建对话框"""
pass
# 具体工厂 - Windows
class WindowsFactory(GUIFactory):
"""Windows GUI工厂"""
def create_button(self) -> Button:
print("创建Windows按钮...")
return WindowsButton()
def create_textbox(self) -> TextBox:
print("创建Windows文本框...")
return WindowsTextBox()
def create_dialog(self) -> Dialog:
print("创建Windows对话框...")
return WindowsDialog()
# 具体工厂 - Mac
class MacFactory(GUIFactory):
"""Mac GUI工厂"""
def create_button(self) -> Button:
print("创建Mac按钮...")
return MacButton()
def create_textbox(self) -> TextBox:
print("创建Mac文本框...")
return MacTextBox()
def create_dialog(self) -> Dialog:
print("创建Mac对话框...")
return MacDialog()
# 具体工厂 - Linux
class LinuxFactory(GUIFactory):
"""Linux GUI工厂"""
def create_button(self) -> Button:
print("创建Linux按钮...")
return LinuxButton()
def create_textbox(self) -> TextBox:
print("创建Linux文本框...")
return LinuxTextBox()
def create_dialog(self) -> Dialog:
print("创建Linux对话框...")
return LinuxDialog()
# 应用配置器
class Application:
"""应用程序 - 演示抽象工厂的使用"""
def __init__(self, factory: GUIFactory):
self.factory = factory
self.button = None
self.textbox = None
self.dialog = None
def initialize_ui(self):
"""初始化用户界面组件"""
print("初始化用户界面...")
self.button = self.factory.create_button()
self.textbox = self.factory.create_textbox()
self.dialog = self.factory.create_dialog()
def render_ui(self):
"""渲染整个界面"""
print("\n=== 渲染用户界面 ===")
print(self.dialog.render())
print(self.button.render())
print(self.textbox.render())
def simulate_user_interaction(self):
"""模拟用户交互"""
print("\n=== 模拟用户交互 ===")
# 设置文本框内容
if hasattr(self.textbox, 'set_text'):
self.textbox.set_text("Hello, Factory Pattern!")
print(f"文本框内容: '{self.textbox.get_text()}'")
# 点击按钮
print(self.button.on_click())
# 显示对话框
print(self.dialog.show())
def switch_theme(self, new_factory: GUIFactory):
"""切换主题(更换工厂)"""
print(f"\n🔄 从 {self.factory.__class__.__name__} 切换到 {new_factory.__class__.__name__}")
self.factory = new_factory
self.initialize_ui()
# 使用示例
print("\n=== 抽象工厂模式演示 ===")
# 创建Windows应用
print("--- Windows风格应用 ---")
windows_factory = WindowsFactory()
windows_app = Application(windows_factory)
windows_app.initialize_ui()
windows_app.render_ui()
windows_app.simulate_user_interaction()
# 切换到Mac主题
print("\n--- 切换到Mac主题 ---")
mac_factory = MacFactory()
windows_app.switch_theme(mac_factory)
windows_app.render_ui()
windows_app.simulate_user_interaction()
# 创建Linux应用
print("\n--- Linux风格应用 ---")
linux_factory = LinuxFactory()
linux_app = Application(linux_factory)
linux_app.initialize_ui()
linux_app.render_ui()
linux_app.simulate_user_interaction()
# 演示产品族一致性
print("\n=== 产品族一致性演示 ===")
print("抽象工厂确保同一工厂创建的组件风格一致")
print("Windows工厂创建的所有组件都是Windows风格")
print("Mac工厂创建的所有组件都是Mac风格")
print("Linux工厂创建的所有组件都是Linux风格")抽象工厂模式的优缺点
优点:
- 确保产品族的一致性:同一工厂创建的产品相互兼容
- 符合开闭原则:新增产品族只需添加新工厂
- 分离客户端和具体实现:客户端面向抽象编程
- 支持多态:运行时可以切换不同的产品族
缺点:
- 新增产品困难:需要修改抽象工厂和所有具体工厂
- 系统复杂度高:抽象层次较多
- 增加了系统的抽象性和理解难度
工厂模式应用场景
1. 数据库连接工厂
from abc import ABC, abstractmethod
from typing import Dict, Any, List
# 数据库连接接口
class DatabaseConnection(ABC):
@abstractmethod
def connect(self) -> bool:
pass
@abstractmethod
def execute(self, query: str) -> List[Dict]:
pass
@abstractmethod
def close(self) -> bool:
pass
# 具体数据库实现
class MySQLConnection(DatabaseConnection):
def __init__(self, host: str, port: int, database: str, username: str, password: str):
self.host = host
self.port = port
self.database = database
self.username = username
self.password = password
self.is_connected = False
def connect(self) -> bool:
print(f"连接MySQL数据库: {self.host}:{self.port}/{self.database}")
self.is_connected = True
return True
def execute(self, query: str) -> List[Dict]:
if not self.is_connected:
print("错误:数据库未连接")
return []
print(f"执行MySQL查询: {query}")
return [{"id": 1, "name": "MySQL查询结果"}]
def close(self) -> bool:
if self.is_connected:
print("关闭MySQL连接")
self.is_connected = False
return True
class PostgreSQLConnection(DatabaseConnection):
def __init__(self, host: str, port: int, database: str, username: str, password: str):
self.host = host
self.port = port
self.database = database
self.username = username
self.password = password
self.is_connected = False
def connect(self) -> bool:
print(f"连接PostgreSQL数据库: {self.host}:{self.port}/{self.database}")
self.is_connected = True
return True
def execute(self, query: str) -> List[Dict]:
if not self.is_connected:
print("错误:数据库未连接")
return []
print(f"执行PostgreSQL查询: {query}")
return [{"id": 1, "name": "PostgreSQL查询结果"}]
def close(self) -> bool:
if self.is_connected:
print("关闭PostgreSQL连接")
self.is_connected = False
return True
# 简单工厂
class DatabaseConnectionFactory:
"""数据库连接工厂"""
# 支持的数据库类型
_database_types = {
"mysql": MySQLConnection,
"postgresql": PostgreSQLConnection
}
@classmethod
def create_connection(cls, db_type: str, **config) -> DatabaseConnection:
"""创建数据库连接"""
db_type = db_type.lower().strip()
if db_type not in cls._database_types:
supported_types = ", ".join(cls._database_types.keys())
raise ValueError(f"不支持的数据库类型: {db_type}。支持的类型: {supported_types}")
db_class = cls._database_types[db_type]
# 提取必要的参数
required_params = ["host", "port", "database", "username", "password"]
missing_params = [p for p in required_params if p not in config]
if missing_params:
raise ValueError(f"缺少必要参数: {missing_params}")
return db_class(
host=config["host"],
port=config["port"],
database=config["database"],
username=config["username"],
password=config["password"]
)
@classmethod
def get_supported_databases(cls) -> List[str]:
"""获取支持的数据库类型"""
return list(cls._database_types.keys())
# 使用示例
print("\n=== 数据库连接工厂应用 ===")
# 数据库配置
mysql_config = {
"host": "localhost",
"port": 3306,
"database": "myapp",
"username": "root",
"password": "password123"
}
postgresql_config = {
"host": "localhost",
"port": 5432,
"database": "myapp_pg",
"username": "postgres",
"password": "password456"
}
# 创建数据库连接
print("创建数据库连接:")
try:
mysql_conn = DatabaseConnectionFactory.create_connection("mysql", **mysql_config)
mysql_conn.connect()
result = mysql_conn.execute("SELECT * FROM users")
mysql_conn.close()
print("\n")
pg_conn = DatabaseConnectionFactory.create_connection("postgresql", **postgresql_config)
pg_conn.connect()
result = pg_conn.execute("SELECT * FROM products")
pg_conn.close()
except ValueError as e:
print(f"错误: {e}")2. API请求工厂
import json
import urllib.request
import urllib.parse
from typing import Dict, Any, Optional
# API客户端接口
class APIClient(ABC):
@abstractmethod
def get(self, endpoint: str, params: Optional[Dict] = None) -> Dict[str, Any]:
"""发送GET请求"""
pass
@abstractmethod
def post(self, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]:
"""发送POST请求"""
pass
@abstractmethod
def set_auth(self, auth_info: Dict[str, str]):
"""设置认证信息"""
pass
# 具体API实现
class GitHubAPIClient(APIClient):
def __init__(self):
self.base_url = "https://api.github.com"
self.auth_info = {}
def get(self, endpoint: str, params: Optional[Dict] = None) -> Dict[str, Any]:
url = f"{self.base_url}/{endpoint.lstrip('/')}"
if params:
query_string = urllib.parse.urlencode(params)
url = f"{url}?{query_string}"
# 模拟API请求
print(f"GitHub API GET: {url}")
return {
"status": "success",
"data": {"message": "GitHub API响应"}
}
def post(self, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]:
url = f"{self.base_url}/{endpoint.lstrip('/')}"
# 模拟API请求
print(f"GitHub API POST: {url}")
if data:
print(f"请求体: {json.dumps(data)}")
return {
"status": "success",
"data": {"message": "GitHub创建成功"}
}
def set_auth(self, auth_info: Dict[str, str]):
self.auth_info = auth_info
print(f"设置GitHub认证: token={auth_info.get('token', '')[:4]}****")
class TwitterAPIClient(APIClient):
def __init__(self):
self.base_url = "https://api.twitter.com/1.1"
self.auth_info = {}
def get(self, endpoint: str, params: Optional[Dict] = None) -> Dict[str, Any]:
url = f"{self.base_url}/{endpoint.lstrip('/')}"
if params:
query_string = urllib.parse.urlencode(params)
url = f"{url}?{query_string}"
# 模拟API请求
print(f"Twitter API GET: {url}")
return {
"status": "success",
"data": {"tweets": ["推文1", "推文2"]}
}
def post(self, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]:
url = f"{self.base_url}/{endpoint.lstrip('/')}"
# 模拟API请求
print(f"Twitter API POST: {url}")
if data:
print(f"请求体: {json.dumps(data)}")
return {
"status": "success",
"data": {"tweet_id": "123456789"}
}
def set_auth(self, auth_info: Dict[str, str]):
self.auth_info = auth_info
print(f"设置Twitter认证: key={auth_info.get('api_key', '')[:4]}****")
# 工厂模式实现
class APIFactory:
"""API客户端工厂 - 简单工厂"""
_supported_apis = {
"github": GitHubAPIClient,
"twitter": TwitterAPIClient
}
@classmethod
def create_client(cls, api_type: str) -> APIClient:
"""创建API客户端"""
api_type = api_type.lower().strip()
if api_type not in cls._supported_apis:
supported_types = ", ".join(cls._supported_apis.keys())
raise ValueError(f"不支持的API类型: {api_type}。支持的类型: {supported_types}")
client_class = cls._supported_apis[api_type]
return client_class()
# 服务类 - 使用工厂创建API客户端
class SocialMediaService:
"""社交媒体服务 - 使用API工厂"""
def __init__(self, api_client: APIClient):
self.client = api_client
def get_user_profile(self, user_id: str) -> Dict[str, Any]:
"""获取用户资料"""
endpoint = f"users/{user_id}"
return self.client.get(endpoint)
def post_message(self, message: str) -> Dict[str, Any]:
"""发布消息"""
return self.client.post("posts", {"content": message})
def authenticate(self, credentials: Dict[str, str]):
"""进行身份认证"""
self.client.set_auth(credentials)
# 使用示例
print("\n=== API请求工厂应用 ===")
# GitHub API
print("--- 使用GitHub API ---")
github_client = APIFactory.create_client("github")
github_service = SocialMediaService(github_client)
# 认证
github_auth = {"token": "ghp_1234567890abcdef"}
github_service.authenticate(github_auth)
# API调用
profile = github_service.get_user_profile("octocat")
post_result = github_service.post_message("Hello from GitHub API!")
print("\n--- 使用Twitter API ---")
twitter_client = APIFactory.create_client("twitter")
twitter_service = SocialMediaService(twitter_client)
# 认证
twitter_auth = {"api_key": "abcd1234efgh5678", "api_secret": "secret123"}
twitter_service.authenticate(twitter_auth)
# API调用
profile = twitter_service.get_user_profile("user123")
post_result = twitter_service.post_message("Hello from Twitter API!")常见错误
错误1:简单工厂中违反开闭原则
# 错误示例:每次新增产品都需修改工厂
class BadPaymentFactory:
# 违反开闭原则:新增支付类型需要修改此类
@classmethod
def create_payment(cls, payment_type):
if payment_type == "alipay":
return Alipay()
elif payment_type == "wechat":
return WechatPay()
# 新增支付类型需要添加新的elif分支
elif payment_type == "bitcoin":
return BitcoinPay() # 每次添加新类型都需要修改这里
else:
raise ValueError("不支持的支付类型")
# 正确做法:使用注册机制
class GoodPaymentFactory:
_payment_types = {}
@classmethod
def register_payment(cls, name, payment_class):
"""注册新的支付类型"""
cls._payment_types[name] = payment_class
print(f"注册支付类型: {name}")
@classmethod
def create_payment(cls, payment_type, **kwargs):
if payment_type not in cls._payment_types:
raise ValueError(f"不支持的支付类型: {payment_type}")
payment_class = cls._payment_types[payment_type]
return payment_class(**kwargs)
# 扩展时只需注册,不需要修改工厂类
GoodPaymentFactory.register_payment("paypal", PayPalPay)
paypal = GoodPaymentFactory.create_payment("paypal")错误2:工厂职责过重
# 错误示例:工厂包含过多业务逻辑
class BadFactory:
@classmethod
def create_user(cls, user_data):
# 工厂职责过重,包含验证、初始化等业务逻辑
if not user_data.get("username"):
raise ValueError("用户名不能为空")
if not user_data.get("email") or "@" not in user_data["email"]:
raise ValueError("邮箱格式不正确")
user = User()
user.username = user_data["username"]
user.email = user_data["email"]
user.password_hash = cls._hash_password(user_data["password"])
user.avatar = cls._generate_avatar(user_data["username"])
user.settings = cls._default_settings()
return user
# 太多的辅助方法
def _hash_password(self, password):
# ...
pass
def _generate_avatar(self, username):
# ...
pass
def _default_settings(self):
# ...
pass
# 正确做法:工厂只负责创建,业务逻辑在其他地方
class GoodFactory:
@classmethod
def create_user(cls, user_data):
"""工厂只负责创建对象"""
return User(user_data)
# 验证逻辑放在专门的类中
class UserValidator:
@staticmethod
def validate(user_data):
if not user_data.get("username"):
raise ValueError("用户名不能为空")
if not user_data.get("email") or "@" not in user_data["email"]:
raise ValueError("邮箱格式不正确")
# 初始化逻辑放在User类或Builder模式中
class UserBuilder:
def __init__(self):
self.user_data = {}
def set_username(self, username):
self.user_data["username"] = username
return self
def set_email(self, email):
self.user_data["email"] = email
return self
def build(self):
UserValidator.validate(self.user_data)
return GoodFactory.create_user(self.user_data)课后练习
创建一个
ShapeFactory简单工厂,能够创建不同形状对象(圆形、矩形、三角形),并实现计算面积和周长的方法。设计一个
DocumentCreator工厂方法模式,支持创建不同类型的文档(PDF、Word、HTML),每个文档类型有不同的格式和导出方式。实现一个抽象工厂模式,创建不同风格的UI组件族(现代风格、经典风格),包括按钮、文本框和滚动条。
创建一个
DataSourceFactory,支持不同类型的数据源连接(文件、数据库、API),并提供统一的CRUD操作接口。设计一个
MessageFormatterFactory,能够创建不同格式的消息格式化器(JSON、XML、YAML),用于API响应的数据序列化。
总结
工厂模式是重要的创建型设计模式:
- 简单工厂:适合产品种类不多且不会频繁变化的场景
- 工厂方法:符合开闭原则,适合需要频繁扩展的场景
- 抽象工厂:适合需要创建产品族、保持产品一致性的场景
- 核心思想是解耦对象的创建与使用,提高系统的灵活性和可扩展性
- 选择合适的工厂模式取决于具体业务需求和未来扩展预期
- 工厂模式与依赖注入常常结合使用,进一步提高系统的解耦程度