Files
fish_monitor/Zero2YoloYard-main/ultralytics_sam_tasks.py
T
2025-12-25 17:41:18 +08:00

291 lines
12 KiB
Python

import logging
import os
import torch
import numpy as np
import cv2
import uuid
try:
from ultralytics import SAM
from ultralytics.models.sam import SAM2VideoPredictor
from ultralytics.engine.results import Results
except ImportError:
logging.critical("FATAL: ultralytics library is not installed. Please run 'pip install ultralytics'.")
SAM = None
SAM2VideoPredictor = None
Results = None
import config
import database
import file_storage
from bbox_writer import convert_text_to_rects_and_labels
import settings_manager
_sam_model_cache = {"model": None, "path": None}
def get_sam_model():
global _sam_model_cache
DEVICE = settings_manager.get_device()
settings = settings_manager.load_settings()
checkpoint_filename = settings.get('sam_model_checkpoint', 'sam2.1_t.pt')
sam_checkpoint_path = os.path.join(config.BASE_DIR, "checkpoints", checkpoint_filename)
if _sam_model_cache["path"] != sam_checkpoint_path or \
(_sam_model_cache["model"] is not None and str(_sam_model_cache["model"].device) != str(DEVICE)):
logging.info(f"Model/device change detected. Reloading SAM model to {DEVICE}. New model: {checkpoint_filename}")
_sam_model_cache["model"] = None
_sam_model_cache["path"] = None
if _sam_model_cache["model"] is not None:
return _sam_model_cache["model"]
if SAM is None:
logging.error("Ultralytics SAM class is not available due to import error.")
return None
if not os.path.exists(sam_checkpoint_path):
logging.error(f"Ultralytics SAM checkpoint not found at {sam_checkpoint_path}. All SAM features are disabled.")
return None
try:
logging.info(f"Loading Ultralytics SAM model ('{checkpoint_filename}') to device '{DEVICE}'...")
model = SAM(sam_checkpoint_path)
model.to(DEVICE)
_sam_model_cache["model"] = model
_sam_model_cache["path"] = sam_checkpoint_path
logging.info("Ultralytics SAM model loaded successfully.")
except Exception as e:
logging.error(f"Failed to load Ultralytics SAM model: {e}", exc_info=True)
return None
return _sam_model_cache["model"]
def predict_box_from_point_ultralytics(image_path, point_coords):
model = get_sam_model()
if model is None:
raise RuntimeError("Ultralytics SAM model is not available.")
results = model(image_path, points=[point_coords], labels=[1])
if results and results[0].boxes and results[0].boxes.xyxy.numel() > 0:
box = results[0].boxes.xyxy[0].cpu().numpy()
x1, y1, x2, y2 = map(int, box)
return {'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2}
return None
def _get_bbox_from_mask(mask_data, original_width, original_height):
if mask_data is None:
return None
if isinstance(mask_data, torch.Tensor):
mask_data = mask_data.cpu().numpy()
mask = (mask_data * 255).astype(np.uint8)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
largest_contour = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(largest_contour)
x1 = max(0, x)
y1 = max(0, y)
x2 = min(original_width, x + w)
y2 = min(original_height, y + h)
if x2 > x1 and y2 > y1:
return [x1, y1, x2, y2]
return None
def track_video_ultralytics(video_uuid, start_frame, end_frame, init_bboxes_text, session):
model = get_sam_model()
if model is None:
raise RuntimeError("Ultralytics SAM model is not available for tracking.")
video_info = database.get_video_entity(video_uuid)
original_width = video_info['width']
original_height = video_info['height']
init_rects, init_labels, init_ids = convert_text_to_rects_and_labels(init_bboxes_text)
if not init_rects:
raise ValueError("No initial bounding boxes provided for tracking.")
tracked_objects = {
(init_ids[i] or f"obj_{i}"): {"label": init_labels[i], "bbox": init_rects[i]}
for i in range(len(init_rects))
}
session['results'][start_frame] = init_bboxes_text
session['progress'] = 1
for current_frame_num in range(start_frame + 1, end_frame + 1):
if session.get('stop_requested', False):
logging.info(f"Tracking for {video_uuid} stopped by user request.")
session['status'] = 'STOPPED'
break
frame_path = file_storage.get_frame_path(video_uuid, current_frame_num)
if not os.path.exists(frame_path):
logging.warning(f"Frame {current_frame_num} not found, skipping.")
continue
prompts_bboxes_list = [obj_data['bbox'] for obj_id, obj_data in tracked_objects.items()]
original_ids = list(tracked_objects.keys())
if not prompts_bboxes_list:
logging.warning(f"Lost all objects at frame {current_frame_num}. Stopping tracking.")
break
prompts_bboxes_np = np.array(prompts_bboxes_list)
results = model(frame_path, bboxes=prompts_bboxes_np)
if isinstance(results, Results):
results = [results]
new_tracked_objects = {}
current_frame_bboxes_text_lines = []
if results and results[0].masks:
new_masks = results[0].masks.data
if len(new_masks) != len(prompts_bboxes_list):
logging.warning(
f"Frame {current_frame_num}: Mismatch between prompted boxes ({len(prompts_bboxes_list)}) and returned masks ({len(new_masks)}). Using previous frame's boxes.")
for obj_id, obj_data in tracked_objects.items():
x1, y1, x2, y2 = obj_data['bbox']
current_frame_bboxes_text_lines.append(f"{x1},{y1},{x2},{y2},{obj_data['label']},{obj_id}")
new_tracked_objects = tracked_objects
else:
for i, new_mask in enumerate(new_masks):
new_bbox = _get_bbox_from_mask(new_mask, original_width, original_height)
if new_bbox:
original_id = original_ids[i]
label = tracked_objects[original_id]['label']
x1, y1, x2, y2 = new_bbox
current_frame_bboxes_text_lines.append(f"{x1},{y1},{x2},{y2},{label},{original_id}")
new_tracked_objects[original_id] = {"label": label, "bbox": new_bbox}
tracked_objects = new_tracked_objects
session['results'][current_frame_num] = "\n".join(current_frame_bboxes_text_lines)
session['progress'] = (current_frame_num - start_frame) + 1
if 'status' not in session or session['status'] == 'PROCESSING':
session['status'] = 'COMPLETED'
def run_batch_tracking_with_predictor(video_uuid, start_frame, end_frame, init_bboxes_text, session):
if SAM2VideoPredictor is None:
raise ImportError("SAM2VideoPredictor could not be imported. Please check your ultralytics installation.")
if get_sam_model() is None:
raise RuntimeError("SAM model is not available for batch tracking.")
settings = settings_manager.load_settings()
model_checkpoint_filename = settings.get('sam_model_checkpoint', 'sam2.1_t.pt')
model_absolute_path = os.path.join(config.BASE_DIR, "checkpoints", model_checkpoint_filename)
if not os.path.exists(model_absolute_path):
raise FileNotFoundError(f"Batch tracking model not found at path: {model_absolute_path}")
video_info = database.get_video_entity(video_uuid)
width, height, fps = video_info['width'], video_info['height'], video_info['fps']
if not fps or fps <= 0:
fps = 30
logging.warning(f"Video {video_uuid} has invalid FPS, falling back to {fps}.")
init_rects, init_labels, _ = convert_text_to_rects_and_labels(init_bboxes_text)
if not init_rects:
raise ValueError("No initial bounding boxes provided for tracking.")
all_frame_results = {start_frame: init_bboxes_text}
last_known_rects = init_rects
total_frames_to_process = end_frame - start_frame
imgsz = settings.get('batch_tracking_imgsz', 1024)
conf = settings.get('batch_tracking_conf', 0.30)
chunk_size = settings.get('batch_tracking_chunk_size', 10)
device = str(settings_manager.get_device())
for i in range(0, total_frames_to_process, chunk_size):
chunk_start_frame = start_frame + i
chunk_end_frame = min(start_frame + i + chunk_size - 1, end_frame)
if chunk_start_frame > end_frame:
break
logging.info(f"Processing chunk: frames {chunk_start_frame} to {chunk_end_frame}")
temp_video_filename = f"temp_chunk_{uuid.uuid4().hex}.mp4"
temp_video_path = os.path.join(config.STORAGE_DIR, 'videos', temp_video_filename)
os.makedirs(os.path.dirname(temp_video_path), exist_ok=True)
video_writer = cv2.VideoWriter(temp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
if not video_writer.isOpened():
raise IOError(f"Failed to create temporary video writer for chunk {chunk_start_frame}-{chunk_end_frame}.")
predictor = None
try:
for frame_num in range(chunk_start_frame, chunk_end_frame + 1):
frame_path = file_storage.get_frame_path(video_uuid, frame_num)
if os.path.exists(frame_path):
img = cv2.imread(frame_path)
if img.shape[1] != width or img.shape[0] != height:
img = cv2.resize(img, (width, height))
video_writer.write(img)
video_writer.release()
last_known_rects_np = np.array(last_known_rects)
current_prompts = [[int((r[0] + r[2]) / 2), int((r[1] + r[3]) / 2)] for r in last_known_rects_np]
labels_prompt = [1] * len(current_prompts)
overrides = dict(
conf=conf,
task="segment",
mode="predict",
imgsz=imgsz,
model=model_absolute_path,
device=device
)
predictor = SAM2VideoPredictor(overrides=overrides)
results_generator = predictor(source=temp_video_path, points=current_prompts, labels=labels_prompt,
stream=True)
latest_rects_in_chunk = None
for frame_idx, results in enumerate(results_generator):
actual_frame_num = chunk_start_frame + frame_idx
session['progress'] = (actual_frame_num - start_frame)
session['message'] = f'Processing frame {actual_frame_num}'
if not results.masks:
all_frame_results[actual_frame_num] = ""
latest_rects_in_chunk = []
continue
masks = results.masks.data
bboxes_text_lines = []
current_frame_rects = []
for obj_idx in range(len(init_labels)):
if obj_idx < len(masks):
mask_data = masks[obj_idx]
bbox = _get_bbox_from_mask(mask_data, width, height)
if bbox:
x1, y1, x2, y2 = bbox
bboxes_text_lines.append(f"{x1},{y1},{x2},{y2},{init_labels[obj_idx]}")
current_frame_rects.append(bbox)
all_frame_results[actual_frame_num] = "\n".join(bboxes_text_lines)
latest_rects_in_chunk = current_frame_rects
if latest_rects_in_chunk and len(latest_rects_in_chunk) > 0:
last_known_rects = latest_rects_in_chunk
else:
logging.warning("Lost all objects in chunk. Stopping batch processing.")
break
finally:
if predictor is not None:
del predictor
if torch.cuda.is_available():
torch.cuda.empty_cache()
if os.path.exists(temp_video_path):
os.remove(temp_video_path)
logging.info(f"Finished chunk, cleared predictor, cache, and temp file.")
return all_frame_results