第67集:数据类
学习目标
- 理解数据类的概念和作用
- 掌握
@dataclass装饰器的基本使用方法 - 学会定义数据类字段和字段属性
- 了解数据类的高级特性和应用场景
- 掌握数据类与普通类的区别和优势
数据类概念
什么是数据类
数据类(Data Class)是Python 3.7引入的一个特殊类,专门用于主要存储数据的类。使用@dataclass装饰器可以自动生成__init__、__repr__、__eq__等特殊方法,大大减少样板代码。
为什么使用数据类
- 减少样板代码:自动生成常用方法
- 提高代码可读性:明确标识这是一个数据容器
- 类型提示友好:与类型注解完美配合
- 不可变性支持:可以创建不可变的数据对象
- 默认值管理:方便处理可选字段的默认值
dataclasses模块基础
基本数据类定义
最简单的data class
from dataclasses import dataclass
@dataclass
class Person:
"""人员信息数据类"""
name: str
age: int
email: str
# 自动生成的__init__方法
person = Person("张三", 25, "zhangsan@example.com")
print(person) # 输出: Person(name='张三', age=25, email='zhangsan@example.com')
# 自动生成的__repr__方法
print(repr(person)) # 输出: Person(name='张三', age=25, email='zhangsan@example.com')
# 自动生成的__eq__方法
person2 = Person("张三", 25, "zhangsan@example.com")
print(person == person2) # 输出: True与传统类的对比
# 传统方式 - 需要手动编写大量代码
class PersonTraditional:
def __init__(self, name: str, age: int, email: str):
self.name = name
self.age = age
self.email = email
def __repr__(self):
return f"PersonTraditional(name='{self.name}', age={self.age}, email='{self.email}')"
def __eq__(self, other):
if not isinstance(other, PersonTraditional):
return False
return (self.name == other.name and
self.age == other.age and
self.email == other.email)
# 数据类方式 - 简洁明了
from dataclasses import dataclass
@dataclass
class PersonDataClass:
name: str
age: int
email: str
# 功能完全相同,但代码量减少70%
traditional = PersonTraditional("李四", 30, "lisi@example.com")
dataclass_obj = PersonDataClass("李四", 30, "lisi@example.com")
print(traditional)
print(dataclass_obj)数据类字段定义
字段类型和默认值
必需字段和可选字段
from dataclasses import dataclass
from typing import Optional
@dataclass
class Student:
"""学生信息数据类"""
# 必需字段 - 没有默认值
name: str
student_id: str
# 可选字段 - 有默认值
age: int = 18
grade: str = "大一"
major: str = "计算机科学"
gpa: float = 0.0
# 可选字段 - 使用None作为默认值
phone: Optional[str] = None
address: Optional[str] = None
# 创建实例的不同方式
student1 = Student("王小明", "2023001") # 只提供必需字段
print(student1)
# 输出: Student(name='王小明', student_id='2023001', age=18, grade='大一', major='计算机科学', gpa=0.0, phone=None, address=None)
student2 = Student("李小红", "2023002", 20, "大三", "数学", 3.8, "13800138000", "北京市朝阳区")
print(student2)使用field()函数定制字段
from dataclasses import dataclass, field
from typing import List
import random
@dataclass
class Product:
"""商品数据类 - 展示field()函数的使用"""
name: str
price: float
category: str
# 初始化时不设置,运行时计算
discount_price: float = field(init=False)
# 默认空列表,但每个实例都有独立的列表
tags: List[str] = field(default_factory=list)
# 元数据 - 用于存储额外信息
sku: str = field(default="", metadata={"description": "库存单位编码"})
def __post_init__(self):
"""后初始化处理 - 在__init__之后自动调用"""
# 计算折扣价格(打9折)
self.discount_price = round(self.price * 0.9, 2)
# 如果没有SKU,自动生成一个
if not self.sku:
self.sku = f"SKU-{random.randint(100000, 999999)}"
# 创建商品实例
product1 = Product("iPhone 15", 5999.0, "手机")
print(product1)
print(f"原价: {product1.price}, 折扣价: {product1.discount_price}")
print(f"SKU: {product1.sku}")
# 添加标签
product1.tags.extend(["热销", "新品", "苹果"])
print(f"标签: {product1.tags}")
# 创建另一个实例,验证独立性
product2 = Product("MacBook Pro", 12999.0, "电脑")
product2.tags.append("专业版")
print(f"商品1标签: {product1.tags}")
print(f"商品2标签: {product2.tags}") # 互不影响字段属性详解
field()函数参数说明
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class Employee:
"""员工数据类 - 展示field()的各种参数"""
name: str
employee_id: str
# init: 是否在__init__中包含此字段
salary: float = field(default=0.0, init=False) # 不在__init__中
# default: 字段默认值
department: str = field(default="未分配", init=True)
# default_factory: 使用工厂函数生成默认值
skills: list = field(default_factory=list)
# repr: 是否在__repr__中显示
phone: str = field(default="", repr=False) # 不在repr中显示
# compare: 是否参与比较操作
temp_note: str = field(default="", compare=False) # 不参与==比较
# hash: 是否参与哈希计算
public_id: str = field(default="", hash=True)
# metadata: 元数据字典
level: str = field(default="初级", metadata={
"description": "员工级别",
"levels": ["初级", "中级", "高级", "专家"]
})
def __post_init__(self):
"""根据级别设置薪资"""
salary_map = {
"初级": 5000,
"中级": 8000,
"高级": 12000,
"专家": 20000
}
self.salary = salary_map.get(self.level, 5000)
# 测试各种field参数
emp = Employee("张工程师", "E001", "技术部", "13800138000", "临时备注", "EMP001", "高级")
print(emp) # phone和temp_note不会显示
print(f"薪资: {emp.salary}") # 通过__post_init__设置
print(f"元数据: {emp.level.metadata}")
# 比较操作 - temp_note不参与比较
emp2 = Employee("张工程师", "E001", "技术部", "13800138000", "不同备注", "EMP001", "高级")
print(f"是否相等: {emp == emp2}") # True - temp_note不影响比较数据类高级特性
不可变数据类
使用frozen=True创建不可变实例
from dataclasses import dataclass
from typing import Tuple
@dataclass(frozen=True)
class Point3D:
"""三维坐标点 - 不可变数据类"""
x: float
y: float
z: float
def distance_from_origin(self) -> float:
"""计算到原点的距离"""
return (self.x ** 2 + self.y ** 2 + self.z ** 2) ** 0.5
def move(self, dx: float, dy: float, dz: float) -> 'Point3D':
"""移动到新位置 - 返回新实例而不是修改当前实例"""
return Point3D(self.x + dx, self.y + dy, self.z + dz)
# 创建不可变对象
point = Point3D(1.0, 2.0, 3.0)
print(f"原点: {point}")
print(f"到原点距离: {point.distance_from_origin():.2f}")
# 尝试修改会报错
# point.x = 5.0 # 这行会报错: FrozenInstanceError
# 通过方法创建新实例
new_point = point.move(1.0, 1.0, 1.0)
print(f"原位置: {point}")
print(f"新位置: {new_point}")
# 可以作为字典键(因为不可变)
points_dict = {point: "起点", new_point: "终点"}
print(f"字典: {points_dict}")继承和数据类
数据类的继承
from dataclasses import dataclass
from abc import ABC
@dataclass
class Animal(ABC):
"""动物基类"""
name: str
age: int
species: str
@dataclass
class Dog(Animal):
"""狗类 - 继承自Animal"""
breed: str # 品种
is_trained: bool = False
def bark(self) -> str:
return f"{self.name} 汪汪叫!"
@dataclass
class Cat(Animal):
"""猫类 - 继承自Animal"""
fur_color: str # 毛色
lives_left: int = 9
def meow(self) -> str:
return f"{self.name} 喵喵叫!"
# 创建实例
dog = Dog("旺财", 3, "犬科", "金毛寻回犬", True)
cat = Cat("咪咪", 2, "猫科", "橘色", 8)
print(dog)
print(f"行为: {dog.bark()}")
print(cat)
print(f"行为: {cat.meow()}")
# 多态演示
animals = [dog, cat]
for animal in animals:
print(f"{animal.name} 是 {animal.species}, {animal.age}岁")特殊方法和自定义行为
自定义比较和哈希
from dataclasses import dataclass
from typing import List
@dataclass(order=True) # 自动生成比较方法
class Score:
"""分数数据类 - 支持排序"""
sort_index: int = field(init=False, repr=False) # 用于控制排序
name: str
subject: str
points: float
def __post_init__(self):
# 设置排序索引 - 按分数降序排列
self.sort_index = -int(self.points * 100) # 负号实现降序
@dataclass(eq=False) # 不自动生成__eq__
class Team:
"""团队数据类 - 自定义相等性"""
name: str
members: List[str]
def __eq__(self, other):
"""自定义相等性:只要团队名相同就认为相等"""
if not isinstance(other, Team):
return False
return self.name == other.name
def __hash__(self):
"""自定义哈希 - 基于团队名"""
return hash(self.name)
# 测试排序功能
scores = [
Score("张三", "数学", 95.5),
Score("李四", "数学", 87.0),
Score("王五", "数学", 92.3),
Score("赵六", "数学", 78.9)
]
print("原始分数:")
for score in scores:
print(f"{score.name}: {score.points}")
# 自动排序(按分数降序)
sorted_scores = sorted(scores)
print("\n排序后:")
for score in sorted_scores:
print(f"{score.name}: {score.points}")
# 测试自定义相等性
team1 = Team("火箭队", ["姚明", "麦迪"])
team2 = Team("火箭队", ["詹姆斯", "韦德"]) # 成员不同但队名相同
team3 = Team("湖人队", ["科比", "奥尼尔"])
print(f"\nteam1 == team2: {team1 == team2}") # True - 队名相同
print(f"team1 == team3: {team1 == team3}") # False - 队名不同
# 用作集合元素(需要__hash__)
teams_set = {team1, team2, team3}
print(f"集合中的队伍数量: {len(teams_set)}") # 2个(team1和team2被视为相同)应用案例
案例1:配置管理系统
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import json
@dataclass
class DatabaseConfig:
"""数据库配置"""
host: str = "localhost"
port: int = 5432
database: str = "myapp"
username: str = "admin"
password: str = field(default="", repr=False) # 密码不在repr中显示
pool_size: int = 10
timeout: int = 30
@dataclass
class RedisConfig:
"""Redis配置"""
host: str = "localhost"
port: int = 6379
db: int = 0
password: Optional[str] = None
max_connections: int = 50
@dataclass
class AppConfig:
"""应用程序配置"""
app_name: str = "MyApplication"
debug: bool = False
version: str = "1.0.0"
# 嵌套配置
database: DatabaseConfig = field(default_factory=DatabaseConfig)
redis: RedisConfig = field(default_factory=RedisConfig)
# 动态配置
allowed_hosts: List[str] = field(default_factory=lambda: ["localhost", "127.0.0.1"])
custom_settings: Dict[str, str] = field(default_factory=dict)
@classmethod
def from_dict(cls, config_dict: dict) -> 'AppConfig':
"""从字典创建配置对象"""
# 提取嵌套配置
db_config = DatabaseConfig(**config_dict.get('database', {}))
redis_config = RedisConfig(**config_dict.get('redis', {}))
# 移除嵌套配置,剩余的是顶层配置
top_level_config = {k: v for k, v in config_dict.items()
if k not in ['database', 'redis']}
return cls(
database=db_config,
redis=redis_config,
**top_level_config
)
def to_dict(self) -> dict:
"""转换为字典"""
result = {
'app_name': self.app_name,
'debug': self.debug,
'version': self.version,
'allowed_hosts': self.allowed_hosts,
'custom_settings': self.custom_settings
}
# 嵌套配置也转换为字典
result['database'] = {
'host': self.database.host,
'port': self.database.port,
'database': self.database.database,
'username': self.database.username,
'pool_size': self.database.pool_size,
'timeout': self.database.timeout
# password故意不包含在内
}
result['redis'] = {
'host': self.redis.host,
'port': self.redis.port,
'db': self.redis.db,
'max_connections': self.redis.max_connections
}
return result
def save_to_file(self, filename: str):
"""保存配置到文件"""
with open(filename, 'w', encoding='utf-8') as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
@classmethod
def load_from_file(cls, filename: str) -> 'AppConfig':
"""从文件加载配置"""
with open(filename, 'r', encoding='utf-8') as f:
config_dict = json.load(f)
return cls.from_dict(config_dict)
# 使用示例
config = AppConfig(
app_name="电商系统",
debug=True,
database=DatabaseConfig(
host="192.168.1.100",
database="ecommerce",
username="ecom_user",
pool_size=20
),
redis=RedisConfig(
host="192.168.1.101",
max_connections=100
),
allowed_hosts=["example.com", "api.example.com"],
custom_settings={"theme": "dark", "language": "zh-CN"}
)
print("=== 配置信息 ===")
print(f"应用名称: {config.app_name}")
print(f"调试模式: {config.debug}")
print(f"数据库主机: {config.database.host}:{config.database.port}")
print(f"Redis最大连接数: {config.redis.max_connections}")
print(f"允许的域名: {config.allowed_hosts}")
# 保存到文件
config.save_to_file("app_config.json")
print("\n配置已保存到 app_config.json")
# 从文件加载
loaded_config = AppConfig.load_from_file("app_config.json")
print(f"\n从文件加载的应用名称: {loaded_config.app_name}")案例2:电商订单系统
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import List, Optional
from enum import Enum
class OrderStatus(Enum):
"""订单状态枚举"""
PENDING = "待付款"
PAID = "已付款"
SHIPPED = "已发货"
DELIVERED = "已送达"
CANCELLED = "已取消"
@dataclass
class ProductItem:
"""订单商品项"""
product_id: str
name: str
price: float
quantity: int
image_url: Optional[str] = None
@property
def subtotal(self) -> float:
"""小计金额"""
return self.price * self.quantity
@dataclass
class ShippingAddress:
"""收货地址"""
recipient: str
phone: str
province: str
city: str
district: str
detail_address: str
postal_code: str = ""
def full_address(self) -> str:
"""完整地址"""
return f"{self.province}{self.city}{self.district}{self.detail_address}"
@dataclass
class Order:
"""订单数据类"""
order_id: str
customer_id: str
items: List[ProductItem]
shipping_address: ShippingAddress
status: OrderStatus = OrderStatus.PENDING
# 时间字段
created_at: datetime = field(default_factory=datetime.now)
paid_at: Optional[datetime] = None
shipped_at: Optional[datetime] = None
delivered_at: Optional[datetime] = None
# 费用字段
shipping_fee: float = 0.0
discount_amount: float = 0.0
def __post_init__(self):
"""验证和计算"""
if not self.items:
raise ValueError("订单必须包含至少一个商品")
# 验证商品数量
for item in self.items:
if item.quantity <= 0:
raise ValueError(f"商品 {item.name} 的数量必须大于0")
@property
def total_items(self) -> int:
"""商品总数量"""
return sum(item.quantity for item in self.items)
@property
def subtotal(self) -> float:
"""商品小计"""
return sum(item.subtotal for item in self.items)
@property
def total_amount(self) -> float:
"""订单总金额"""
return self.subtotal + self.shipping_fee - self.discount_amount
def can_cancel(self) -> bool:
"""是否可以取消订单"""
return self.status in [OrderStatus.PENDING, OrderStatus.PAID]
def cancel(self) -> bool:
"""取消订单"""
if self.can_cancel():
self.status = OrderStatus.CANCELLED
return True
return False
def pay(self) -> bool:
"""付款"""
if self.status == OrderStatus.PENDING:
self.status = OrderStatus.PAID
self.paid_at = datetime.now()
return True
return False
def ship(self) -> bool:
"""发货"""
if self.status == OrderStatus.PAID:
self.status = OrderStatus.SHIPPED
self.shipped_at = datetime.now()
return True
return False
def deliver(self) -> bool:
"""确认送达"""
if self.status == OrderStatus.SHIPPED:
self.status = OrderStatus.DELIVERED
self.delivered_at = datetime.now()
return True
return False
def get_processing_days(self) -> Optional[int]:
"""获取处理天数"""
if self.paid_at and self.shipped_at:
return (self.shipped_at - self.paid_at).days
return None
def summary(self) -> str:
"""订单摘要"""
return (
f"订单号: {self.order_id}\n"
f"客户ID: {self.customer_id}\n"
f"状态: {self.status.value}\n"
f"商品种类: {len(self.items)}种, 总数量: {self.total_items}件\n"
f"商品金额: ¥{self.subtotal:.2f}\n"
f"运费: ¥{self.shipping_fee:.2f}\n"
f"优惠: -¥{self.discount_amount:.2f}\n"
f"实付金额: ¥{self.total_amount:.2f}\n"
f"下单时间: {self.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
)
# 使用示例
# 创建商品项
items = [
ProductItem("P001", "iPhone 15", 5999.0, 1, "iphone.jpg"),
ProductItem("P002", "AirPods Pro", 1999.0, 1, "airpods.jpg"),
ProductItem("P003", "Apple Watch", 2999.0, 1, "watch.jpg")
]
# 创建收货地址
address = ShippingAddress(
recipient="张三",
phone="13800138000",
province="北京市",
city="北京市",
district="朝阳区",
detail_address="三里屯街道1号",
postal_code="100027"
)
# 创建订单
order = Order(
order_id="ORD202312250001",
customer_id="CUST001",
items=items,
shipping_address=address,
shipping_fee=15.0,
discount_amount=200.0
)
print("=== 订单创建成功 ===")
print(order.summary())
print(f"\n完整收货地址: {address.full_address()}")
# 模拟订单流程
print("\n=== 订单处理流程 ===")
print(f"是否可以取消: {order.can_cancel()}")
order.pay()
print(f"付款后状态: {order.status.value}")
order.ship()
print(f"发货后状态: {order.status.value}")
order.deliver()
print(f"送达后状态: {order.status.value}")
processing_days = order.get_processing_days()
if processing_days is not None:
print(f"处理用时: {processing_days}天")
print(f"\n订单最终摘要:\n{order.summary()}")常见错误
错误1:混淆可变和不可变默认值
from dataclasses import dataclass, field
from typing import List
@dataclass
class ProblematicClass:
"""有问题的数据类 - 展示了常见的陷阱"""
# 危险:所有实例共享同一个列表
tags: List[str] = [] # 错误!
# 正确:使用default_factory
# tags: List[str] = field(default_factory=list)
problem_class = ProblematicClass()
problem_class.tags.append("test")
another_instance = ProblematicClass()
print(f"另一个实例的tags: {another_instance.tags}") # 也会包含"test"!
# 正确的做法
@dataclass
class CorrectClass:
tags: List[str] = field(default_factory=list)
correct1 = CorrectClass()
correct1.tags.append("test")
correct2 = CorrectClass()
print(f"正确实现的另一个实例tags: {correct2.tags}") # 空的[]错误2:忘记类型注解
from dataclasses import dataclass
@dataclass
class MissingTypes:
"""缺少类型注解的数据类"""
# 错误:没有类型注解,dataclass可能无法正常工作
# name = "default_name" # 这不会被识别为字段
# 正确:必须有类型注解
name: str = "default_name"
age: int = 0
# dataclass需要明确的类型注解才能识别字段
obj = MissingTypes("test", 25)
print(obj)错误3:在frozen类中尝试修改字段
from dataclasses import dataclass
@dataclass(frozen=True)
class ImmutablePoint:
x: float
y: float
point = ImmutablePoint(1.0, 2.0)
# 错误:不能在frozen类中修改字段
# point.x = 3.0 # FrozenInstanceError
# 正确:创建新实例
new_point = ImmutablePoint(3.0, point.y)
print(f"新点: {new_point}")课后练习
创建一个
Book数据类,包含书名、作者、ISBN、价格、出版日期等字段,并实现图书信息的格式化显示功能。设计一个
UserProfile数据类,管理用户的个人信息、偏好设置和权限列表,支持配置的序列化和反序列化。实现一个
Rectangle数据类,表示矩形,包含计算面积、周长、对角线的方法,以及判断与其他矩形关系的功能。创建
Course数据类来管理课程信息,包括课程名称、教师、学生列表、上课时间等,并实现选课和退课的功能。设计一个
BankAccount数据类,实现账户余额管理、交易记录、冻结/解冻等功能,确保金额的准确性和安全性。
总结
数据类是Python现代编程中的重要特性:
- 大幅减少样板代码,提高开发效率
- 与类型提示完美结合,增强代码可维护性
- 支持灵活的配置选项和高级特性
- 特别适合配置管理、DTO、值对象等场景
- 通过
@dataclass装饰器轻松创建功能丰富的数据容器 - 与传统的手写类相比,更加简洁、安全、易读