import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import GCNConv
import numpy as np


# 图神经网络模块
class GNNModule(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNNModule, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        return x


# GRU时序网络模块
class GRUNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super(GRUNetwork, self).__init__()
        self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 3)  # 3个行为类别

    def forward(self, x):
        out, _ = self.gru(x)
        out = self.fc(out[:, -1, :])  # 使用最后时刻的输出进行分类
        return out


# 主网络(结合GNN和GRU)
class GNN_GRU_Network(nn.Module):
    def __init__(self, input_dim, hidden_dim_gnn, hidden_dim_gru, num_layers_gru, num_keypoints=17):
        super(GNN_GRU_Network, self).__init__()
        self.gnn = GNNModule(input_dim=2, hidden_dim=hidden_dim_gnn, output_dim=hidden_dim_gnn)  # 17个关键点的(x, y)
        self.gru = GRUNetwork(input_dim=num_keypoints * hidden_dim_gnn, hidden_dim=hidden_dim_gru,
                              num_layers=num_layers_gru)

    def forward(self, keypoints, edge_index, seq_length):
        batch_size = keypoints.shape[0]
        num_keypoints = keypoints.shape[1]

        # 对每个时刻的关键点使用GNN提取空间特征
        spatial_features = []
        for t in range(seq_length):
            x = keypoints[:, t, :].view(batch_size * num_keypoints, 2)
            spatial_feature = self.gnn(x, edge_index)  # (batch_size * num_keypoints, hidden_dim_gnn)
            spatial_features.append(spatial_feature)

        spatial_features = torch.stack(spatial_features, dim=1)  # (batch_size * num_keypoints, seq_len, hidden_dim_gnn)

        gru_input = spatial_features.view(batch_size, seq_length, num_keypoints * hidden_dim_gnn)

        # 传递给GRU进行时序建模
        output = self.gru(gru_input)
        return output


# 数据集类
class KeypointDataset(Dataset):
    def __init__(self, root_dir, behavior_labels, seq_length=5):
        """
        :param root_dir: 包含三个行为文件夹的根目录
        :param behavior_labels: 每个文件夹对应的行为标签
        :param seq_length: 每个视频的时长(秒),每秒20帧
        """
        self.root_dir = root_dir
        self.behavior_labels = behavior_labels
        self.seq_length = seq_length  # 每个视频5秒,假设每秒20帧

        self.videos = []
        self.labels = []

        # 遍历根目录,获取每个视频路径和标签
        for idx, behavior in enumerate(behavior_labels):
            behavior_dir = os.path.join(root_dir, behavior)
            for video_name in os.listdir(behavior_dir):
                video_path = os.path.join(behavior_dir, video_name)
                if os.path.isdir(video_path):
                    video_files = sorted(os.listdir(video_path))  # 按帧数排序文件
                    self.videos.append(video_files)
                    self.labels.append(idx)  # 标签为行为类别索引

    def __len__(self):
        return len(self.videos)

    def __getitem__(self, idx):
        video_files = self.videos[idx]
        label = self.labels[idx]

        # 每个视频加载时,读取关键点数据
        keypoints = []
        for frame_file in video_files:
            frame_path = os.path.join(self.root_dir, frame_file)
            keypoints.append(self.load_keypoints(frame_path))

        keypoints = np.array(keypoints)  # shape: (num_frames, num_persons, 34) -> (num_frames, num_persons, 17*2)
        return torch.tensor(keypoints, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

    def load_keypoints(self, frame_path):
        """
        读取每一帧的关键点数据,每帧数据包含每个人的17个关键点
        :param frame_path: 每帧的txt文件路径
        :return: 该帧的所有人物的17个关键点坐标
        """
        keypoints = []
        with open(frame_path, 'r') as f:
            for line in f.readlines():
                data = list(map(float, line.strip().split()))
                person_id = int(data[0])
                frame_id = int(data[1])
                coordinates = data[2:]  # 后面的34个数值是17个关键点的(x, y)坐标
                keypoints.append(coordinates)

        return np.array(keypoints)  # shape: (num_persons, 17*2)


# 训练过程
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for data, label in train_loader:
        data = data.to(device)
        label = label.to(device)

        optimizer.zero_grad()
        output = model(data)  # 前向传播
        loss = criterion(output, label)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

        total_loss += loss.item()

    return total_loss / len(train_loader)


# 验证过程
def evaluate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, label in val_loader:
            data = data.to(device)
            label = label.to(device)

            output = model(data)
            loss = criterion(output, label)
            total_loss += loss.item()

            _, predicted = torch.max(output, 1)
            correct += (predicted == label).sum().item()
            total += label.size(0)

    accuracy = correct / total
    return total_loss / len(val_loader), accuracy


# 主函数
def main():
    root_dir = "/path/to/your/data"  # 数据集根目录
    behavior_labels = ["behavior1", "behavior2", "behavior3"]

    # 创建数据集与加载器
    train_dataset = KeypointDataset(root_dir, behavior_labels)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

    # 模型设置
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GNN_GRU_Network(input_dim=2, hidden_dim_gnn=64, hidden_dim_gru=128, num_layers_gru=2)
    model.to(device)

    # 损失函数与优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # 训练与验证
    for epoch in range(1, 101):
        train_loss = train(model, train_loader, criterion, optimizer, device)
        print(f"Epoch {epoch}, Train Loss: {train_loss:.4f}")

        # 可选:每个epoch结束时进行验证
        # val_loss, val_accuracy = evaluate(model, val_loader, criterion, device)
        # print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")


if __name__ == '__main__':
    main()
'''
代码说明:

数据集处理(KeypointDataset):
假设你的数据集文件夹结构如下:
data /
├── behavior1 /
│    ├── video1 /
│    │    ├── frame1.txt
│    │    ├── frame2.txt
│    │    └── ...
│    ├── video2 /
│    └── ...
├── behavior2 /
└── behavior3 /
每个视频文件夹包含若干个.txt
文件,每个文件代表一帧,记录了每个人的关键点位置(17
个关键点,每个关键点的(x, y)
坐标)。
模型(GNN_GRU_Network):
使用
GNN
对每一帧的关键点进行空间特征提取。
然后将提取的空间特征传入
GRU
网络进行时序建模,最后通过
'''