第67集:数据类

学习目标

  1. 理解数据类的概念和作用
  2. 掌握@dataclass装饰器的基本使用方法
  3. 学会定义数据类字段和字段属性
  4. 了解数据类的高级特性和应用场景
  5. 掌握数据类与普通类的区别和优势

数据类概念

什么是数据类

数据类(Data Class)是Python 3.7引入的一个特殊类,专门用于主要存储数据的类。使用@dataclass装饰器可以自动生成__init____repr____eq__等特殊方法,大大减少样板代码。

为什么使用数据类

  • 减少样板代码:自动生成常用方法
  • 提高代码可读性:明确标识这是一个数据容器
  • 类型提示友好:与类型注解完美配合
  • 不可变性支持:可以创建不可变的数据对象
  • 默认值管理:方便处理可选字段的默认值

dataclasses模块基础

基本数据类定义

最简单的data class

from dataclasses import dataclass

@dataclass
class Person:
    """人员信息数据类"""
    name: str
    age: int
    email: str

# 自动生成的__init__方法
person = Person("张三", 25, "zhangsan@example.com")
print(person)  # 输出: Person(name='张三', age=25, email='zhangsan@example.com')

# 自动生成的__repr__方法
print(repr(person))  # 输出: Person(name='张三', age=25, email='zhangsan@example.com')

# 自动生成的__eq__方法
person2 = Person("张三", 25, "zhangsan@example.com")
print(person == person2)  # 输出: True

与传统类的对比

# 传统方式 - 需要手动编写大量代码
class PersonTraditional:
    def __init__(self, name: str, age: int, email: str):
        self.name = name
        self.age = age
        self.email = email
    
    def __repr__(self):
        return f"PersonTraditional(name='{self.name}', age={self.age}, email='{self.email}')"
    
    def __eq__(self, other):
        if not isinstance(other, PersonTraditional):
            return False
        return (self.name == other.name and 
                self.age == other.age and 
                self.email == other.email)

# 数据类方式 - 简洁明了
from dataclasses import dataclass

@dataclass
class PersonDataClass:
    name: str
    age: int
    email: str

# 功能完全相同,但代码量减少70%
traditional = PersonTraditional("李四", 30, "lisi@example.com")
dataclass_obj = PersonDataClass("李四", 30, "lisi@example.com")
print(traditional)
print(dataclass_obj)

数据类字段定义

字段类型和默认值

必需字段和可选字段

from dataclasses import dataclass
from typing import Optional

@dataclass
class Student:
    """学生信息数据类"""
    # 必需字段 - 没有默认值
    name: str
    student_id: str
    
    # 可选字段 - 有默认值
    age: int = 18
    grade: str = "大一"
    major: str = "计算机科学"
    gpa: float = 0.0
    
    # 可选字段 - 使用None作为默认值
    phone: Optional[str] = None
    address: Optional[str] = None

# 创建实例的不同方式
student1 = Student("王小明", "2023001")  # 只提供必需字段
print(student1)
# 输出: Student(name='王小明', student_id='2023001', age=18, grade='大一', major='计算机科学', gpa=0.0, phone=None, address=None)

student2 = Student("李小红", "2023002", 20, "大三", "数学", 3.8, "13800138000", "北京市朝阳区")
print(student2)

使用field()函数定制字段

from dataclasses import dataclass, field
from typing import List
import random

@dataclass
class Product:
    """商品数据类 - 展示field()函数的使用"""
    name: str
    price: float
    category: str
    
    # 初始化时不设置,运行时计算
    discount_price: float = field(init=False)
    
    # 默认空列表,但每个实例都有独立的列表
    tags: List[str] = field(default_factory=list)
    
    # 元数据 - 用于存储额外信息
    sku: str = field(default="", metadata={"description": "库存单位编码"})
    
    def __post_init__(self):
        """后初始化处理 - 在__init__之后自动调用"""
        # 计算折扣价格(打9折)
        self.discount_price = round(self.price * 0.9, 2)
        
        # 如果没有SKU,自动生成一个
        if not self.sku:
            self.sku = f"SKU-{random.randint(100000, 999999)}"

# 创建商品实例
product1 = Product("iPhone 15", 5999.0, "手机")
print(product1)
print(f"原价: {product1.price}, 折扣价: {product1.discount_price}")
print(f"SKU: {product1.sku}")

# 添加标签
product1.tags.extend(["热销", "新品", "苹果"])
print(f"标签: {product1.tags}")

# 创建另一个实例,验证独立性
product2 = Product("MacBook Pro", 12999.0, "电脑")
product2.tags.append("专业版")
print(f"商品1标签: {product1.tags}")
print(f"商品2标签: {product2.tags}")  # 互不影响

字段属性详解

field()函数参数说明

from dataclasses import dataclass, field
from typing import Optional

@dataclass
class Employee:
    """员工数据类 - 展示field()的各种参数"""
    name: str
    employee_id: str
    
    # init: 是否在__init__中包含此字段
    salary: float = field(default=0.0, init=False)  # 不在__init__中
    
    # default: 字段默认值
    department: str = field(default="未分配", init=True)
    
    # default_factory: 使用工厂函数生成默认值
    skills: list = field(default_factory=list)
    
    # repr: 是否在__repr__中显示
    phone: str = field(default="", repr=False)  # 不在repr中显示
    
    # compare: 是否参与比较操作
    temp_note: str = field(default="", compare=False)  # 不参与==比较
    
    # hash: 是否参与哈希计算
    public_id: str = field(default="", hash=True)
    
    # metadata: 元数据字典
    level: str = field(default="初级", metadata={
        "description": "员工级别",
        "levels": ["初级", "中级", "高级", "专家"]
    })
    
    def __post_init__(self):
        """根据级别设置薪资"""
        salary_map = {
            "初级": 5000,
            "中级": 8000, 
            "高级": 12000,
            "专家": 20000
        }
        self.salary = salary_map.get(self.level, 5000)

# 测试各种field参数
emp = Employee("张工程师", "E001", "技术部", "13800138000", "临时备注", "EMP001", "高级")
print(emp)  # phone和temp_note不会显示
print(f"薪资: {emp.salary}")  # 通过__post_init__设置
print(f"元数据: {emp.level.metadata}")

# 比较操作 - temp_note不参与比较
emp2 = Employee("张工程师", "E001", "技术部", "13800138000", "不同备注", "EMP001", "高级")
print(f"是否相等: {emp == emp2}")  # True - temp_note不影响比较

数据类高级特性

不可变数据类

使用frozen=True创建不可变实例

from dataclasses import dataclass
from typing import Tuple

@dataclass(frozen=True)
class Point3D:
    """三维坐标点 - 不可变数据类"""
    x: float
    y: float
    z: float
    
    def distance_from_origin(self) -> float:
        """计算到原点的距离"""
        return (self.x ** 2 + self.y ** 2 + self.z ** 2) ** 0.5
    
    def move(self, dx: float, dy: float, dz: float) -> 'Point3D':
        """移动到新位置 - 返回新实例而不是修改当前实例"""
        return Point3D(self.x + dx, self.y + dy, self.z + dz)

# 创建不可变对象
point = Point3D(1.0, 2.0, 3.0)
print(f"原点: {point}")
print(f"到原点距离: {point.distance_from_origin():.2f}")

# 尝试修改会报错
# point.x = 5.0  # 这行会报错: FrozenInstanceError

# 通过方法创建新实例
new_point = point.move(1.0, 1.0, 1.0)
print(f"原位置: {point}")
print(f"新位置: {new_point}")

# 可以作为字典键(因为不可变)
points_dict = {point: "起点", new_point: "终点"}
print(f"字典: {points_dict}")

继承和数据类

数据类的继承

from dataclasses import dataclass
from abc import ABC

@dataclass
class Animal(ABC):
    """动物基类"""
    name: str
    age: int
    species: str

@dataclass
class Dog(Animal):
    """狗类 - 继承自Animal"""
    breed: str  # 品种
    is_trained: bool = False
    
    def bark(self) -> str:
        return f"{self.name} 汪汪叫!"

@dataclass
class Cat(Animal):
    """猫类 - 继承自Animal"""
    fur_color: str  # 毛色
    lives_left: int = 9
    
    def meow(self) -> str:
        return f"{self.name} 喵喵叫!"

# 创建实例
dog = Dog("旺财", 3, "犬科", "金毛寻回犬", True)
cat = Cat("咪咪", 2, "猫科", "橘色", 8)

print(dog)
print(f"行为: {dog.bark()}")
print(cat)
print(f"行为: {cat.meow()}")

# 多态演示
animals = [dog, cat]
for animal in animals:
    print(f"{animal.name} 是 {animal.species}, {animal.age}岁")

特殊方法和自定义行为

自定义比较和哈希

from dataclasses import dataclass
from typing import List

@dataclass(order=True)  # 自动生成比较方法
class Score:
    """分数数据类 - 支持排序"""
    sort_index: int = field(init=False, repr=False)  # 用于控制排序
    name: str
    subject: str
    points: float
    
    def __post_init__(self):
        # 设置排序索引 - 按分数降序排列
        self.sort_index = -int(self.points * 100)  # 负号实现降序

@dataclass(eq=False)  # 不自动生成__eq__
class Team:
    """团队数据类 - 自定义相等性"""
    name: str
    members: List[str]
    
    def __eq__(self, other):
        """自定义相等性:只要团队名相同就认为相等"""
        if not isinstance(other, Team):
            return False
        return self.name == other.name
    
    def __hash__(self):
        """自定义哈希 - 基于团队名"""
        return hash(self.name)

# 测试排序功能
scores = [
    Score("张三", "数学", 95.5),
    Score("李四", "数学", 87.0),
    Score("王五", "数学", 92.3),
    Score("赵六", "数学", 78.9)
]

print("原始分数:")
for score in scores:
    print(f"{score.name}: {score.points}")

# 自动排序(按分数降序)
sorted_scores = sorted(scores)
print("\n排序后:")
for score in sorted_scores:
    print(f"{score.name}: {score.points}")

# 测试自定义相等性
team1 = Team("火箭队", ["姚明", "麦迪"])
team2 = Team("火箭队", ["詹姆斯", "韦德"])  # 成员不同但队名相同
team3 = Team("湖人队", ["科比", "奥尼尔"])

print(f"\nteam1 == team2: {team1 == team2}")  # True - 队名相同
print(f"team1 == team3: {team1 == team3}")  # False - 队名不同

# 用作集合元素(需要__hash__)
teams_set = {team1, team2, team3}
print(f"集合中的队伍数量: {len(teams_set)}")  # 2个(team1和team2被视为相同)

应用案例

案例1:配置管理系统

from dataclasses import dataclass, field
from typing import Dict, List, Optional
import json

@dataclass
class DatabaseConfig:
    """数据库配置"""
    host: str = "localhost"
    port: int = 5432
    database: str = "myapp"
    username: str = "admin"
    password: str = field(default="", repr=False)  # 密码不在repr中显示
    pool_size: int = 10
    timeout: int = 30

@dataclass
class RedisConfig:
    """Redis配置"""
    host: str = "localhost"
    port: int = 6379
    db: int = 0
    password: Optional[str] = None
    max_connections: int = 50

@dataclass
class AppConfig:
    """应用程序配置"""
    app_name: str = "MyApplication"
    debug: bool = False
    version: str = "1.0.0"
    
    # 嵌套配置
    database: DatabaseConfig = field(default_factory=DatabaseConfig)
    redis: RedisConfig = field(default_factory=RedisConfig)
    
    # 动态配置
    allowed_hosts: List[str] = field(default_factory=lambda: ["localhost", "127.0.0.1"])
    custom_settings: Dict[str, str] = field(default_factory=dict)
    
    @classmethod
    def from_dict(cls, config_dict: dict) -> 'AppConfig':
        """从字典创建配置对象"""
        # 提取嵌套配置
        db_config = DatabaseConfig(**config_dict.get('database', {}))
        redis_config = RedisConfig(**config_dict.get('redis', {}))
        
        # 移除嵌套配置,剩余的是顶层配置
        top_level_config = {k: v for k, v in config_dict.items() 
                           if k not in ['database', 'redis']}
        
        return cls(
            database=db_config,
            redis=redis_config,
            **top_level_config
        )
    
    def to_dict(self) -> dict:
        """转换为字典"""
        result = {
            'app_name': self.app_name,
            'debug': self.debug,
            'version': self.version,
            'allowed_hosts': self.allowed_hosts,
            'custom_settings': self.custom_settings
        }
        
        # 嵌套配置也转换为字典
        result['database'] = {
            'host': self.database.host,
            'port': self.database.port,
            'database': self.database.database,
            'username': self.database.username,
            'pool_size': self.database.pool_size,
            'timeout': self.database.timeout
            # password故意不包含在内
        }
        
        result['redis'] = {
            'host': self.redis.host,
            'port': self.redis.port,
            'db': self.redis.db,
            'max_connections': self.redis.max_connections
        }
        
        return result
    
    def save_to_file(self, filename: str):
        """保存配置到文件"""
        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
    
    @classmethod
    def load_from_file(cls, filename: str) -> 'AppConfig':
        """从文件加载配置"""
        with open(filename, 'r', encoding='utf-8') as f:
            config_dict = json.load(f)
        return cls.from_dict(config_dict)

# 使用示例
config = AppConfig(
    app_name="电商系统",
    debug=True,
    database=DatabaseConfig(
        host="192.168.1.100",
        database="ecommerce",
        username="ecom_user",
        pool_size=20
    ),
    redis=RedisConfig(
        host="192.168.1.101",
        max_connections=100
    ),
    allowed_hosts=["example.com", "api.example.com"],
    custom_settings={"theme": "dark", "language": "zh-CN"}
)

print("=== 配置信息 ===")
print(f"应用名称: {config.app_name}")
print(f"调试模式: {config.debug}")
print(f"数据库主机: {config.database.host}:{config.database.port}")
print(f"Redis最大连接数: {config.redis.max_connections}")
print(f"允许的域名: {config.allowed_hosts}")

# 保存到文件
config.save_to_file("app_config.json")
print("\n配置已保存到 app_config.json")

# 从文件加载
loaded_config = AppConfig.load_from_file("app_config.json")
print(f"\n从文件加载的应用名称: {loaded_config.app_name}")

案例2:电商订单系统

from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import List, Optional
from enum import Enum

class OrderStatus(Enum):
    """订单状态枚举"""
    PENDING = "待付款"
    PAID = "已付款"
    SHIPPED = "已发货"
    DELIVERED = "已送达"
    CANCELLED = "已取消"

@dataclass
class ProductItem:
    """订单商品项"""
    product_id: str
    name: str
    price: float
    quantity: int
    image_url: Optional[str] = None
    
    @property
    def subtotal(self) -> float:
        """小计金额"""
        return self.price * self.quantity

@dataclass
class ShippingAddress:
    """收货地址"""
    recipient: str
    phone: str
    province: str
    city: str
    district: str
    detail_address: str
    postal_code: str = ""
    
    def full_address(self) -> str:
        """完整地址"""
        return f"{self.province}{self.city}{self.district}{self.detail_address}"

@dataclass
class Order:
    """订单数据类"""
    order_id: str
    customer_id: str
    items: List[ProductItem]
    shipping_address: ShippingAddress
    status: OrderStatus = OrderStatus.PENDING
    
    # 时间字段
    created_at: datetime = field(default_factory=datetime.now)
    paid_at: Optional[datetime] = None
    shipped_at: Optional[datetime] = None
    delivered_at: Optional[datetime] = None
    
    # 费用字段
    shipping_fee: float = 0.0
    discount_amount: float = 0.0
    
    def __post_init__(self):
        """验证和计算"""
        if not self.items:
            raise ValueError("订单必须包含至少一个商品")
        
        # 验证商品数量
        for item in self.items:
            if item.quantity <= 0:
                raise ValueError(f"商品 {item.name} 的数量必须大于0")
    
    @property
    def total_items(self) -> int:
        """商品总数量"""
        return sum(item.quantity for item in self.items)
    
    @property
    def subtotal(self) -> float:
        """商品小计"""
        return sum(item.subtotal for item in self.items)
    
    @property
    def total_amount(self) -> float:
        """订单总金额"""
        return self.subtotal + self.shipping_fee - self.discount_amount
    
    def can_cancel(self) -> bool:
        """是否可以取消订单"""
        return self.status in [OrderStatus.PENDING, OrderStatus.PAID]
    
    def cancel(self) -> bool:
        """取消订单"""
        if self.can_cancel():
            self.status = OrderStatus.CANCELLED
            return True
        return False
    
    def pay(self) -> bool:
        """付款"""
        if self.status == OrderStatus.PENDING:
            self.status = OrderStatus.PAID
            self.paid_at = datetime.now()
            return True
        return False
    
    def ship(self) -> bool:
        """发货"""
        if self.status == OrderStatus.PAID:
            self.status = OrderStatus.SHIPPED
            self.shipped_at = datetime.now()
            return True
        return False
    
    def deliver(self) -> bool:
        """确认送达"""
        if self.status == OrderStatus.SHIPPED:
            self.status = OrderStatus.DELIVERED
            self.delivered_at = datetime.now()
            return True
        return False
    
    def get_processing_days(self) -> Optional[int]:
        """获取处理天数"""
        if self.paid_at and self.shipped_at:
            return (self.shipped_at - self.paid_at).days
        return None
    
    def summary(self) -> str:
        """订单摘要"""
        return (
            f"订单号: {self.order_id}\n"
            f"客户ID: {self.customer_id}\n"
            f"状态: {self.status.value}\n"
            f"商品种类: {len(self.items)}种, 总数量: {self.total_items}件\n"
            f"商品金额: ¥{self.subtotal:.2f}\n"
            f"运费: ¥{self.shipping_fee:.2f}\n"
            f"优惠: -¥{self.discount_amount:.2f}\n"
            f"实付金额: ¥{self.total_amount:.2f}\n"
            f"下单时间: {self.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
        )

# 使用示例
# 创建商品项
items = [
    ProductItem("P001", "iPhone 15", 5999.0, 1, "iphone.jpg"),
    ProductItem("P002", "AirPods Pro", 1999.0, 1, "airpods.jpg"),
    ProductItem("P003", "Apple Watch", 2999.0, 1, "watch.jpg")
]

# 创建收货地址
address = ShippingAddress(
    recipient="张三",
    phone="13800138000",
    province="北京市",
    city="北京市",
    district="朝阳区",
    detail_address="三里屯街道1号",
    postal_code="100027"
)

# 创建订单
order = Order(
    order_id="ORD202312250001",
    customer_id="CUST001",
    items=items,
    shipping_address=address,
    shipping_fee=15.0,
    discount_amount=200.0
)

print("=== 订单创建成功 ===")
print(order.summary())
print(f"\n完整收货地址: {address.full_address()}")

# 模拟订单流程
print("\n=== 订单处理流程 ===")
print(f"是否可以取消: {order.can_cancel()}")

order.pay()
print(f"付款后状态: {order.status.value}")

order.ship()
print(f"发货后状态: {order.status.value}")

order.deliver()
print(f"送达后状态: {order.status.value}")

processing_days = order.get_processing_days()
if processing_days is not None:
    print(f"处理用时: {processing_days}天")

print(f"\n订单最终摘要:\n{order.summary()}")

常见错误

错误1:混淆可变和不可变默认值

from dataclasses import dataclass, field
from typing import List

@dataclass
class ProblematicClass:
    """有问题的数据类 - 展示了常见的陷阱"""
    # 危险:所有实例共享同一个列表
    tags: List[str] = []  # 错误!
    
    # 正确:使用default_factory
    # tags: List[str] = field(default_factory=list)

problem_class = ProblematicClass()
problem_class.tags.append("test")

another_instance = ProblematicClass()
print(f"另一个实例的tags: {another_instance.tags}")  # 也会包含"test"!

# 正确的做法
@dataclass
class CorrectClass:
    tags: List[str] = field(default_factory=list)

correct1 = CorrectClass()
correct1.tags.append("test")

correct2 = CorrectClass()
print(f"正确实现的另一个实例tags: {correct2.tags}")  # 空的[]

错误2:忘记类型注解

from dataclasses import dataclass

@dataclass
class MissingTypes:
    """缺少类型注解的数据类"""
    # 错误:没有类型注解,dataclass可能无法正常工作
    # name = "default_name"  # 这不会被识别为字段
    
    # 正确:必须有类型注解
    name: str = "default_name"
    age: int = 0

# dataclass需要明确的类型注解才能识别字段
obj = MissingTypes("test", 25)
print(obj)

错误3:在frozen类中尝试修改字段

from dataclasses import dataclass

@dataclass(frozen=True)
class ImmutablePoint:
    x: float
    y: float

point = ImmutablePoint(1.0, 2.0)
# 错误:不能在frozen类中修改字段
# point.x = 3.0  # FrozenInstanceError

# 正确:创建新实例
new_point = ImmutablePoint(3.0, point.y)
print(f"新点: {new_point}")

课后练习

  1. 创建一个Book数据类,包含书名、作者、ISBN、价格、出版日期等字段,并实现图书信息的格式化显示功能。

  2. 设计一个UserProfile数据类,管理用户的个人信息、偏好设置和权限列表,支持配置的序列化和反序列化。

  3. 实现一个Rectangle数据类,表示矩形,包含计算面积、周长、对角线的方法,以及判断与其他矩形关系的功能。

  4. 创建Course数据类来管理课程信息,包括课程名称、教师、学生列表、上课时间等,并实现选课和退课的功能。

  5. 设计一个BankAccount数据类,实现账户余额管理、交易记录、冻结/解冻等功能,确保金额的准确性和安全性。

总结

数据类是Python现代编程中的重要特性:

  • 大幅减少样板代码,提高开发效率
  • 与类型提示完美结合,增强代码可维护性
  • 支持灵活的配置选项和高级特性
  • 特别适合配置管理、DTO、值对象等场景
  • 通过@dataclass装饰器轻松创建功能丰富的数据容器
  • 与传统的手写类相比,更加简洁、安全、易读
« 上一篇 枚举类型 下一篇 » 单例模式