第66集:枚举类型

学习目标

  1. 理解枚举类型的概念和作用
  2. 掌握enum模块的基本使用方法
  3. 学会创建枚举类和枚举成员
  4. 了解枚举的高级特性和应用场景

枚举类型概念

什么是枚举

枚举(Enumeration)是一种数据类型,它包含一组命名的常量值。枚举让代码更加清晰、可读,并且可以避免使用魔法数字或字符串。

为什么使用枚举

  • 提高代码可读性:用有意义的名称代替数字或字符串
  • 类型安全:避免无效的赋值
  • 便于维护:集中管理相关常量
  • 自动补全支持:IDE可以提供更好的代码提示

enum模块基础

Enum类的基本使用

创建简单枚举

from enum import Enum

class Color(Enum):
    """颜色枚举"""
    RED = 1
    GREEN = 2
    BLUE = 3
    YELLOW = 4

# 使用枚举成员
print(Color.RED)           # 输出: Color.RED
print(Color.RED.value)     # 输出: 1
print(Color.RED.name)      # 输出: RED
print(Color(1))            # 输出: Color.RED
print(Color['RED'])        # 输出: Color.RED

枚举成员的特性

from enum import Enum

class Weekday(Enum):
    """星期枚举"""
    MONDAY = 1
    TUESDAY = 2
    WEDNESDAY = 3
    THURSDAY = 4
    FRIDAY = 5
    SATURDAY = 6
    SUNDAY = 7

# 枚举成员是不可变的
weekday = Weekday.MONDAY
print(weekday.value)        # 输出: 1
print(weekday.name)         # 输出: MONDAY

# 可以通过值获取枚举成员
monday = Weekday(1)
print(monday)              # 输出: Weekday.MONDAY

# 可以通过名称获取枚举成员
sunday = Weekday['SUNDAY']
print(sunday)              # 输出: Weekday.SUNDAY

枚举的遍历

from enum import Enum

class Season(Enum):
    SPRING = 1
    SUMMER = 2
    AUTUMN = 3
    WINTER = 4

# 遍历所有枚举成员
for season in Season:
    print(f"{season.name} = {season.value}")

# 输出:
# SPRING = 1
# SUMMER = 2
# AUTUMN = 3
# WINTER = 4

# 获取所有枚举成员
members = list(Season)
print(f"所有季节: {members}")

# 获取所有枚举值
values = [member.value for member in Season]
print(f"所有值: {values}")

# 获取所有枚举名称
names = [member.name for member in Season]
print(f"所有名称: {names}")

枚举的高级特性

自动赋值

使用auto()自动生成值

from enum import Enum, auto

class Priority(Enum):
    """优先级枚举 - 使用auto自动赋值"""
    LOW = auto()
    MEDIUM = auto()
    HIGH = auto()
    URGENT = auto()

for priority in Priority:
    print(f"{priority.name} = {priority.value}")
# 输出:
# LOW = 1
# MEDIUM = 2
# HIGH = 3
# URGENT = 4

自定义自动赋值函数

from enum import Enum

def calculate_value(name):
    """根据名称计算值"""
    return len(name)

class Size(Enum):
    """尺寸枚举 - 自定义计算值"""
    SMALL = calculate_value('SMALL')
    MEDIUM = calculate_value('MEDIUM')
    LARGE = calculate_value('LARGE')
    EXTRA_LARGE = calculate_value('EXTRA_LARGE')

for size in Size:
    print(f"{size.name} = {size.value}")

唯一性约束

防止重复值

from enum import Enum, unique

@unique  # 装饰器确保值唯一
class Status(Enum):
    """状态枚举 - 确保值唯一"""
    PENDING = 1
    APPROVED = 2
    REJECTED = 3
    CANCELLED = 4
    # DUPLICATE = 1  # 这行会报错: duplicate values found

# 如果没有@unique,可以有重复值
class HttpStatusWithoutUnique(Enum):
    OK = 200
    CREATED = 201
    ACCEPTED = 202
    SUCCESS = 200  # 允许重复值,但访问时会返回第一个匹配的

print(HttpStatusWithoutUnique(200))  # 输出: HttpStatusWithoutUnique.OK

枚举的比较

枚举成员比较

from enum import Enum

class Direction(Enum):
    NORTH = 1
    SOUTH = 2
    EAST = 3
    WEST = 4

# 枚举成员可以比较身份(同一性)
print(Direction.NORTH is Direction.NORTH)    # 输出: True
print(Direction.NORTH is Direction.SOUTH)    # 输出: False

# 枚举成员不能直接比较值(需要显式比较value)
print(Direction.NORTH == Direction.NORTH)   # 输出: True
# print(Direction.NORTH == 1)               # 这会报错
print(Direction.NORTH.value == 1)          # 输出: True

# 枚举成员之间的相等比较
north = Direction.NORTH
same_north = Direction(1)
print(north == same_north)                 # 输出: True

自定义比较行为

from enum import Enum

class Number(Enum):
    ONE = 1
    TWO = 2
    THREE = 3
    
    def __eq__(self, other):
        """自定义相等比较"""
        if isinstance(other, Number):
            return self.value == other.value
        elif isinstance(other, int):
            return self.value == other
        return False
    
    def __lt__(self, other):
        """自定义小于比较"""
        if isinstance(other, Number):
            return self.value < other.value
        elif isinstance(other, int):
            return self.value < other
        return NotImplemented

num = Number.TWO
print(num == 2)      # 输出: True
print(num < 3)       # 输出: True
print(num > Number.ONE)  # 输出: True

枚举的方法

自定义枚举方法

from enum import Enum

class Planet(Enum):
    """行星枚举 - 包含自定义方法"""
    MERCURY = (3.303e+23, 2.4397e6)
    VENUS = (4.869e+24, 6.0518e6)
    EARTH = (5.976e+24, 6.37814e6)
    MARS = (6.421e+23, 3.3972e6)
    JUPITER = (1.9e+27, 7.1492e7)
    SATURN = (5.688e+26, 6.0268e7)
    URANUS = (8.686e+25, 2.5559e7)
    NEPTUNE = (1.024e+26, 2.4746e7)
    
    def __init__(self, mass, radius):
        self.mass = mass      # 质量 (kg)
        self.radius = radius  # 半径 (m)
    
    def surface_gravity(self):
        """计算表面重力"""
        G = 6.67300E-11  # 万有引力常数
        return G * self.mass / (self.radius ** 2)
    
    def surface_weight(self, other_mass):
        """计算在其他行星上的重量"""
        return other_mass * self.surface_gravity()
    
    @classmethod
    def get_by_name(cls, name):
        """通过名称获取行星"""
        try:
            return cls[name.upper()]
        except KeyError:
            return None

# 使用示例
earth = Planet.EARTH
print(f"地球表面重力: {earth.surface_gravity():.2f} m/s²")
print(f"70kg的人在地球上的重量: {earth.surface_weight(70):.2f} N")

mars = Planet.MARS
print(f"70kg的人在火星上的重量: {mars.surface_weight(70):.2f} N")

# 使用类方法
found_planet = Planet.get_by_name("earth")
print(f"找到的行星: {found_planet}")

应用案例

案例1:HTTP状态码管理

from enum import Enum

class HttpStatusCode(Enum):
    """HTTP状态码枚举"""
    # 成功状态码
    OK = 200
    CREATED = 201
    ACCEPTED = 202
    NO_CONTENT = 204
    
    # 重定向状态码
    MOVED_PERMANENTLY = 301
    FOUND = 302
    SEE_OTHER = 303
    
    # 客户端错误状态码
    BAD_REQUEST = 400
    UNAUTHORIZED = 401
    FORBIDDEN = 403
    NOT_FOUND = 404
    METHOD_NOT_ALLOWED = 405
    
    # 服务器错误状态码
    INTERNAL_SERVER_ERROR = 500
    NOT_IMPLEMENTED = 501
    BAD_GATEWAY = 502
    SERVICE_UNAVAILABLE = 503
    
    def is_success(self):
        """是否为成功状态码"""
        return 200 <= self.value < 300
    
    def is_client_error(self):
        """是否为客户端错误状态码"""
        return 400 <= self.value < 500
    
    def is_server_error(self):
        """是否为服务器错误状态码"""
        return 500 <= self.value < 600
    
    def get_description(self):
        """获取状态码描述"""
        descriptions = {
            200: "请求成功",
            201: "创建成功",
            202: "已接受",
            204: "无内容",
            301: "永久重定向",
            302: "临时重定向",
            303: "查看其他",
            400: "请求错误",
            401: "未授权",
            403: "禁止访问",
            404: "资源不存在",
            405: "方法不允许",
            500: "服务器内部错误",
            501: "未实现",
            502: "网关错误",
            503: "服务不可用"
        }
        return descriptions.get(self.value, "未知状态码")

# 使用示例
def handle_response(status_code):
    code = HttpStatusCode(status_code)
    
    print(f"状态码: {code.value} ({code.name})")
    print(f"描述: {code.get_description()}")
    
    if code.is_success():
        print("✓ 请求成功处理")
    elif code.is_client_error():
        print("✗ 客户端错误,请检查请求")
    elif code.is_server_error():
        print("✗ 服务器错误,请稍后重试")
    
    return code

# 测试各种状态码
handle_response(200)
print()
handle_response(404)
print()
handle_response(500)

案例2:订单状态管理系统

from enum import Enum, auto

class OrderState(Enum):
    """订单状态枚举"""
    PENDING = auto()      # 待处理
    CONFIRMED = auto()    # 已确认
    PROCESSING = auto()   # 处理中
    SHIPPED = auto()      # 已发货
    DELIVERED = auto()    # 已送达
    CANCELLED = auto()    # 已取消
    REFUNDED = auto()     # 已退款
    
    def can_transition_to(self, new_state):
        """检查是否可以转换到新状态"""
        transitions = {
            OrderState.PENDING: [OrderState.CONFIRMED, OrderState.CANCELLED],
            OrderState.CONFIRMED: [OrderState.PROCESSING, OrderState.CANCELLED],
            OrderState.PROCESSING: [OrderState.SHIPPED, OrderState.CANCELLED],
            OrderState.SHIPPED: [OrderState.DELIVERED, OrderState.CANCELLED],
            OrderState.DELIVERED: [OrderState.REFUNDED],
            OrderState.CANCELLED: [],  # 取消后不能再转换
            OrderState.REFUNDED: []    # 退款后不能再转换
        }
        
        return new_state in transitions.get(self, [])
    
    def get_next_states(self):
        """获取可能的下一个状态"""
        transitions = {
            OrderState.PENDING: [OrderState.CONFIRMED, OrderState.CANCELLED],
            OrderState.CONFIRMED: [OrderState.PROCESSING, OrderState.CANCELLED],
            OrderState.PROCESSING: [OrderState.SHIPPED, OrderState.CANCELLED],
            OrderState.SHIPPED: [OrderState.DELIVERED, OrderState.CANCELLED],
            OrderState.DELIVERED: [OrderState.REFUNDED],
            OrderState.CANCELLED: [],
            OrderState.REFUNDED: []
        }
        return transitions.get(self, [])
    
    def __str__(self):
        names = {
            OrderState.PENDING: "待处理",
            OrderState.CONFIRMED: "已确认",
            OrderState.PROCESSING: "处理中",
            OrderState.SHIPPED: "已发货",
            OrderState.DELIVERED: "已送达",
            OrderState.CANCELLED: "已取消",
            OrderState.REFUNDED: "已退款"
        }
        return names.get(self, "未知状态")

class Order:
    """订单类"""
    
    def __init__(self, order_id, customer_name):
        self.order_id = order_id
        self.customer_name = customer_name
        self.state = OrderState.PENDING
    
    def change_state(self, new_state):
        """改变订单状态"""
        if self.state.can_transition_to(new_state):
            old_state = self.state
            self.state = new_state
            print(f"订单 {self.order_id} 状态从 {old_state} 变更为 {new_state}")
            return True
        else:
            print(f"❌ 无法从 {self.state} 转换到 {new_state}")
            return False
    
    def show_available_transitions(self):
        """显示可用的状态转换"""
        next_states = self.state.get_next_states()
        if next_states:
            print(f"订单 {self.order_id} 当前状态: {self.state}")
            print("可转换为:")
            for state in next_states:
                print(f"  - {state} ({str(state)})")
        else:
            print(f"订单 {self.order_id} 当前状态: {self.state},无可用转换")

# 使用示例
order = Order("ORD001", "张三")

order.show_available_transitions()
order.change_state(OrderState.CONFIRMED)
print()

order.show_available_transitions()
order.change_state(OrderState.PROCESSING)
print()

order.show_available_transitions()
order.change_state(OrderState.SHIPPED)
print()

# 尝试非法转换
order.change_state(OrderState.PENDING)  # 这会失败

常见错误

错误1:混淆枚举值和成员

from enum import Enum

class Color(Enum):
    RED = 1
    GREEN = 2
    BLUE = 3

# 错误:直接比较值和成员
# if color == 1:  # 这样不会工作

# 正确:比较枚举成员
color = Color.RED
if color == Color.RED:
    print("红色")

# 或者比较值
if color.value == 1:
    print("红色")

错误2:修改枚举成员

from enum import Enum

class Status(Enum):
    PENDING = 1
    APPROVED = 2

status = Status.PENDING
# 错误:枚举成员是只读的,不能修改
# status.value = 3  # 这会报错

# 如果需要不同的行为,创建新的枚举或使用不同的设计

错误3:忽略枚举的类型安全

from enum import Enum

class Permission(Enum):
    READ = 1
    WRITE = 2
    EXECUTE = 4

def check_permission(user_perm, required_perm):
    # 错误:没有验证参数类型
    # return (user_perm & required_perm) == required_perm
    
    # 正确:验证参数类型
    if not isinstance(user_perm, Permission):
        raise TypeError("user_perm 必须是 Permission 枚举")
    if not isinstance(required_perm, Permission):
        raise TypeError("required_perm 必须是 Permission 枚举")
    
    return (user_perm.value & required_perm.value) == required_perm.value

课后练习

  1. 创建一个 Direction 枚举表示四个基本方向,并实现坐标移动功能
  2. 设计一个 FileType 枚举管理不同文件类型及其对应的MIME类型
  3. 实现一个 GameState 枚举用于游戏状态管理(开始、暂停、结束等)
  4. 创建 Permission 枚举实现位运算权限管理
  5. 设计一个 PaymentStatus 枚举管理支付流程的各种状态

总结

枚举类型是Python中管理常量的强大工具:

  • 提供类型安全和代码可读性
  • 支持自动赋值和唯一性约束
  • 可以包含自定义方法和业务逻辑
  • 适用于状态管理、配置管理、协议定义等场景
  • 与面向对象编程完美结合,提升代码质量
« 上一篇 常用魔术方法 下一篇 » 数据类