dalunwen/GNN_GRU.py
2024-12-30 01:10:29 +08:00

221 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import GCNConv
import numpy as np
# 图神经网络模块
class GNNModule(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(GNNModule, self).__init__()
self.conv1 = GCNConv(input_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, output_dim)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
return x
# GRU时序网络模块
class GRUNetwork(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers):
super(GRUNetwork, self).__init__()
self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, 3) # 3个行为类别
def forward(self, x):
out, _ = self.gru(x)
out = self.fc(out[:, -1, :]) # 使用最后时刻的输出进行分类
return out
# 主网络结合GNN和GRU
class GNN_GRU_Network(nn.Module):
def __init__(self, input_dim, hidden_dim_gnn, hidden_dim_gru, num_layers_gru, num_keypoints=17):
super(GNN_GRU_Network, self).__init__()
self.gnn = GNNModule(input_dim=2, hidden_dim=hidden_dim_gnn, output_dim=hidden_dim_gnn) # 17个关键点的(x, y)
self.gru = GRUNetwork(input_dim=num_keypoints * hidden_dim_gnn, hidden_dim=hidden_dim_gru,
num_layers=num_layers_gru)
def forward(self, keypoints, edge_index, seq_length):
batch_size = keypoints.shape[0]
num_keypoints = keypoints.shape[1]
# 对每个时刻的关键点使用GNN提取空间特征
spatial_features = []
for t in range(seq_length):
x = keypoints[:, t, :].view(batch_size * num_keypoints, 2)
spatial_feature = self.gnn(x, edge_index) # (batch_size * num_keypoints, hidden_dim_gnn)
spatial_features.append(spatial_feature)
spatial_features = torch.stack(spatial_features, dim=1) # (batch_size * num_keypoints, seq_len, hidden_dim_gnn)
gru_input = spatial_features.view(batch_size, seq_length, num_keypoints * hidden_dim_gnn)
# 传递给GRU进行时序建模
output = self.gru(gru_input)
return output
# 数据集类
class KeypointDataset(Dataset):
def __init__(self, root_dir, behavior_labels, seq_length=5):
"""
:param root_dir: 包含三个行为文件夹的根目录
:param behavior_labels: 每个文件夹对应的行为标签
:param seq_length: 每个视频的时长每秒20帧
"""
self.root_dir = root_dir
self.behavior_labels = behavior_labels
self.seq_length = seq_length # 每个视频5秒假设每秒20帧
self.videos = []
self.labels = []
# 遍历根目录,获取每个视频路径和标签
for idx, behavior in enumerate(behavior_labels):
behavior_dir = os.path.join(root_dir, behavior)
for video_name in os.listdir(behavior_dir):
video_path = os.path.join(behavior_dir, video_name)
if os.path.isdir(video_path):
video_files = sorted(os.listdir(video_path)) # 按帧数排序文件
self.videos.append(video_files)
self.labels.append(idx) # 标签为行为类别索引
def __len__(self):
return len(self.videos)
def __getitem__(self, idx):
video_files = self.videos[idx]
label = self.labels[idx]
# 每个视频加载时,读取关键点数据
keypoints = []
for frame_file in video_files:
frame_path = os.path.join(self.root_dir, frame_file)
keypoints.append(self.load_keypoints(frame_path))
keypoints = np.array(keypoints) # shape: (num_frames, num_persons, 34) -> (num_frames, num_persons, 17*2)
return torch.tensor(keypoints, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
def load_keypoints(self, frame_path):
"""
读取每一帧的关键点数据每帧数据包含每个人的17个关键点
:param frame_path: 每帧的txt文件路径
:return: 该帧的所有人物的17个关键点坐标
"""
keypoints = []
with open(frame_path, 'r') as f:
for line in f.readlines():
data = list(map(float, line.strip().split()))
person_id = int(data[0])
frame_id = int(data[1])
coordinates = data[2:] # 后面的34个数值是17个关键点的(x, y)坐标
keypoints.append(coordinates)
return np.array(keypoints) # shape: (num_persons, 17*2)
# 训练过程
def train(model, train_loader, criterion, optimizer, device):
model.train()
total_loss = 0
for data, label in train_loader:
data = data.to(device)
label = label.to(device)
optimizer.zero_grad()
output = model(data) # 前向传播
loss = criterion(output, label) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
total_loss += loss.item()
return total_loss / len(train_loader)
# 验证过程
def evaluate(model, val_loader, criterion, device):
model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for data, label in val_loader:
data = data.to(device)
label = label.to(device)
output = model(data)
loss = criterion(output, label)
total_loss += loss.item()
_, predicted = torch.max(output, 1)
correct += (predicted == label).sum().item()
total += label.size(0)
accuracy = correct / total
return total_loss / len(val_loader), accuracy
# 主函数
def main():
root_dir = "/path/to/your/data" # 数据集根目录
behavior_labels = ["behavior1", "behavior2", "behavior3"]
# 创建数据集与加载器
train_dataset = KeypointDataset(root_dir, behavior_labels)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
# 模型设置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GNN_GRU_Network(input_dim=2, hidden_dim_gnn=64, hidden_dim_gru=128, num_layers_gru=2)
model.to(device)
# 损失函数与优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练与验证
for epoch in range(1, 101):
train_loss = train(model, train_loader, criterion, optimizer, device)
print(f"Epoch {epoch}, Train Loss: {train_loss:.4f}")
# 可选每个epoch结束时进行验证
# val_loss, val_accuracy = evaluate(model, val_loader, criterion, device)
# print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
if __name__ == '__main__':
main()
'''
代码说明:
数据集处理(KeypointDataset)
假设你的数据集文件夹结构如下:
data /
├── behavior1 /
│ ├── video1 /
│ │ ├── frame1.txt
│ │ ├── frame2.txt
│ │ └── ...
│ ├── video2 /
│ └── ...
├── behavior2 /
└── behavior3 /
每个视频文件夹包含若干个.txt
文件每个文件代表一帧记录了每个人的关键点位置17
个关键点,每个关键点的(x, y)
坐标)。
模型(GNN_GRU_Network)
使用
GNN
对每一帧的关键点进行空间特征提取。
然后将提取的空间特征传入
GRU
网络进行时序建模,最后通过
'''