12.29晚大量修改
This commit is contained in:
		
							parent
							
								
									b12a206d9e
								
							
						
					
					
						commit
						b4a3ad5a6c
					
				
							
								
								
									
										8
									
								
								.idea/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								.idea/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,8 @@
 | 
			
		||||
# 默认忽略的文件
 | 
			
		||||
/shelf/
 | 
			
		||||
/workspace.xml
 | 
			
		||||
# 基于编辑器的 HTTP 客户端请求
 | 
			
		||||
/httpRequests/
 | 
			
		||||
# Datasource local storage ignored files
 | 
			
		||||
/dataSources/
 | 
			
		||||
/dataSources.local.xml
 | 
			
		||||
							
								
								
									
										8
									
								
								.idea/dalunwen11.iml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								.idea/dalunwen11.iml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,8 @@
 | 
			
		||||
<?xml version="1.0" encoding="UTF-8"?>
 | 
			
		||||
<module type="PYTHON_MODULE" version="4">
 | 
			
		||||
  <component name="NewModuleRootManager">
 | 
			
		||||
    <content url="file://$MODULE_DIR$" />
 | 
			
		||||
    <orderEntry type="inheritedJdk" />
 | 
			
		||||
    <orderEntry type="sourceFolder" forTests="false" />
 | 
			
		||||
  </component>
 | 
			
		||||
</module>
 | 
			
		||||
							
								
								
									
										28
									
								
								.idea/deployment.xml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								.idea/deployment.xml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,28 @@
 | 
			
		||||
<?xml version="1.0" encoding="UTF-8"?>
 | 
			
		||||
<project version="4">
 | 
			
		||||
  <component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
 | 
			
		||||
    <serverData>
 | 
			
		||||
      <paths name="root@connect.beijinga.seetacloud.com:21749 password">
 | 
			
		||||
        <serverdata>
 | 
			
		||||
          <mappings>
 | 
			
		||||
            <mapping local="$PROJECT_DIR$" web="/" />
 | 
			
		||||
          </mappings>
 | 
			
		||||
        </serverdata>
 | 
			
		||||
      </paths>
 | 
			
		||||
      <paths name="root@connect.beijinga.seetacloud.com:21749 password (2)">
 | 
			
		||||
        <serverdata>
 | 
			
		||||
          <mappings>
 | 
			
		||||
            <mapping local="$PROJECT_DIR$" web="/" />
 | 
			
		||||
          </mappings>
 | 
			
		||||
        </serverdata>
 | 
			
		||||
      </paths>
 | 
			
		||||
      <paths name="root@connect.beijinga.seetacloud.com:21749 password (3)">
 | 
			
		||||
        <serverdata>
 | 
			
		||||
          <mappings>
 | 
			
		||||
            <mapping local="$PROJECT_DIR$" web="/" />
 | 
			
		||||
          </mappings>
 | 
			
		||||
        </serverdata>
 | 
			
		||||
      </paths>
 | 
			
		||||
    </serverData>
 | 
			
		||||
  </component>
 | 
			
		||||
</project>
 | 
			
		||||
							
								
								
									
										16
									
								
								.idea/inspectionProfiles/Project_Default.xml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								.idea/inspectionProfiles/Project_Default.xml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,16 @@
 | 
			
		||||
<component name="InspectionProjectProfileManager">
 | 
			
		||||
  <profile version="1.0">
 | 
			
		||||
    <option name="myName" value="Project Default" />
 | 
			
		||||
    <inspection_tool class="PyCompatibilityInspection" enabled="true" level="WARNING" enabled_by_default="true">
 | 
			
		||||
      <option name="ourVersions">
 | 
			
		||||
        <value>
 | 
			
		||||
          <list size="3">
 | 
			
		||||
            <item index="0" class="java.lang.String" itemvalue="2.7" />
 | 
			
		||||
            <item index="1" class="java.lang.String" itemvalue="3.12" />
 | 
			
		||||
            <item index="2" class="java.lang.String" itemvalue="3.5" />
 | 
			
		||||
          </list>
 | 
			
		||||
        </value>
 | 
			
		||||
      </option>
 | 
			
		||||
    </inspection_tool>
 | 
			
		||||
  </profile>
 | 
			
		||||
</component>
 | 
			
		||||
							
								
								
									
										6
									
								
								.idea/inspectionProfiles/profiles_settings.xml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								.idea/inspectionProfiles/profiles_settings.xml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,6 @@
 | 
			
		||||
<component name="InspectionProjectProfileManager">
 | 
			
		||||
  <settings>
 | 
			
		||||
    <option name="USE_PROJECT_PROFILE" value="false" />
 | 
			
		||||
    <version value="1.0" />
 | 
			
		||||
  </settings>
 | 
			
		||||
</component>
 | 
			
		||||
							
								
								
									
										7
									
								
								.idea/misc.xml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								.idea/misc.xml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,7 @@
 | 
			
		||||
<?xml version="1.0" encoding="UTF-8"?>
 | 
			
		||||
<project version="4">
 | 
			
		||||
  <component name="Black">
 | 
			
		||||
    <option name="sdkName" value="Python 3.9 (project)" />
 | 
			
		||||
  </component>
 | 
			
		||||
  <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (project)" project-jdk-type="Python SDK" />
 | 
			
		||||
</project>
 | 
			
		||||
							
								
								
									
										8
									
								
								.idea/modules.xml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								.idea/modules.xml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,8 @@
 | 
			
		||||
<?xml version="1.0" encoding="UTF-8"?>
 | 
			
		||||
<project version="4">
 | 
			
		||||
  <component name="ProjectModuleManager">
 | 
			
		||||
    <modules>
 | 
			
		||||
      <module fileurl="file://$PROJECT_DIR$/.idea/dalunwen11.iml" filepath="$PROJECT_DIR$/.idea/dalunwen11.iml" />
 | 
			
		||||
    </modules>
 | 
			
		||||
  </component>
 | 
			
		||||
</project>
 | 
			
		||||
							
								
								
									
										6
									
								
								.idea/vcs.xml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								.idea/vcs.xml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,6 @@
 | 
			
		||||
<?xml version="1.0" encoding="UTF-8"?>
 | 
			
		||||
<project version="4">
 | 
			
		||||
  <component name="VcsDirectoryMappings">
 | 
			
		||||
    <mapping directory="" vcs="Git" />
 | 
			
		||||
  </component>
 | 
			
		||||
</project>
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								GNN_GRU.docx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								GNN_GRU.docx
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										221
									
								
								GNN_GRU.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										221
									
								
								GNN_GRU.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,221 @@
 | 
			
		||||
import os
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from torch.utils.data import Dataset, DataLoader
 | 
			
		||||
from torch_geometric.nn import GCNConv
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 图神经网络模块
 | 
			
		||||
class GNNModule(nn.Module):
 | 
			
		||||
    def __init__(self, input_dim, hidden_dim, output_dim):
 | 
			
		||||
        super(GNNModule, self).__init__()
 | 
			
		||||
        self.conv1 = GCNConv(input_dim, hidden_dim)
 | 
			
		||||
        self.conv2 = GCNConv(hidden_dim, output_dim)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, edge_index):
 | 
			
		||||
        x = F.relu(self.conv1(x, edge_index))
 | 
			
		||||
        x = F.relu(self.conv2(x, edge_index))
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# GRU时序网络模块
 | 
			
		||||
class GRUNetwork(nn.Module):
 | 
			
		||||
    def __init__(self, input_dim, hidden_dim, num_layers):
 | 
			
		||||
        super(GRUNetwork, self).__init__()
 | 
			
		||||
        self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True)
 | 
			
		||||
        self.fc = nn.Linear(hidden_dim, 3)  # 3个行为类别
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        out, _ = self.gru(x)
 | 
			
		||||
        out = self.fc(out[:, -1, :])  # 使用最后时刻的输出进行分类
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 主网络(结合GNN和GRU)
 | 
			
		||||
class GNN_GRU_Network(nn.Module):
 | 
			
		||||
    def __init__(self, input_dim, hidden_dim_gnn, hidden_dim_gru, num_layers_gru, num_keypoints=17):
 | 
			
		||||
        super(GNN_GRU_Network, self).__init__()
 | 
			
		||||
        self.gnn = GNNModule(input_dim=2, hidden_dim=hidden_dim_gnn, output_dim=hidden_dim_gnn)  # 17个关键点的(x, y)
 | 
			
		||||
        self.gru = GRUNetwork(input_dim=num_keypoints * hidden_dim_gnn, hidden_dim=hidden_dim_gru,
 | 
			
		||||
                              num_layers=num_layers_gru)
 | 
			
		||||
 | 
			
		||||
    def forward(self, keypoints, edge_index, seq_length):
 | 
			
		||||
        batch_size = keypoints.shape[0]
 | 
			
		||||
        num_keypoints = keypoints.shape[1]
 | 
			
		||||
 | 
			
		||||
        # 对每个时刻的关键点使用GNN提取空间特征
 | 
			
		||||
        spatial_features = []
 | 
			
		||||
        for t in range(seq_length):
 | 
			
		||||
            x = keypoints[:, t, :].view(batch_size * num_keypoints, 2)
 | 
			
		||||
            spatial_feature = self.gnn(x, edge_index)  # (batch_size * num_keypoints, hidden_dim_gnn)
 | 
			
		||||
            spatial_features.append(spatial_feature)
 | 
			
		||||
 | 
			
		||||
        spatial_features = torch.stack(spatial_features, dim=1)  # (batch_size * num_keypoints, seq_len, hidden_dim_gnn)
 | 
			
		||||
 | 
			
		||||
        gru_input = spatial_features.view(batch_size, seq_length, num_keypoints * hidden_dim_gnn)
 | 
			
		||||
 | 
			
		||||
        # 传递给GRU进行时序建模
 | 
			
		||||
        output = self.gru(gru_input)
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 数据集类
 | 
			
		||||
class KeypointDataset(Dataset):
 | 
			
		||||
    def __init__(self, root_dir, behavior_labels, seq_length=5):
 | 
			
		||||
        """
 | 
			
		||||
        :param root_dir: 包含三个行为文件夹的根目录
 | 
			
		||||
        :param behavior_labels: 每个文件夹对应的行为标签
 | 
			
		||||
        :param seq_length: 每个视频的时长(秒),每秒20帧
 | 
			
		||||
        """
 | 
			
		||||
        self.root_dir = root_dir
 | 
			
		||||
        self.behavior_labels = behavior_labels
 | 
			
		||||
        self.seq_length = seq_length  # 每个视频5秒,假设每秒20帧
 | 
			
		||||
 | 
			
		||||
        self.videos = []
 | 
			
		||||
        self.labels = []
 | 
			
		||||
 | 
			
		||||
        # 遍历根目录,获取每个视频路径和标签
 | 
			
		||||
        for idx, behavior in enumerate(behavior_labels):
 | 
			
		||||
            behavior_dir = os.path.join(root_dir, behavior)
 | 
			
		||||
            for video_name in os.listdir(behavior_dir):
 | 
			
		||||
                video_path = os.path.join(behavior_dir, video_name)
 | 
			
		||||
                if os.path.isdir(video_path):
 | 
			
		||||
                    video_files = sorted(os.listdir(video_path))  # 按帧数排序文件
 | 
			
		||||
                    self.videos.append(video_files)
 | 
			
		||||
                    self.labels.append(idx)  # 标签为行为类别索引
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(self.videos)
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, idx):
 | 
			
		||||
        video_files = self.videos[idx]
 | 
			
		||||
        label = self.labels[idx]
 | 
			
		||||
 | 
			
		||||
        # 每个视频加载时,读取关键点数据
 | 
			
		||||
        keypoints = []
 | 
			
		||||
        for frame_file in video_files:
 | 
			
		||||
            frame_path = os.path.join(self.root_dir, frame_file)
 | 
			
		||||
            keypoints.append(self.load_keypoints(frame_path))
 | 
			
		||||
 | 
			
		||||
        keypoints = np.array(keypoints)  # shape: (num_frames, num_persons, 34) -> (num_frames, num_persons, 17*2)
 | 
			
		||||
        return torch.tensor(keypoints, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
 | 
			
		||||
 | 
			
		||||
    def load_keypoints(self, frame_path):
 | 
			
		||||
        """
 | 
			
		||||
        读取每一帧的关键点数据,每帧数据包含每个人的17个关键点
 | 
			
		||||
        :param frame_path: 每帧的txt文件路径
 | 
			
		||||
        :return: 该帧的所有人物的17个关键点坐标
 | 
			
		||||
        """
 | 
			
		||||
        keypoints = []
 | 
			
		||||
        with open(frame_path, 'r') as f:
 | 
			
		||||
            for line in f.readlines():
 | 
			
		||||
                data = list(map(float, line.strip().split()))
 | 
			
		||||
                person_id = int(data[0])
 | 
			
		||||
                frame_id = int(data[1])
 | 
			
		||||
                coordinates = data[2:]  # 后面的34个数值是17个关键点的(x, y)坐标
 | 
			
		||||
                keypoints.append(coordinates)
 | 
			
		||||
 | 
			
		||||
        return np.array(keypoints)  # shape: (num_persons, 17*2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 训练过程
 | 
			
		||||
def train(model, train_loader, criterion, optimizer, device):
 | 
			
		||||
    model.train()
 | 
			
		||||
    total_loss = 0
 | 
			
		||||
    for data, label in train_loader:
 | 
			
		||||
        data = data.to(device)
 | 
			
		||||
        label = label.to(device)
 | 
			
		||||
 | 
			
		||||
        optimizer.zero_grad()
 | 
			
		||||
        output = model(data)  # 前向传播
 | 
			
		||||
        loss = criterion(output, label)  # 计算损失
 | 
			
		||||
        loss.backward()  # 反向传播
 | 
			
		||||
        optimizer.step()  # 更新参数
 | 
			
		||||
 | 
			
		||||
        total_loss += loss.item()
 | 
			
		||||
 | 
			
		||||
    return total_loss / len(train_loader)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 验证过程
 | 
			
		||||
def evaluate(model, val_loader, criterion, device):
 | 
			
		||||
    model.eval()
 | 
			
		||||
    total_loss = 0
 | 
			
		||||
    correct = 0
 | 
			
		||||
    total = 0
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        for data, label in val_loader:
 | 
			
		||||
            data = data.to(device)
 | 
			
		||||
            label = label.to(device)
 | 
			
		||||
 | 
			
		||||
            output = model(data)
 | 
			
		||||
            loss = criterion(output, label)
 | 
			
		||||
            total_loss += loss.item()
 | 
			
		||||
 | 
			
		||||
            _, predicted = torch.max(output, 1)
 | 
			
		||||
            correct += (predicted == label).sum().item()
 | 
			
		||||
            total += label.size(0)
 | 
			
		||||
 | 
			
		||||
    accuracy = correct / total
 | 
			
		||||
    return total_loss / len(val_loader), accuracy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 主函数
 | 
			
		||||
def main():
 | 
			
		||||
    root_dir = "/path/to/your/data"  # 数据集根目录
 | 
			
		||||
    behavior_labels = ["behavior1", "behavior2", "behavior3"]
 | 
			
		||||
 | 
			
		||||
    # 创建数据集与加载器
 | 
			
		||||
    train_dataset = KeypointDataset(root_dir, behavior_labels)
 | 
			
		||||
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
 | 
			
		||||
 | 
			
		||||
    # 模型设置
 | 
			
		||||
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 | 
			
		||||
    model = GNN_GRU_Network(input_dim=2, hidden_dim_gnn=64, hidden_dim_gru=128, num_layers_gru=2)
 | 
			
		||||
    model.to(device)
 | 
			
		||||
 | 
			
		||||
    # 损失函数与优化器
 | 
			
		||||
    criterion = nn.CrossEntropyLoss()
 | 
			
		||||
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
 | 
			
		||||
 | 
			
		||||
    # 训练与验证
 | 
			
		||||
    for epoch in range(1, 101):
 | 
			
		||||
        train_loss = train(model, train_loader, criterion, optimizer, device)
 | 
			
		||||
        print(f"Epoch {epoch}, Train Loss: {train_loss:.4f}")
 | 
			
		||||
 | 
			
		||||
        # 可选:每个epoch结束时进行验证
 | 
			
		||||
        # val_loss, val_accuracy = evaluate(model, val_loader, criterion, device)
 | 
			
		||||
        # print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    main()
 | 
			
		||||
'''
 | 
			
		||||
代码说明:
 | 
			
		||||
 | 
			
		||||
数据集处理(KeypointDataset):
 | 
			
		||||
假设你的数据集文件夹结构如下:
 | 
			
		||||
data /
 | 
			
		||||
├── behavior1 /
 | 
			
		||||
│    ├── video1 /
 | 
			
		||||
│    │    ├── frame1.txt
 | 
			
		||||
│    │    ├── frame2.txt
 | 
			
		||||
│    │    └── ...
 | 
			
		||||
│    ├── video2 /
 | 
			
		||||
│    └── ...
 | 
			
		||||
├── behavior2 /
 | 
			
		||||
└── behavior3 /
 | 
			
		||||
每个视频文件夹包含若干个.txt
 | 
			
		||||
文件,每个文件代表一帧,记录了每个人的关键点位置(17
 | 
			
		||||
个关键点,每个关键点的(x, y)
 | 
			
		||||
坐标)。
 | 
			
		||||
模型(GNN_GRU_Network):
 | 
			
		||||
使用
 | 
			
		||||
GNN
 | 
			
		||||
对每一帧的关键点进行空间特征提取。
 | 
			
		||||
然后将提取的空间特征传入
 | 
			
		||||
GRU
 | 
			
		||||
网络进行时序建模,最后通过
 | 
			
		||||
'''
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								dlo7qcz3q.jpg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								dlo7qcz3q.jpg
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 67 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								lls开题报告.docx
									
									
									
									
									
								
							
							
						
						
									
										
											BIN
										
									
								
								lls开题报告.docx
									
									
									
									
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										0
									
								
								新建 文本文档.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								新建 文本文档.txt
									
									
									
									
									
										Normal file
									
								
							
		Loading…
	
		Reference in New Issue
	
	Block a user