观赏鱼健康系统v1.0

This commit is contained in:
2025-12-23 09:30:56 +08:00
parent ff6f40c972
commit ae9d252255
+73
View File
@@ -0,0 +1,73 @@
# fish_tracker.py
import cv2
import numpy as np
from collections import defaultdict, deque
from ultralytics import YOLO
from deep_sort_realtime.deepsort_tracker import DeepSort
class FishTracker:
def __init__(self, model_path="yolov8n.pt", conf_thresh=0.5):
self.model = YOLO(model_path)
self.tracker = DeepSort(max_age=30, n_init=3, nn_budget=100)
self.conf_thresh = conf_thresh
self.trajectories = defaultdict(lambda: deque(maxlen=30)) # 存最近30帧轨迹
self.alerts = []
def process_frame(self, frame):
self.alerts.clear()
h, w = frame.shape[:2]
# 1. YOLO 检测(只检测 "fish" 类别,COCO 中类别 ID=17
results = self.model(frame, verbose=False)[0]
detections = []
for box in results.boxes:
if int(box.cls.item()) == 17 and box.conf.item() > self.conf_thresh: # COCO: fish=17
x1, y1, x2, y2 = map(int, box.xyxy[0])
conf = float(box.conf.item())
detections.append([[x1, y1, x2 - x1, y2 - y1], conf, "fish"])
# 2. DeepSORT 追踪
tracks = self.tracker.update_tracks(detections, frame=frame)
# 3. 更新轨迹 & 绘制
output_frame = frame.copy()
for track in tracks:
if not track.is_confirmed():
continue
track_id = track.track_id
ltrb = track.to_ltrb()
x1, y1, x2, y2 = map(int, ltrb)
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
# 记录轨迹
self.trajectories[track_id].append((cx, cy, cv2.getTickCount()))
# 计算速度(像素/秒)
traj = self.trajectories[track_id]
if len(traj) >= 2:
total_dist = 0
total_time = 0
for i in range(1, len(traj)):
dx = traj[i][0] - traj[i-1][0]
dy = traj[i][1] - traj[i-1][1]
dist = np.sqrt(dx*dx + dy*dy)
dt = (traj[i][2] - traj[i-1][2]) / cv2.getTickFrequency()
total_dist += dist
total_time += dt
speed = total_dist / total_time if total_time > 0 else 0
else:
speed = 0
# 判断是否沉底(底部 20% 区域)
is_bottom = cy > h * 0.8
# 异常规则
if speed < 15 and is_bottom:
self.alerts.append(f"⚠️ 鱼 {track_id}: 低速({speed:.1f}px/s) + 沉底")
# 绘制框和ID
cv2.rectangle(output_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(output_frame, f"Fish-{track_id}", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
return output_frame, self.alerts