第69集:工厂模式

学习目标

  1. 理解工厂模式的概念和作用
  2. 掌握简单工厂、工厂方法和抽象工厂三种实现
  3. 学会根据不同场景选择合适的工厂模式
  4. 了解工厂模式与依赖注入的关系
  5. 能够在实际项目中应用工厂模式解耦对象创建

工厂模式概念

什么是工厂模式

工厂模式(Factory Pattern)是一种创建型设计模式,它提供了一种创建对象的接口,但让子类决定要实例化的类是哪一个。工厂模式将对象的创建与使用分离,提高了代码的灵活性和可扩展性。

为什么要使用工厂模式

  • 解耦创建与使用:客户端不需要知道具体类名,只需知道对应的参数
  • 提高扩展性:新增产品时,不需要修改现有客户端代码
  • 统一管理创建逻辑:集中处理对象的创建、初始化、配置等逻辑
  • 支持多态:客户端面向抽象编程,运行时决定具体实现
  • 便于测试:可以通过工厂模拟对象创建,便于单元测试

工厂模式的三种形式

  1. 简单工厂模式:一个工厂类根据参数创建不同产品
  2. 工厂方法模式:定义创建产品的接口,由子类决定创建哪种产品
  3. 抽象工厂模式:创建相关或依赖对象的家族,而不需要指定具体类

简单工厂模式

基本概念

简单工厂模式由一个工厂类根据传入的参数,动态决定应该创建哪一个产品类的实例。它是最基本的工厂模式,但违反了开闭原则。

实现示例

from abc import ABC, abstractmethod
from typing import Dict, Any

# 产品接口
class Payment(ABC):
    """支付接口 - 抽象产品"""
    
    @abstractmethod
    def pay(self, amount: float) -> str:
        """支付方法"""
        pass
    
    @abstractmethod
    def get_name(self) -> str:
        """获取支付方式名称"""
        pass

# 具体产品类
class Alipay(Payment):
    """支付宝支付"""
    
    def pay(self, amount: float) -> str:
        return f"使用支付宝成功支付 ¥{amount:.2f}"
    
    def get_name(self) -> str:
        return "支付宝"

class WechatPay(Payment):
    """微信支付"""
    
    def pay(self, amount: float) -> str:
        return f"使用微信支付成功支付 ¥{amount:.2f}"
    
    def get_name(self) -> str:
        return "微信支付"

class BankCardPay(Payment):
    """银行卡支付"""
    
    def __init__(self, bank_name: str = "工商银行"):
        self.bank_name = bank_name
    
    def pay(self, amount: float) -> str:
        return f"使用{self.bank_name}银行卡成功支付 ¥{amount:.2f}"
    
    def get_name(self) -> str:
        return f"银行卡({self.bank_name})"

# 简单工厂类
class PaymentFactory:
    """支付工厂 - 简单工厂模式"""
    
    # 支持的支付方式映射
    _payment_types = {
        "alipay": Alipay,
        "wechat": WechatPay,
        "bank": BankCardPay
    }
    
    @classmethod
    def create_payment(cls, payment_type: str, **kwargs) -> Payment:
        """创建支付对象
        
        Args:
            payment_type: 支付类型(alipay, wechat, bank)
            **kwargs: 额外参数,如银行名称等
            
        Returns:
            Payment: 支付对象实例
            
        Raises:
            ValueError: 当支付类型不支持时
        """
        # 标准化支付类型
        payment_type = payment_type.lower().strip()
        
        if payment_type not in cls._payment_types:
            supported_types = ", ".join(cls._payment_types.keys())
            raise ValueError(f"不支持的支付类型: {payment_type}。支持的类型: {supported_types}")
        
        # 获取对应的支付类
        payment_class = cls._payment_types[payment_type]
        
        # 创建实例(处理特殊参数)
        if payment_type == "bank" and "bank_name" in kwargs:
            return payment_class(kwargs["bank_name"])
        else:
            return payment_class()
    
    @classmethod
    def get_supported_types(cls) -> list:
        """获取支持的支付类型"""
        return list(cls._payment_types.keys())

# 使用示例
print("=== 简单工厂模式演示 ===")

# 客户端代码 - 不需要知道具体类名
print("创建不同的支付对象:")
try:
    alipay = PaymentFactory.create_payment("alipay")
    print(f"创建成功: {alipay.get_name()}")
    
    wechat_pay = PaymentFactory.create_payment("wechat")
    print(f"创建成功: {wechat_pay.get_name()}")
    
    icbc_pay = PaymentFactory.create_payment("bank", bank_name="工商银行")
    print(f"创建成功: {icbc_pay.get_name()}")
    
    cmb_pay = PaymentFactory.create_payment("bank", bank_name="招商银行")
    print(f"创建成功: {cmb_pay.get_name()}")
    
except ValueError as e:
    print(f"创建失败: {e}")

# 使用支付功能
print("\n执行支付操作:")
payments = [
    PaymentFactory.create_payment("alipay"),
    PaymentFactory.create_payment("wechat"),
    PaymentFactory.create_payment("bank", bank_name="建设银行")
]

for payment in payments:
    result = payment.pay(100.0)
    print(f"  {result}")

# 查看支持的支付类型
print(f"\n支持的支付类型: {PaymentFactory.get_supported_types()}")

# 测试错误处理
print("\n测试错误处理:")
try:
    invalid_pay = PaymentFactory.create_payment("bitcoin")
except ValueError as e:
    print(f"预期的错误: {e}")

简单工厂模式的优缺点

优点:

  • 客户端无需知道具体产品类名,只需知道参数
  • 实现了创建逻辑与使用逻辑的分离
  • 代码结构简单,易于理解和维护

缺点:

  • 违反开闭原则:新增产品时需要修改工厂类
  • 工厂类职责过重,不易于扩展
  • 当产品较多时,工厂类会变得庞大

工厂方法模式

基本概念

工厂方法模式定义了一个创建对象的接口,但让子类决定实例化哪一个类。工厂方法使一个类的实例化延迟到其子类。

实现示例

from abc import ABC, abstractmethod
from typing import List

# 产品接口
class LoggerFactory(ABC):
    """日志工厂接口 - 抽象工厂"""
    
    @abstractmethod
    def create_logger(self):
        """创建日志记录器"""
        pass
    
    def log_info(self, message: str):
        """记录信息日志 - 通用方法"""
        logger = self.create_logger()
        logger.info(message)
    
    def log_error(self, message: str):
        """记录错误日志 - 通用方法"""
        logger = self.create_logger()
        logger.error(message)

# 具体产品类
class FileLogger:
    """文件日志记录器"""
    
    def info(self, message: str):
        print(f"[文件日志] INFO: {message}")
    
    def error(self, message: str):
        print(f"[文件日志] ERROR: {message}")
    
    def get_type(self) -> str:
        return "文件日志"

class DatabaseLogger:
    """数据库日志记录器"""
    
    def info(self, message: str):
        print(f"[数据库日志] INFO: {message}")
    
    def error(self, message: str):
        print(f"[数据库日志] ERROR: {message}")
    
    def get_type(self) -> str:
        return "数据库日志"

class ConsoleLogger:
    """控制台日志记录器"""
    
    def info(self, message: str):
        print(f"[控制台日志] INFO: {message}")
    
    def error(self, message: str):
        print(f"[控制台日志] ERROR: {message}")
    
    def get_type(self) -> str:
        return "控制台日志"

# 具体工厂类
class FileLoggerFactory(LoggerFactory):
    """文件日志工厂"""
    
    def create_logger(self):
        print("创建文件日志记录器...")
        # 这里可以添加文件日志的初始化逻辑
        return FileLogger()

class DatabaseLoggerFactory(LoggerFactory):
    """数据库日志工厂"""
    
    def __init__(self, connection_string: str = "default_db"):
        self.connection_string = connection_string
    
    def create_logger(self):
        print(f"创建数据库日志记录器,连接: {self.connection_string}...")
        # 这里可以添加数据库连接的初始化逻辑
        return DatabaseLogger()

class ConsoleLoggerFactory(LoggerFactory):
    """控制台日志工厂"""
    
    def __init__(self, colored_output: bool = True):
        self.colored_output = colored_output
    
    def create_logger(self):
        print(f"创建控制台日志记录器,彩色输出: {self.colored_output}...")
        return ConsoleLogger()

# 日志管理器(客户端)
class LogManager:
    """日志管理器 - 演示如何使用工厂方法模式"""
    
    def __init__(self, factory: LoggerFactory):
        self.factory = factory
    
    def set_factory(self, factory: LoggerFactory):
        """切换日志工厂"""
        print(f"切换日志工厂从 {self.factory.create_logger().get_type()} 到 {factory.create_logger().get_type()}")
        self.factory = factory
    
    def process_business_operation(self, operation: str):
        """模拟业务操作,记录日志"""
        try:
            print(f"执行操作: {operation}")
            # 模拟业务逻辑
            if "error" in operation.lower():
                raise ValueError("模拟业务错误")
            
            # 记录成功日志
            self.factory.log_info(f"操作 '{operation}' 执行成功")
            
        except Exception as e:
            # 记录错误日志
            self.factory.log_error(f"操作 '{operation}' 执行失败: {str(e)}")

# 使用示例
print("\n=== 工厂方法模式演示 ===")

# 创建不同的工厂
file_factory = FileLoggerFactory()
database_factory = DatabaseLoggerFactory("mysql://localhost/logs")
console_factory = ConsoleLoggerFactory(colored_output=True)

# 使用文件日志
print("--- 使用文件日志 ---")
log_manager = LogManager(file_factory)
log_manager.process_business_operation("用户登录")
log_manager.process_business_operation("数据保存")

# 切换到数据库日志
print("\n--- 切换到数据库日志 ---")
log_manager.set_factory(database_factory)
log_manager.process_business_operation("订单处理")
log_manager.process_business_operation("支付操作")

# 切换到控制台日志
print("\n--- 切换到控制台日志 ---")
log_manager.set_factory(console_factory)
log_manager.process_business_operation("系统启动")
log_manager.process_business_operation("模拟错误操作")

# 演示扩展性 - 新增日志类型不需要修改现有代码
class RemoteLogger:
    """远程日志服务(新增的产品)"""
    
    def info(self, message: str):
        print(f"[远程日志服务] INFO: {message}")
    
    def error(self, message: str):
        print(f"[远程日志服务] ERROR: {message}")
    
    def get_type(self) -> str:
        return "远程日志服务"

class RemoteLoggerFactory(LoggerFactory):
    """远程日志工厂(新增的工厂)"""
    
    def __init__(self, endpoint: str = "https://logs.example.com"):
        self.endpoint = endpoint
    
    def create_logger(self):
        print(f"创建远程日志服务,端点: {self.endpoint}...")
        return RemoteLogger()

print("\n--- 使用新增的远程日志 ---")
remote_factory = RemoteLoggerFactory("https://myapp-logs.company.com")
log_manager.set_factory(remote_factory)
log_manager.process_business_operation("API调用")
print("✅ 新增日志类型无需修改现有代码,符合开闭原则!")

工厂方法模式的优缺点

优点:

  • 符合开闭原则:新增产品时只需添加新的工厂类
  • 符合单一职责原则:每个工厂只负责创建一种产品
  • 支持多态:客户端面向抽象编程
  • 易于扩展:新增产品不影响现有代码

缺点:

  • 类的数量增多:每个产品都需要对应的工厂类
  • 增加了系统的抽象性和理解难度
  • 增加了系统的复杂度

抽象工厂模式

基本概念

抽象工厂模式提供一个创建一系列相关或相互依赖对象的接口,而无需指定它们具体的类。它针对的是产品族的创建。

实现示例

from abc import ABC, abstractmethod
from typing import List

# 抽象产品族 - 按钮
class Button(ABC):
    """按钮接口"""
    
    @abstractmethod
    def render(self) -> str:
        """渲染按钮"""
        pass
    
    @abstractmethod
    def on_click(self) -> str:
        """点击事件"""
        pass

# 抽象产品族 - 文本框
class TextBox(ABC):
    """文本框接口"""
    
    @abstractmethod
    def render(self) -> str:
        """渲染文本框"""
        pass
    
    @abstractmethod
    def get_text(self) -> str:
        """获取文本"""
        pass

# 抽象产品族 - 对话框
class Dialog(ABC):
    """对话框接口"""
    
    @abstractmethod
    def render(self) -> str:
        """渲染对话框"""
        pass
    
    @abstractmethod
    def show(self) -> str:
        """显示对话框"""
        pass

# 具体产品 - Windows风格
class WindowsButton(Button):
    def render(self) -> str:
        return "🟦 渲染Windows风格按钮(蓝色直角)"
    
    def on_click(self) -> str:
        return "Windows按钮被点击,执行.exe命令"

class WindowsTextBox(TextBox):
    def __init__(self):
        self._text = ""
    
    def render(self) -> str:
        return "📝 渲染Windows风格文本框(白色背景,灰色边框)"
    
    def get_text(self) -> str:
        return self._text
    
    def set_text(self, text: str):
        self._text = text

class WindowsDialog(Dialog):
    def render(self) -> str:
        return "🪟 渲染Windows风格对话框(标题栏+最小化/最大化/关闭按钮)"
    
    def show(self) -> str:
        return "显示Windows对话框,支持Alt+Tab切换"

# 具体产品 - Mac风格
class MacButton(Button):
    def render(self) -> str:
        return "🟢 渲染Mac风格按钮(绿色圆角)"
    
    def on_click(self) -> str:
        return "Mac按钮被点击,执行.App命令"

class MacTextBox(TextBox):
    def __init__(self):
        self._text = ""
    
    def render(self) -> str:
        return "📝 渲染Mac风格文本框(浅灰背景,细边框)"
    
    def get_text(self) -> str:
        return self._text
    
    def set_text(self, text: str):
        self._text = text

class MacDialog(Dialog):
    def render(self) -> str:
        return "🍎 渲染Mac风格对话框(无标题栏按钮,红绿灯控制)"
    
    def show(self) -> str:
        return "显示Mac对话框,支持Mission Control"

# 具体产品 - Linux风格
class LinuxButton(Button):
    def render(self) -> str:
        return "🟡 渲染Linux风格按钮(黄色,可自定义主题)"
    
    def on_click(self) -> str:
        return "Linux按钮被点击,执行.sh脚本"

class LinuxTextBox(TextBox):
    def __init__(self):
        self._text = ""
    
    def render(self) -> str:
        return "📝 渲染Linux风格文本框(终端风格,黑色背景可选)"
    
    def get_text(self) -> str:
        return self._text
    
    def set_text(self, text: str):
        self._text = text

class LinuxDialog(Dialog):
    def render(self) -> str:
        return "🐧 渲染Linux风格对话框(GNOME/KDE样式可选)"
    
    def show(self) -> str:
        return "显示Linux对话框,支持窗口管理器特效"

# 抽象工厂接口
class GUIFactory(ABC):
    """GUI工厂接口 - 抽象工厂"""
    
    @abstractmethod
    def create_button(self) -> Button:
        """创建按钮"""
        pass
    
    @abstractmethod
    def create_textbox(self) -> TextBox:
        """创建文本框"""
        pass
    
    @abstractmethod
    def create_dialog(self) -> Dialog:
        """创建对话框"""
        pass

# 具体工厂 - Windows
class WindowsFactory(GUIFactory):
    """Windows GUI工厂"""
    
    def create_button(self) -> Button:
        print("创建Windows按钮...")
        return WindowsButton()
    
    def create_textbox(self) -> TextBox:
        print("创建Windows文本框...")
        return WindowsTextBox()
    
    def create_dialog(self) -> Dialog:
        print("创建Windows对话框...")
        return WindowsDialog()

# 具体工厂 - Mac
class MacFactory(GUIFactory):
    """Mac GUI工厂"""
    
    def create_button(self) -> Button:
        print("创建Mac按钮...")
        return MacButton()
    
    def create_textbox(self) -> TextBox:
        print("创建Mac文本框...")
        return MacTextBox()
    
    def create_dialog(self) -> Dialog:
        print("创建Mac对话框...")
        return MacDialog()

# 具体工厂 - Linux
class LinuxFactory(GUIFactory):
    """Linux GUI工厂"""
    
    def create_button(self) -> Button:
        print("创建Linux按钮...")
        return LinuxButton()
    
    def create_textbox(self) -> TextBox:
        print("创建Linux文本框...")
        return LinuxTextBox()
    
    def create_dialog(self) -> Dialog:
        print("创建Linux对话框...")
        return LinuxDialog()

# 应用配置器
class Application:
    """应用程序 - 演示抽象工厂的使用"""
    
    def __init__(self, factory: GUIFactory):
        self.factory = factory
        self.button = None
        self.textbox = None
        self.dialog = None
    
    def initialize_ui(self):
        """初始化用户界面组件"""
        print("初始化用户界面...")
        self.button = self.factory.create_button()
        self.textbox = self.factory.create_textbox()
        self.dialog = self.factory.create_dialog()
    
    def render_ui(self):
        """渲染整个界面"""
        print("\n=== 渲染用户界面 ===")
        print(self.dialog.render())
        print(self.button.render())
        print(self.textbox.render())
    
    def simulate_user_interaction(self):
        """模拟用户交互"""
        print("\n=== 模拟用户交互 ===")
        
        # 设置文本框内容
        if hasattr(self.textbox, 'set_text'):
            self.textbox.set_text("Hello, Factory Pattern!")
            print(f"文本框内容: '{self.textbox.get_text()}'")
        
        # 点击按钮
        print(self.button.on_click())
        
        # 显示对话框
        print(self.dialog.show())
    
    def switch_theme(self, new_factory: GUIFactory):
        """切换主题(更换工厂)"""
        print(f"\n🔄 从 {self.factory.__class__.__name__} 切换到 {new_factory.__class__.__name__}")
        self.factory = new_factory
        self.initialize_ui()

# 使用示例
print("\n=== 抽象工厂模式演示 ===")

# 创建Windows应用
print("--- Windows风格应用 ---")
windows_factory = WindowsFactory()
windows_app = Application(windows_factory)
windows_app.initialize_ui()
windows_app.render_ui()
windows_app.simulate_user_interaction()

# 切换到Mac主题
print("\n--- 切换到Mac主题 ---")
mac_factory = MacFactory()
windows_app.switch_theme(mac_factory)
windows_app.render_ui()
windows_app.simulate_user_interaction()

# 创建Linux应用
print("\n--- Linux风格应用 ---")
linux_factory = LinuxFactory()
linux_app = Application(linux_factory)
linux_app.initialize_ui()
linux_app.render_ui()
linux_app.simulate_user_interaction()

# 演示产品族一致性
print("\n=== 产品族一致性演示 ===")
print("抽象工厂确保同一工厂创建的组件风格一致")
print("Windows工厂创建的所有组件都是Windows风格")
print("Mac工厂创建的所有组件都是Mac风格")
print("Linux工厂创建的所有组件都是Linux风格")

抽象工厂模式的优缺点

优点:

  • 确保产品族的一致性:同一工厂创建的产品相互兼容
  • 符合开闭原则:新增产品族只需添加新工厂
  • 分离客户端和具体实现:客户端面向抽象编程
  • 支持多态:运行时可以切换不同的产品族

缺点:

  • 新增产品困难:需要修改抽象工厂和所有具体工厂
  • 系统复杂度高:抽象层次较多
  • 增加了系统的抽象性和理解难度

工厂模式应用场景

1. 数据库连接工厂

from abc import ABC, abstractmethod
from typing import Dict, Any, List

# 数据库连接接口
class DatabaseConnection(ABC):
    @abstractmethod
    def connect(self) -> bool:
        pass
    
    @abstractmethod
    def execute(self, query: str) -> List[Dict]:
        pass
    
    @abstractmethod
    def close(self) -> bool:
        pass

# 具体数据库实现
class MySQLConnection(DatabaseConnection):
    def __init__(self, host: str, port: int, database: str, username: str, password: str):
        self.host = host
        self.port = port
        self.database = database
        self.username = username
        self.password = password
        self.is_connected = False
    
    def connect(self) -> bool:
        print(f"连接MySQL数据库: {self.host}:{self.port}/{self.database}")
        self.is_connected = True
        return True
    
    def execute(self, query: str) -> List[Dict]:
        if not self.is_connected:
            print("错误:数据库未连接")
            return []
        print(f"执行MySQL查询: {query}")
        return [{"id": 1, "name": "MySQL查询结果"}]
    
    def close(self) -> bool:
        if self.is_connected:
            print("关闭MySQL连接")
            self.is_connected = False
        return True

class PostgreSQLConnection(DatabaseConnection):
    def __init__(self, host: str, port: int, database: str, username: str, password: str):
        self.host = host
        self.port = port
        self.database = database
        self.username = username
        self.password = password
        self.is_connected = False
    
    def connect(self) -> bool:
        print(f"连接PostgreSQL数据库: {self.host}:{self.port}/{self.database}")
        self.is_connected = True
        return True
    
    def execute(self, query: str) -> List[Dict]:
        if not self.is_connected:
            print("错误:数据库未连接")
            return []
        print(f"执行PostgreSQL查询: {query}")
        return [{"id": 1, "name": "PostgreSQL查询结果"}]
    
    def close(self) -> bool:
        if self.is_connected:
            print("关闭PostgreSQL连接")
            self.is_connected = False
        return True

# 简单工厂
class DatabaseConnectionFactory:
    """数据库连接工厂"""
    
    # 支持的数据库类型
    _database_types = {
        "mysql": MySQLConnection,
        "postgresql": PostgreSQLConnection
    }
    
    @classmethod
    def create_connection(cls, db_type: str, **config) -> DatabaseConnection:
        """创建数据库连接"""
        db_type = db_type.lower().strip()
        
        if db_type not in cls._database_types:
            supported_types = ", ".join(cls._database_types.keys())
            raise ValueError(f"不支持的数据库类型: {db_type}。支持的类型: {supported_types}")
        
        db_class = cls._database_types[db_type]
        
        # 提取必要的参数
        required_params = ["host", "port", "database", "username", "password"]
        missing_params = [p for p in required_params if p not in config]
        if missing_params:
            raise ValueError(f"缺少必要参数: {missing_params}")
        
        return db_class(
            host=config["host"],
            port=config["port"],
            database=config["database"],
            username=config["username"],
            password=config["password"]
        )
    
    @classmethod
    def get_supported_databases(cls) -> List[str]:
        """获取支持的数据库类型"""
        return list(cls._database_types.keys())

# 使用示例
print("\n=== 数据库连接工厂应用 ===")

# 数据库配置
mysql_config = {
    "host": "localhost",
    "port": 3306,
    "database": "myapp",
    "username": "root",
    "password": "password123"
}

postgresql_config = {
    "host": "localhost",
    "port": 5432,
    "database": "myapp_pg",
    "username": "postgres",
    "password": "password456"
}

# 创建数据库连接
print("创建数据库连接:")
try:
    mysql_conn = DatabaseConnectionFactory.create_connection("mysql", **mysql_config)
    mysql_conn.connect()
    result = mysql_conn.execute("SELECT * FROM users")
    mysql_conn.close()
    
    print("\n")
    pg_conn = DatabaseConnectionFactory.create_connection("postgresql", **postgresql_config)
    pg_conn.connect()
    result = pg_conn.execute("SELECT * FROM products")
    pg_conn.close()
    
except ValueError as e:
    print(f"错误: {e}")

2. API请求工厂

import json
import urllib.request
import urllib.parse
from typing import Dict, Any, Optional

# API客户端接口
class APIClient(ABC):
    @abstractmethod
    def get(self, endpoint: str, params: Optional[Dict] = None) -> Dict[str, Any]:
        """发送GET请求"""
        pass
    
    @abstractmethod
    def post(self, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]:
        """发送POST请求"""
        pass
    
    @abstractmethod
    def set_auth(self, auth_info: Dict[str, str]):
        """设置认证信息"""
        pass

# 具体API实现
class GitHubAPIClient(APIClient):
    def __init__(self):
        self.base_url = "https://api.github.com"
        self.auth_info = {}
    
    def get(self, endpoint: str, params: Optional[Dict] = None) -> Dict[str, Any]:
        url = f"{self.base_url}/{endpoint.lstrip('/')}"
        
        if params:
            query_string = urllib.parse.urlencode(params)
            url = f"{url}?{query_string}"
        
        # 模拟API请求
        print(f"GitHub API GET: {url}")
        return {
            "status": "success",
            "data": {"message": "GitHub API响应"}
        }
    
    def post(self, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]:
        url = f"{self.base_url}/{endpoint.lstrip('/')}"
        
        # 模拟API请求
        print(f"GitHub API POST: {url}")
        if data:
            print(f"请求体: {json.dumps(data)}")
        
        return {
            "status": "success",
            "data": {"message": "GitHub创建成功"}
        }
    
    def set_auth(self, auth_info: Dict[str, str]):
        self.auth_info = auth_info
        print(f"设置GitHub认证: token={auth_info.get('token', '')[:4]}****")

class TwitterAPIClient(APIClient):
    def __init__(self):
        self.base_url = "https://api.twitter.com/1.1"
        self.auth_info = {}
    
    def get(self, endpoint: str, params: Optional[Dict] = None) -> Dict[str, Any]:
        url = f"{self.base_url}/{endpoint.lstrip('/')}"
        
        if params:
            query_string = urllib.parse.urlencode(params)
            url = f"{url}?{query_string}"
        
        # 模拟API请求
        print(f"Twitter API GET: {url}")
        return {
            "status": "success",
            "data": {"tweets": ["推文1", "推文2"]}
        }
    
    def post(self, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]:
        url = f"{self.base_url}/{endpoint.lstrip('/')}"
        
        # 模拟API请求
        print(f"Twitter API POST: {url}")
        if data:
            print(f"请求体: {json.dumps(data)}")
        
        return {
            "status": "success",
            "data": {"tweet_id": "123456789"}
        }
    
    def set_auth(self, auth_info: Dict[str, str]):
        self.auth_info = auth_info
        print(f"设置Twitter认证: key={auth_info.get('api_key', '')[:4]}****")

# 工厂模式实现
class APIFactory:
    """API客户端工厂 - 简单工厂"""
    
    _supported_apis = {
        "github": GitHubAPIClient,
        "twitter": TwitterAPIClient
    }
    
    @classmethod
    def create_client(cls, api_type: str) -> APIClient:
        """创建API客户端"""
        api_type = api_type.lower().strip()
        
        if api_type not in cls._supported_apis:
            supported_types = ", ".join(cls._supported_apis.keys())
            raise ValueError(f"不支持的API类型: {api_type}。支持的类型: {supported_types}")
        
        client_class = cls._supported_apis[api_type]
        return client_class()

# 服务类 - 使用工厂创建API客户端
class SocialMediaService:
    """社交媒体服务 - 使用API工厂"""
    
    def __init__(self, api_client: APIClient):
        self.client = api_client
    
    def get_user_profile(self, user_id: str) -> Dict[str, Any]:
        """获取用户资料"""
        endpoint = f"users/{user_id}"
        return self.client.get(endpoint)
    
    def post_message(self, message: str) -> Dict[str, Any]:
        """发布消息"""
        return self.client.post("posts", {"content": message})
    
    def authenticate(self, credentials: Dict[str, str]):
        """进行身份认证"""
        self.client.set_auth(credentials)

# 使用示例
print("\n=== API请求工厂应用 ===")

# GitHub API
print("--- 使用GitHub API ---")
github_client = APIFactory.create_client("github")
github_service = SocialMediaService(github_client)

# 认证
github_auth = {"token": "ghp_1234567890abcdef"}
github_service.authenticate(github_auth)

# API调用
profile = github_service.get_user_profile("octocat")
post_result = github_service.post_message("Hello from GitHub API!")

print("\n--- 使用Twitter API ---")
twitter_client = APIFactory.create_client("twitter")
twitter_service = SocialMediaService(twitter_client)

# 认证
twitter_auth = {"api_key": "abcd1234efgh5678", "api_secret": "secret123"}
twitter_service.authenticate(twitter_auth)

# API调用
profile = twitter_service.get_user_profile("user123")
post_result = twitter_service.post_message("Hello from Twitter API!")

常见错误

错误1:简单工厂中违反开闭原则

# 错误示例:每次新增产品都需修改工厂
class BadPaymentFactory:
    # 违反开闭原则:新增支付类型需要修改此类
    @classmethod
    def create_payment(cls, payment_type):
        if payment_type == "alipay":
            return Alipay()
        elif payment_type == "wechat":
            return WechatPay()
        # 新增支付类型需要添加新的elif分支
        elif payment_type == "bitcoin":
            return BitcoinPay()  # 每次添加新类型都需要修改这里
        else:
            raise ValueError("不支持的支付类型")

# 正确做法:使用注册机制
class GoodPaymentFactory:
    _payment_types = {}
    
    @classmethod
    def register_payment(cls, name, payment_class):
        """注册新的支付类型"""
        cls._payment_types[name] = payment_class
        print(f"注册支付类型: {name}")
    
    @classmethod
    def create_payment(cls, payment_type, **kwargs):
        if payment_type not in cls._payment_types:
            raise ValueError(f"不支持的支付类型: {payment_type}")
        
        payment_class = cls._payment_types[payment_type]
        return payment_class(**kwargs)

# 扩展时只需注册,不需要修改工厂类
GoodPaymentFactory.register_payment("paypal", PayPalPay)
paypal = GoodPaymentFactory.create_payment("paypal")

错误2:工厂职责过重

# 错误示例:工厂包含过多业务逻辑
class BadFactory:
    @classmethod
    def create_user(cls, user_data):
        # 工厂职责过重,包含验证、初始化等业务逻辑
        if not user_data.get("username"):
            raise ValueError("用户名不能为空")
        if not user_data.get("email") or "@" not in user_data["email"]:
            raise ValueError("邮箱格式不正确")
        
        user = User()
        user.username = user_data["username"]
        user.email = user_data["email"]
        user.password_hash = cls._hash_password(user_data["password"])
        user.avatar = cls._generate_avatar(user_data["username"])
        user.settings = cls._default_settings()
        return user
    
    # 太多的辅助方法
    def _hash_password(self, password):
        # ...
        pass
    
    def _generate_avatar(self, username):
        # ...
        pass
    
    def _default_settings(self):
        # ...
        pass

# 正确做法:工厂只负责创建,业务逻辑在其他地方
class GoodFactory:
    @classmethod
    def create_user(cls, user_data):
        """工厂只负责创建对象"""
        return User(user_data)

# 验证逻辑放在专门的类中
class UserValidator:
    @staticmethod
    def validate(user_data):
        if not user_data.get("username"):
            raise ValueError("用户名不能为空")
        if not user_data.get("email") or "@" not in user_data["email"]:
            raise ValueError("邮箱格式不正确")

# 初始化逻辑放在User类或Builder模式中
class UserBuilder:
    def __init__(self):
        self.user_data = {}
    
    def set_username(self, username):
        self.user_data["username"] = username
        return self
    
    def set_email(self, email):
        self.user_data["email"] = email
        return self
    
    def build(self):
        UserValidator.validate(self.user_data)
        return GoodFactory.create_user(self.user_data)

课后练习

  1. 创建一个ShapeFactory简单工厂,能够创建不同形状对象(圆形、矩形、三角形),并实现计算面积和周长的方法。

  2. 设计一个DocumentCreator工厂方法模式,支持创建不同类型的文档(PDF、Word、HTML),每个文档类型有不同的格式和导出方式。

  3. 实现一个抽象工厂模式,创建不同风格的UI组件族(现代风格、经典风格),包括按钮、文本框和滚动条。

  4. 创建一个DataSourceFactory,支持不同类型的数据源连接(文件、数据库、API),并提供统一的CRUD操作接口。

  5. 设计一个MessageFormatterFactory,能够创建不同格式的消息格式化器(JSON、XML、YAML),用于API响应的数据序列化。

总结

工厂模式是重要的创建型设计模式:

  • 简单工厂:适合产品种类不多且不会频繁变化的场景
  • 工厂方法:符合开闭原则,适合需要频繁扩展的场景
  • 抽象工厂:适合需要创建产品族、保持产品一致性的场景
  • 核心思想是解耦对象的创建与使用,提高系统的灵活性和可扩展性
  • 选择合适的工厂模式取决于具体业务需求和未来扩展预期
  • 工厂模式与依赖注入常常结合使用,进一步提高系统的解耦程度
« 上一篇 单例模式 下一篇 » 面向对象设计原则