python使用dataset快速使用SQLite
from torchdata.datapipes.iter import IterableWrapper
from torch.utils.data import Dataset
import sqlite3
class SQLiteDataset(Dataset):
def __init__(self, database, query):
self.database = database
self.query = query
self.conn = sqlite3.connect(self.database, isolation_level=None)
self.cur = self.conn.cursor()
self.cur.execute(self.query)
self.data = self.cur.fetchall()
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def __iter__(self):
return IterableWrapper(self.data)
def __del__(self):
self.conn.close()
# 使用示例
database_path = 'path_to_your_sqlite_database.db'
query = 'SELECT * FROM your_table_name'
dataset = SQLiteDataset(database_path, query)
# 现在可以像使用其他PyTorch Dataset一样使用 `dataset`
这个示例代码定义了一个名为SQLiteDataset
的类,它允许用户使用SQL查询从SQLite数据库中创建一个可迭代的数据集。这个类实现了PyTorch Dataset
的基本方法,包括初始化连接数据库、执行查询、获取长度和数据项。在实例化SQLiteDataset
时,只需传入数据库路径和要执行的SQL查询字符串。这个类在实例化后可以像其他PyTorch数据集一样使用,例如用于模型的数据提供。
评论已关闭