12.29晚大量修改
This commit is contained in:
parent
b12a206d9e
commit
b4a3ad5a6c
8
.idea/.gitignore
vendored
Normal file
8
.idea/.gitignore
vendored
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# 默认忽略的文件
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# 基于编辑器的 HTTP 客户端请求
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
8
.idea/dalunwen11.iml
Normal file
8
.idea/dalunwen11.iml
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="inheritedJdk" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
28
.idea/deployment.xml
Normal file
28
.idea/deployment.xml
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||||
|
<serverData>
|
||||||
|
<paths name="root@connect.beijinga.seetacloud.com:21749 password">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="root@connect.beijinga.seetacloud.com:21749 password (2)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="root@connect.beijinga.seetacloud.com:21749 password (3)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
</serverData>
|
||||||
|
</component>
|
||||||
|
</project>
|
16
.idea/inspectionProfiles/Project_Default.xml
Normal file
16
.idea/inspectionProfiles/Project_Default.xml
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="PyCompatibilityInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||||
|
<option name="ourVersions">
|
||||||
|
<value>
|
||||||
|
<list size="3">
|
||||||
|
<item index="0" class="java.lang.String" itemvalue="2.7" />
|
||||||
|
<item index="1" class="java.lang.String" itemvalue="3.12" />
|
||||||
|
<item index="2" class="java.lang.String" itemvalue="3.5" />
|
||||||
|
</list>
|
||||||
|
</value>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
</profile>
|
||||||
|
</component>
|
6
.idea/inspectionProfiles/profiles_settings.xml
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
7
.idea/misc.xml
Normal file
7
.idea/misc.xml
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="Python 3.9 (project)" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (project)" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
8
.idea/modules.xml
Normal file
8
.idea/modules.xml
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/dalunwen11.iml" filepath="$PROJECT_DIR$/.idea/dalunwen11.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
6
.idea/vcs.xml
Normal file
6
.idea/vcs.xml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
BIN
GNN_GRU.docx
Normal file
BIN
GNN_GRU.docx
Normal file
Binary file not shown.
221
GNN_GRU.py
Normal file
221
GNN_GRU.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torch_geometric.nn import GCNConv
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# 图神经网络模块
|
||||||
|
class GNNModule(nn.Module):
|
||||||
|
def __init__(self, input_dim, hidden_dim, output_dim):
|
||||||
|
super(GNNModule, self).__init__()
|
||||||
|
self.conv1 = GCNConv(input_dim, hidden_dim)
|
||||||
|
self.conv2 = GCNConv(hidden_dim, output_dim)
|
||||||
|
|
||||||
|
def forward(self, x, edge_index):
|
||||||
|
x = F.relu(self.conv1(x, edge_index))
|
||||||
|
x = F.relu(self.conv2(x, edge_index))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# GRU时序网络模块
|
||||||
|
class GRUNetwork(nn.Module):
|
||||||
|
def __init__(self, input_dim, hidden_dim, num_layers):
|
||||||
|
super(GRUNetwork, self).__init__()
|
||||||
|
self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True)
|
||||||
|
self.fc = nn.Linear(hidden_dim, 3) # 3个行为类别
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out, _ = self.gru(x)
|
||||||
|
out = self.fc(out[:, -1, :]) # 使用最后时刻的输出进行分类
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# 主网络(结合GNN和GRU)
|
||||||
|
class GNN_GRU_Network(nn.Module):
|
||||||
|
def __init__(self, input_dim, hidden_dim_gnn, hidden_dim_gru, num_layers_gru, num_keypoints=17):
|
||||||
|
super(GNN_GRU_Network, self).__init__()
|
||||||
|
self.gnn = GNNModule(input_dim=2, hidden_dim=hidden_dim_gnn, output_dim=hidden_dim_gnn) # 17个关键点的(x, y)
|
||||||
|
self.gru = GRUNetwork(input_dim=num_keypoints * hidden_dim_gnn, hidden_dim=hidden_dim_gru,
|
||||||
|
num_layers=num_layers_gru)
|
||||||
|
|
||||||
|
def forward(self, keypoints, edge_index, seq_length):
|
||||||
|
batch_size = keypoints.shape[0]
|
||||||
|
num_keypoints = keypoints.shape[1]
|
||||||
|
|
||||||
|
# 对每个时刻的关键点使用GNN提取空间特征
|
||||||
|
spatial_features = []
|
||||||
|
for t in range(seq_length):
|
||||||
|
x = keypoints[:, t, :].view(batch_size * num_keypoints, 2)
|
||||||
|
spatial_feature = self.gnn(x, edge_index) # (batch_size * num_keypoints, hidden_dim_gnn)
|
||||||
|
spatial_features.append(spatial_feature)
|
||||||
|
|
||||||
|
spatial_features = torch.stack(spatial_features, dim=1) # (batch_size * num_keypoints, seq_len, hidden_dim_gnn)
|
||||||
|
|
||||||
|
gru_input = spatial_features.view(batch_size, seq_length, num_keypoints * hidden_dim_gnn)
|
||||||
|
|
||||||
|
# 传递给GRU进行时序建模
|
||||||
|
output = self.gru(gru_input)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
# 数据集类
|
||||||
|
class KeypointDataset(Dataset):
|
||||||
|
def __init__(self, root_dir, behavior_labels, seq_length=5):
|
||||||
|
"""
|
||||||
|
:param root_dir: 包含三个行为文件夹的根目录
|
||||||
|
:param behavior_labels: 每个文件夹对应的行为标签
|
||||||
|
:param seq_length: 每个视频的时长(秒),每秒20帧
|
||||||
|
"""
|
||||||
|
self.root_dir = root_dir
|
||||||
|
self.behavior_labels = behavior_labels
|
||||||
|
self.seq_length = seq_length # 每个视频5秒,假设每秒20帧
|
||||||
|
|
||||||
|
self.videos = []
|
||||||
|
self.labels = []
|
||||||
|
|
||||||
|
# 遍历根目录,获取每个视频路径和标签
|
||||||
|
for idx, behavior in enumerate(behavior_labels):
|
||||||
|
behavior_dir = os.path.join(root_dir, behavior)
|
||||||
|
for video_name in os.listdir(behavior_dir):
|
||||||
|
video_path = os.path.join(behavior_dir, video_name)
|
||||||
|
if os.path.isdir(video_path):
|
||||||
|
video_files = sorted(os.listdir(video_path)) # 按帧数排序文件
|
||||||
|
self.videos.append(video_files)
|
||||||
|
self.labels.append(idx) # 标签为行为类别索引
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.videos)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
video_files = self.videos[idx]
|
||||||
|
label = self.labels[idx]
|
||||||
|
|
||||||
|
# 每个视频加载时,读取关键点数据
|
||||||
|
keypoints = []
|
||||||
|
for frame_file in video_files:
|
||||||
|
frame_path = os.path.join(self.root_dir, frame_file)
|
||||||
|
keypoints.append(self.load_keypoints(frame_path))
|
||||||
|
|
||||||
|
keypoints = np.array(keypoints) # shape: (num_frames, num_persons, 34) -> (num_frames, num_persons, 17*2)
|
||||||
|
return torch.tensor(keypoints, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
|
||||||
|
|
||||||
|
def load_keypoints(self, frame_path):
|
||||||
|
"""
|
||||||
|
读取每一帧的关键点数据,每帧数据包含每个人的17个关键点
|
||||||
|
:param frame_path: 每帧的txt文件路径
|
||||||
|
:return: 该帧的所有人物的17个关键点坐标
|
||||||
|
"""
|
||||||
|
keypoints = []
|
||||||
|
with open(frame_path, 'r') as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
data = list(map(float, line.strip().split()))
|
||||||
|
person_id = int(data[0])
|
||||||
|
frame_id = int(data[1])
|
||||||
|
coordinates = data[2:] # 后面的34个数值是17个关键点的(x, y)坐标
|
||||||
|
keypoints.append(coordinates)
|
||||||
|
|
||||||
|
return np.array(keypoints) # shape: (num_persons, 17*2)
|
||||||
|
|
||||||
|
|
||||||
|
# 训练过程
|
||||||
|
def train(model, train_loader, criterion, optimizer, device):
|
||||||
|
model.train()
|
||||||
|
total_loss = 0
|
||||||
|
for data, label in train_loader:
|
||||||
|
data = data.to(device)
|
||||||
|
label = label.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
output = model(data) # 前向传播
|
||||||
|
loss = criterion(output, label) # 计算损失
|
||||||
|
loss.backward() # 反向传播
|
||||||
|
optimizer.step() # 更新参数
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
return total_loss / len(train_loader)
|
||||||
|
|
||||||
|
|
||||||
|
# 验证过程
|
||||||
|
def evaluate(model, val_loader, criterion, device):
|
||||||
|
model.eval()
|
||||||
|
total_loss = 0
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for data, label in val_loader:
|
||||||
|
data = data.to(device)
|
||||||
|
label = label.to(device)
|
||||||
|
|
||||||
|
output = model(data)
|
||||||
|
loss = criterion(output, label)
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
_, predicted = torch.max(output, 1)
|
||||||
|
correct += (predicted == label).sum().item()
|
||||||
|
total += label.size(0)
|
||||||
|
|
||||||
|
accuracy = correct / total
|
||||||
|
return total_loss / len(val_loader), accuracy
|
||||||
|
|
||||||
|
|
||||||
|
# 主函数
|
||||||
|
def main():
|
||||||
|
root_dir = "/path/to/your/data" # 数据集根目录
|
||||||
|
behavior_labels = ["behavior1", "behavior2", "behavior3"]
|
||||||
|
|
||||||
|
# 创建数据集与加载器
|
||||||
|
train_dataset = KeypointDataset(root_dir, behavior_labels)
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
|
||||||
|
|
||||||
|
# 模型设置
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
model = GNN_GRU_Network(input_dim=2, hidden_dim_gnn=64, hidden_dim_gru=128, num_layers_gru=2)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# 损失函数与优化器
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
# 训练与验证
|
||||||
|
for epoch in range(1, 101):
|
||||||
|
train_loss = train(model, train_loader, criterion, optimizer, device)
|
||||||
|
print(f"Epoch {epoch}, Train Loss: {train_loss:.4f}")
|
||||||
|
|
||||||
|
# 可选:每个epoch结束时进行验证
|
||||||
|
# val_loss, val_accuracy = evaluate(model, val_loader, criterion, device)
|
||||||
|
# print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
'''
|
||||||
|
代码说明:
|
||||||
|
|
||||||
|
数据集处理(KeypointDataset):
|
||||||
|
假设你的数据集文件夹结构如下:
|
||||||
|
data /
|
||||||
|
├── behavior1 /
|
||||||
|
│ ├── video1 /
|
||||||
|
│ │ ├── frame1.txt
|
||||||
|
│ │ ├── frame2.txt
|
||||||
|
│ │ └── ...
|
||||||
|
│ ├── video2 /
|
||||||
|
│ └── ...
|
||||||
|
├── behavior2 /
|
||||||
|
└── behavior3 /
|
||||||
|
每个视频文件夹包含若干个.txt
|
||||||
|
文件,每个文件代表一帧,记录了每个人的关键点位置(17
|
||||||
|
个关键点,每个关键点的(x, y)
|
||||||
|
坐标)。
|
||||||
|
模型(GNN_GRU_Network):
|
||||||
|
使用
|
||||||
|
GNN
|
||||||
|
对每一帧的关键点进行空间特征提取。
|
||||||
|
然后将提取的空间特征传入
|
||||||
|
GRU
|
||||||
|
网络进行时序建模,最后通过
|
||||||
|
'''
|
BIN
dlo7qcz3q.jpg
Normal file
BIN
dlo7qcz3q.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 67 KiB |
BIN
lls开题报告.docx
BIN
lls开题报告.docx
Binary file not shown.
0
新建 文本文档.txt
Normal file
0
新建 文本文档.txt
Normal file
Loading…
Reference in New Issue
Block a user