Files
fish_monitor/fish_tracker.py
T
panda 1a135cdda7 v1.1
支持本地视频测试
2025-12-23 17:59:18 +08:00

87 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# fish_tracker.py
import cv2
import numpy as np
from collections import defaultdict, deque
from ultralytics import YOLO
from deep_sort_realtime.deepsort_tracker import DeepSort
class FishTracker:
def __init__(self, model_path="yolov8n.pt", conf_thresh=0.5):
self.model = YOLO(model_path)
self.tracker = DeepSort(max_age=30, n_init=3, nn_budget=100)
self.conf_thresh = conf_thresh
self.trajectories = defaultdict(lambda: deque(maxlen=30)) # 存最近30帧轨迹
self.alerts = []
def process_frame(self, frame):
"""返回处理后帧、告警列表、每条鱼的统计信息"""
self.alerts.clear()
h, w = frame.shape[:2]
# 1. YOLO 检测(只检测 "fish" 类别,COCO 中类别 ID=17
results = self.model(frame, verbose=False)[0]
detections = []
for box in results.boxes:
if int(box.cls.item()) == 17 and box.conf.item() > self.conf_thresh: # COCO: fish=17
x1, y1, x2, y2 = map(int, box.xyxy[0])
conf = float(box.conf.item())
detections.append([[x1, y1, x2 - x1, y2 - y1], conf, "fish"])
# 2. DeepSORT 追踪
tracks = self.tracker.update_tracks(detections, frame=frame)
# 3. 更新轨迹 & 绘制
output_frame = frame.copy()
track_stats = []
for track in tracks:
if not track.is_confirmed():
continue
track_id = track.track_id
ltrb = track.to_ltrb()
x1, y1, x2, y2 = map(int, ltrb)
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
# 记录轨迹
self.trajectories[track_id].append((cx, cy, cv2.getTickCount()))
# 计算速度(像素/秒)
traj = self.trajectories[track_id]
step_dist = 0
if len(traj) >= 2:
total_dist = 0
total_time = 0
for i in range(1, len(traj)):
dx = traj[i][0] - traj[i-1][0]
dy = traj[i][1] - traj[i-1][1]
dist = np.sqrt(dx*dx + dy*dy)
if i == len(traj) - 1:
step_dist = dist
dt = (traj[i][2] - traj[i-1][2]) / cv2.getTickFrequency()
total_dist += dist
total_time += dt
speed = total_dist / total_time if total_time > 0 else 0
else:
speed = 0
# 判断是否沉底(底部 20% 区域)
is_bottom = cy > h * 0.8
# 异常规则
is_slow = speed < 15
if speed < 15 and is_bottom:
self.alerts.append(f"⚠️ 鱼 {track_id}: 低速({speed:.1f}px/s) + 沉底")
track_stats.append({
"id": track_id,
"speed": speed,
"is_bottom": is_bottom,
"is_slow": is_slow,
"distance_delta": step_dist,
})
# 绘制框和ID
cv2.rectangle(output_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(output_frame, f"Fish-{track_id}", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
return output_frame, self.alerts, track_stats