692 lines
32 KiB
Python
692 lines
32 KiB
Python
import logging
|
|
import os
|
|
import random
|
|
import shutil
|
|
import time
|
|
import traceback
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import torch
|
|
import yaml
|
|
|
|
import ai_models
|
|
import config
|
|
import database
|
|
import file_storage
|
|
from bbox_writer import extract_labels
|
|
from multiprocessing import Pool, cpu_count
|
|
|
|
try:
|
|
import ultralytics_sam_tasks
|
|
except ImportError:
|
|
ultralytics_sam_tasks = None
|
|
|
|
try:
|
|
import albumentations as A
|
|
|
|
|
|
class BboxSafeCoarseDropout(A.CoarseDropout):
|
|
def apply_to_bbox(self, bbox, **params):
|
|
return bbox
|
|
|
|
except ImportError:
|
|
logging.warning(
|
|
"albumentations library not found. Data augmentation will be disabled. Run 'pip install albumentations opencv-python-headless'")
|
|
A = None
|
|
|
|
|
|
active_tasks = {}
|
|
tracking_sessions = {}
|
|
|
|
|
|
def apply_prototypes_to_video_task(video_uuid, class_name, negative_samples, confidence_threshold, app_context):
|
|
if active_tasks.get(video_uuid):
|
|
logging.warning(f"Cannot start applying prototypes for {video_uuid}, another task is active.")
|
|
return
|
|
|
|
active_tasks[video_uuid] = 'APPLYING_PROTOTYPES'
|
|
logging.info(
|
|
f"Starting to apply suggestions for class '{class_name}' to video {video_uuid} with threshold {confidence_threshold}")
|
|
|
|
try:
|
|
with app_context:
|
|
database.update_video_status(video_uuid, 'APPLYING_PROTOTYPES', f"Initializing for '{class_name}'...")
|
|
|
|
database.update_video_status(video_uuid, 'APPLYING_PROTOTYPES',
|
|
f"Building positive prototypes for '{class_name}'...")
|
|
positive_prototypes = ai_models.build_prototypes_for_class(class_name)
|
|
if positive_prototypes is None or len(positive_prototypes) == 0:
|
|
raise ValueError(f"Could not build positive prototypes for class '{class_name}'.")
|
|
logging.info(f"Successfully built {len(positive_prototypes)} positive prototypes for '{class_name}'.")
|
|
|
|
negative_prototypes = None
|
|
if negative_samples:
|
|
database.update_video_status(video_uuid, 'APPLYING_PROTOTYPES', "Building negative prototypes...")
|
|
negative_prototypes = ai_models.get_prototypes_from_drawn_boxes(negative_samples)
|
|
if negative_prototypes is not None and len(negative_prototypes) > 0:
|
|
logging.info(
|
|
f"Successfully built {len(negative_prototypes)} negative prototypes from user samples.")
|
|
else:
|
|
logging.warning("User provided negative samples, but failed to build prototypes from them.")
|
|
|
|
all_frames = database.get_video_frames(video_uuid)
|
|
unlabeled_frames = [f for f in all_frames if not f['bboxes_text'].strip()]
|
|
total_frames = len(unlabeled_frames)
|
|
logging.info(f"Found {total_frames} unlabeled frames to process in video {video_uuid}.")
|
|
|
|
for i, frame_info in enumerate(unlabeled_frames):
|
|
frame_number = frame_info['frame_number']
|
|
current_status = database.get_video_entity(video_uuid)['status']
|
|
if current_status == 'CANCELLING':
|
|
logging.info(f"Task for {video_uuid} cancelled by user.")
|
|
database.update_video_status(video_uuid, 'READY', 'Task was cancelled.')
|
|
return
|
|
|
|
database.update_video_status(video_uuid, 'APPLYING_PROTOTYPES',
|
|
f"Processing frame {i + 1}/{total_frames}")
|
|
|
|
try:
|
|
predictions = ai_models.predict_with_prototypes(
|
|
video_uuid, frame_number, positive_prototypes,
|
|
negative_prototypes=negative_prototypes,
|
|
confidence_threshold=confidence_threshold
|
|
)
|
|
|
|
if predictions:
|
|
suggested_text = "\n".join(
|
|
[
|
|
f"{int(p['box'][0])},{int(p['box'][1])},{int(p['box'][2])},{int(p['box'][3])},{class_name},{p['score']:.4f}"
|
|
for p in predictions])
|
|
database.save_frame_suggestions(video_uuid, frame_number, suggested_text)
|
|
|
|
except Exception as frame_e:
|
|
logging.error(f"Failed to process frame {frame_number} for {video_uuid}: {frame_e}")
|
|
|
|
cache_key = f"{video_uuid}_{frame_number}"
|
|
if cache_key in ai_models.PREPROCESSED_DATA_CACHE:
|
|
del ai_models.PREPROCESSED_DATA_CACHE[cache_key]
|
|
|
|
database.update_video_status(video_uuid, 'READY',
|
|
f"Finished applying '{class_name}' suggestions. Review suggestions.")
|
|
logging.info(f"Task for {video_uuid} completed successfully.")
|
|
|
|
except Exception as e:
|
|
error_message = f"Failed to apply prototypes to video {video_uuid}"
|
|
logging.error(f"{error_message}: {e}")
|
|
logging.error(traceback.format_exc())
|
|
database.update_video_status(video_uuid, status="FAILED", message=str(e))
|
|
finally:
|
|
if active_tasks.get(video_uuid) == 'APPLYING_PROTOTYPES':
|
|
del active_tasks[video_uuid]
|
|
|
|
|
|
def start_sam2_tracking_task(video_uuid, tracker_uuid, start_frame, end_frame, init_bboxes_text):
|
|
if active_tasks.get(video_uuid):
|
|
logging.warning(f"A task is already running for video {video_uuid}.")
|
|
tracking_sessions[tracker_uuid] = {'status': 'FAILED', 'message': 'Another task is active.'}
|
|
return
|
|
|
|
if ultralytics_sam_tasks is None:
|
|
logging.error("Ultralytics SAM Tasks module not available.")
|
|
tracking_sessions[tracker_uuid] = {'status': 'FAILED',
|
|
'message': 'Ultralytics library not installed or configured on server.'}
|
|
return
|
|
|
|
active_tasks[video_uuid] = tracker_uuid
|
|
session = {
|
|
'status': 'STARTING',
|
|
'progress': 0,
|
|
'total': (end_frame - start_frame) + 1,
|
|
'results': {},
|
|
'stop_requested': False,
|
|
'message': ''
|
|
}
|
|
tracking_sessions[tracker_uuid] = session
|
|
|
|
try:
|
|
logging.info(
|
|
f"Starting INTERACTIVE SAM tracking for video {video_uuid} from frame {start_frame} to {end_frame}")
|
|
session['status'] = 'PROCESSING'
|
|
|
|
ultralytics_sam_tasks.track_video_ultralytics(
|
|
video_uuid,
|
|
start_frame,
|
|
end_frame,
|
|
init_bboxes_text,
|
|
session
|
|
)
|
|
|
|
final_status = session.get('status', 'COMPLETED')
|
|
logging.info(f"Interactive SAM tracking for {tracker_uuid} finished with status: {final_status}.")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error during Interactive SAM tracking for {video_uuid}: {e}\n{traceback.format_exc()}")
|
|
session['status'] = 'FAILED'
|
|
session['message'] = str(e)
|
|
finally:
|
|
logging.info(f"Cleaning up resources for Interactive SAM tracking task {tracker_uuid}...")
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
logging.info("Emptied PyTorch CUDA cache.")
|
|
|
|
if active_tasks.get(video_uuid) == tracker_uuid:
|
|
del active_tasks[video_uuid]
|
|
|
|
logging.info(f"Resource cleanup for task {tracker_uuid} complete.")
|
|
|
|
|
|
def start_sam2_batch_tracking_task(video_uuid, tracker_uuid, start_frame, end_frame, init_bboxes_text):
|
|
if active_tasks.get(video_uuid):
|
|
logging.warning(f"A task is already running for video {video_uuid}.")
|
|
tracking_sessions[tracker_uuid] = {'status': 'FAILED', 'message': 'Another task is active.'}
|
|
return
|
|
|
|
if ultralytics_sam_tasks is None:
|
|
logging.error("Ultralytics SAM Tasks module not available for batch tracking.")
|
|
tracking_sessions[tracker_uuid] = {'status': 'FAILED',
|
|
'message': 'Ultralytics library not installed or configured on server.'}
|
|
return
|
|
|
|
active_tasks[video_uuid] = tracker_uuid
|
|
session = {
|
|
'status': 'BATCH_PROCESSING',
|
|
'progress': 0,
|
|
'total': (end_frame - start_frame) + 1,
|
|
'results': {},
|
|
'stop_requested': False,
|
|
'message': 'Preparing temporary video clip...'
|
|
}
|
|
tracking_sessions[tracker_uuid] = session
|
|
|
|
try:
|
|
logging.info(
|
|
f"Starting BATCH SAM tracking for video {video_uuid} from frame {start_frame} to {end_frame}")
|
|
|
|
all_results = ultralytics_sam_tasks.run_batch_tracking_with_predictor(
|
|
video_uuid,
|
|
start_frame,
|
|
end_frame,
|
|
init_bboxes_text,
|
|
session
|
|
)
|
|
|
|
session['results'] = all_results
|
|
session['progress'] = session['total']
|
|
session['status'] = 'COMPLETED'
|
|
session['message'] = 'Batch processing complete. Ready for review.'
|
|
logging.info(f"Batch SAM tracking for {tracker_uuid} finished successfully.")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error during Batch SAM tracking for {video_uuid}: {e}\n{traceback.format_exc()}")
|
|
session['status'] = 'FAILED'
|
|
session['message'] = str(e)
|
|
finally:
|
|
if active_tasks.get(video_uuid) == tracker_uuid:
|
|
del active_tasks[video_uuid]
|
|
logging.info(f"Batch tracking task for {tracker_uuid} cleaned up.")
|
|
|
|
|
|
def extract_frames_task(video_uuid):
|
|
if active_tasks.get(video_uuid) == 'EXTRACTING':
|
|
logging.warning(f"Extraction for {video_uuid} is already running.")
|
|
return
|
|
|
|
active_tasks[video_uuid] = 'EXTRACTING'
|
|
logging.info(f"Starting frame extraction for {video_uuid}")
|
|
video_path = file_storage.get_video_path(video_uuid)
|
|
|
|
try:
|
|
vid = cv2.VideoCapture(video_path)
|
|
if not vid.isOpened():
|
|
raise IOError("Cannot open video file")
|
|
|
|
width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
fps = vid.get(cv2.CAP_PROP_FPS)
|
|
|
|
frame_count = int(vid.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
if frame_count <= 0 or frame_count > config.MAX_FRAMES_PER_VIDEO:
|
|
vid.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
|
frame_count = 0
|
|
while vid.grab():
|
|
frame_count += 1
|
|
vid.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
|
|
|
if frame_count > config.MAX_FRAMES_PER_VIDEO:
|
|
raise ValueError(f"Video has more than {config.MAX_FRAMES_PER_VIDEO} frames.")
|
|
|
|
database.update_video_after_extraction_start(video_uuid, width, height, fps, frame_count)
|
|
|
|
count = 0
|
|
while True:
|
|
success, frame = vid.read()
|
|
if not success:
|
|
break
|
|
|
|
success, buffer = cv2.imencode('.jpg', frame, [int(cv2.IMWRITE_JPEG_QUALITY), 75])
|
|
if success:
|
|
file_storage.save_frame_image(video_uuid, count, buffer.tobytes())
|
|
database.update_extracted_frame_count(video_uuid, count + 1)
|
|
|
|
count += 1
|
|
|
|
vid.release()
|
|
database.update_video_status(video_uuid, 'READY')
|
|
logging.info(f"Frame extraction for {video_uuid} completed successfully.")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error extracting frames for {video_uuid}: {e}")
|
|
database.update_video_status(video_uuid, 'FAILED', str(e))
|
|
finally:
|
|
if active_tasks.get(video_uuid) == 'EXTRACTING':
|
|
del active_tasks[video_uuid]
|
|
|
|
|
|
def pre_annotate_video_task(video_uuid, model_uuid, options):
|
|
if active_tasks.get(video_uuid):
|
|
logging.warning(f"Cannot start pre-annotation for {video_uuid}, another task is active.")
|
|
return
|
|
|
|
active_tasks[video_uuid] = 'PRE_ANNOTATING'
|
|
logging.info(f"Starting pre-annotation for video {video_uuid} with options: {options}")
|
|
|
|
try:
|
|
confidence_threshold = options['confidence']
|
|
start_frame = options['start_frame']
|
|
end_frame = options['end_frame']
|
|
merge_strategy = options['merge_strategy']
|
|
|
|
video = database.get_video_entity(video_uuid)
|
|
model_info = database.get_model_entity(model_uuid)
|
|
model_type = model_info['model_type']
|
|
|
|
database.update_video_status(video_uuid, 'PRE_ANNOTATING', f"Using model: {model_info['description']}")
|
|
database.update_pre_annotation_info(video_uuid, model_uuid, model_info['description'])
|
|
|
|
model_path = file_storage.get_model_path(model_uuid)
|
|
label_path = file_storage.get_label_file_path(model_uuid)
|
|
|
|
interpreter = tf.lite.Interpreter(model_path=model_path)
|
|
interpreter.allocate_tensors()
|
|
|
|
with open(label_path, 'r') as f:
|
|
labels = [line.strip() for line in f.readlines()]
|
|
|
|
input_details = interpreter.get_input_details()
|
|
output_details = interpreter.get_output_details()
|
|
height = input_details[0]['shape'][1]
|
|
width = input_details[0]['shape'][2]
|
|
|
|
all_frames = database.get_video_frames(video_uuid)
|
|
frames_to_process = []
|
|
for frame_info in all_frames:
|
|
if start_frame <= frame_info['frame_number'] <= end_frame:
|
|
if merge_strategy == 'skip_labeled' and frame_info.get('bboxes_text', '').strip():
|
|
continue
|
|
frames_to_process.append(frame_info)
|
|
|
|
total_frames_to_process = len(frames_to_process)
|
|
logging.info(f"Total frames to process after filtering: {total_frames_to_process}")
|
|
|
|
for i, frame_info in enumerate(frames_to_process):
|
|
if i % 10 == 0:
|
|
current_status = database.get_video_entity(video_uuid)['status']
|
|
if current_status == 'CANCELLING':
|
|
logging.info(f"Pre-annotation for {video_uuid} cancelled by user.")
|
|
database.update_video_status(video_uuid, 'READY', 'Task was cancelled.')
|
|
return
|
|
|
|
if (i + 1) % 20 == 0:
|
|
progress_msg = f"Processed {i + 1}/{total_frames_to_process} frames"
|
|
database.update_video_status(video_uuid, 'PRE_ANNOTATING', progress_msg)
|
|
|
|
frame_path = file_storage.get_frame_path(video_uuid, frame_info['frame_number'])
|
|
if not os.path.exists(frame_path):
|
|
continue
|
|
|
|
frame_img = cv2.imread(frame_path)
|
|
imH, imW, _ = frame_img.shape
|
|
frame_rgb = cv2.cvtColor(frame_img, cv2.COLOR_BGR2RGB)
|
|
image_resized = cv2.resize(frame_rgb, (width, height))
|
|
input_data = np.expand_dims(image_resized, axis=0)
|
|
|
|
if model_type == 'float32':
|
|
input_data = np.float32(input_data) / 255.0
|
|
|
|
interpreter.set_tensor(input_details[0]['index'], input_data)
|
|
interpreter.invoke()
|
|
|
|
scores_raw = interpreter.get_tensor(output_details[0]['index'])[0]
|
|
boxes_raw = interpreter.get_tensor(output_details[1]['index'])[0]
|
|
classes_raw = interpreter.get_tensor(output_details[3]['index'])[0]
|
|
|
|
scores_details = output_details[0]
|
|
if scores_details['dtype'] == np.uint8 and scores_details.get('quantization'):
|
|
scale, zero_point = scores_details['quantization']
|
|
scores = (np.float32(scores_raw) - zero_point) * scale
|
|
else:
|
|
scores = scores_raw
|
|
|
|
boxes_details = output_details[1]
|
|
if boxes_details['dtype'] == np.uint8 and boxes_details.get('quantization'):
|
|
scale, zero_point = boxes_details['quantization']
|
|
boxes = (np.float32(boxes_raw) - zero_point) * scale
|
|
else:
|
|
boxes = boxes_raw
|
|
classes = classes_raw
|
|
|
|
bboxes_text_lines = []
|
|
for j in range(len(scores)):
|
|
if scores[j] > confidence_threshold:
|
|
ymin = int(max(0, boxes[j][0] * imH))
|
|
xmin = int(max(0, boxes[j][1] * imW))
|
|
ymax = int(min(imH, boxes[j][2] * imH))
|
|
xmax = int(min(imW, boxes[j][3] * imW))
|
|
|
|
object_id = int(classes[j])
|
|
if object_id < len(labels):
|
|
object_name = labels[object_id]
|
|
bboxes_text_lines.append(f"{xmin},{ymin},{xmax},{ymax},{object_name}")
|
|
|
|
final_bboxes_text = "\n".join(bboxes_text_lines)
|
|
database.save_frame_bboxes(video_uuid, frame_info['frame_number'], final_bboxes_text)
|
|
|
|
database.update_video_status(video_uuid, 'READY', "Pre-annotation complete")
|
|
logging.info(f"Pre-annotation for {video_uuid} completed successfully.")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error during pre-annotation for {video_uuid}: {e}", exc_info=True)
|
|
database.update_video_status(video_uuid, 'READY', f"Pre-annotation failed: {e}")
|
|
finally:
|
|
if active_tasks.get(video_uuid) == 'PRE_ANNOTATING':
|
|
del active_tasks[video_uuid]
|
|
|
|
|
|
def start_tracking_task(video_uuid, tracker_uuid, tracker_name, scale, init_frame_number, init_bboxes_text):
|
|
if active_tasks.get(video_uuid):
|
|
logging.warning(f"A task (tracking/extraction) is already running for video {video_uuid}.")
|
|
tracking_sessions[tracker_uuid] = {'status': 'FAILED', 'message': 'Another task is active.'}
|
|
return
|
|
|
|
active_tasks[video_uuid] = tracker_uuid
|
|
video_path = file_storage.get_video_path(video_uuid)
|
|
video_info = database.get_video_entity(video_uuid)
|
|
|
|
tracker_fns = {
|
|
'CSRT': cv2.legacy.TrackerCSRT_create, 'MedianFlow': cv2.legacy.TrackerMedianFlow_create,
|
|
'MIL': cv2.legacy.TrackerMIL_create, 'MOSSE': cv2.legacy.TrackerMOSSE_create,
|
|
'TLD': cv2.legacy.TrackerTLD_create, 'KCF': cv2.legacy.TrackerKCF_create,
|
|
'Boosting': cv2.legacy.TrackerBoosting_create,
|
|
}
|
|
|
|
try:
|
|
logging.info(f"Starting tracking for video {video_uuid} with tracker {tracker_name}")
|
|
vid = cv2.VideoCapture(video_path)
|
|
if not vid.isOpened(): raise IOError("Cannot open video file")
|
|
vid.set(cv2.CAP_PROP_POS_FRAMES, init_frame_number)
|
|
session = {'status': 'RUNNING', 'current_frame': init_frame_number, 'bboxes_text': init_bboxes_text,
|
|
'last_client_update': time.time(), 'stop_requested': False}
|
|
tracking_sessions[tracker_uuid] = session
|
|
frame_number = init_frame_number
|
|
trackers = None
|
|
while not session['stop_requested']:
|
|
success, frame = vid.read()
|
|
if not success:
|
|
session['status'] = 'COMPLETED'
|
|
break
|
|
if trackers is None or session['current_frame'] == frame_number:
|
|
from bbox_writer import parse_bboxes_text
|
|
bboxes, classes = parse_bboxes_text(session['bboxes_text'], scale)
|
|
tracker_fn = tracker_fns[tracker_name]
|
|
trackers = []
|
|
for bbox in bboxes:
|
|
tracker = tracker_fn()
|
|
tracker.init(frame, tuple(bbox))
|
|
trackers.append(tracker)
|
|
new_bboxes = []
|
|
for tracker in trackers:
|
|
ok, bbox = tracker.update(frame)
|
|
new_bboxes.append(np.array(bbox) if ok else None)
|
|
from bbox_writer import format_bboxes_text
|
|
session['bboxes_text'] = format_bboxes_text(new_bboxes, classes, scale, video_info['width'],
|
|
video_info['height'])
|
|
session['current_frame'] = frame_number
|
|
while session['current_frame'] == frame_number and not session['stop_requested']:
|
|
time.sleep(0.1)
|
|
if time.time() - session['last_client_update'] > 60:
|
|
logging.warning(f"Tracking session {tracker_uuid} timed out.")
|
|
session['status'] = 'TIMED OUT'
|
|
session['stop_requested'] = True
|
|
frame_number += 1
|
|
vid.release()
|
|
except Exception as e:
|
|
logging.error(f"Error during tracking for {video_uuid}: {e}\n{traceback.format_exc()}")
|
|
if tracker_uuid in tracking_sessions:
|
|
tracking_sessions[tracker_uuid]['status'] = 'FAILED'
|
|
tracking_sessions[tracker_uuid]['message'] = str(e)
|
|
finally:
|
|
if active_tasks.get(video_uuid) == tracker_uuid: del active_tasks[video_uuid]
|
|
if tracker_uuid in tracking_sessions and tracking_sessions[tracker_uuid]['status'] == 'RUNNING':
|
|
tracking_sessions[tracker_uuid]['status'] = 'STOPPED'
|
|
logging.info(
|
|
f"Tracking task for {video_uuid} finished with status: {tracking_sessions.get(tracker_uuid, {}).get('status')}")
|
|
|
|
|
|
def build_augmentation_pipeline(options):
|
|
if A is None: return None
|
|
transforms = []
|
|
if options.get('hflip', {}).get('enabled'):
|
|
transforms.append(A.HorizontalFlip(p=options['hflip']['p']))
|
|
if options.get('vflip', {}).get('enabled'):
|
|
transforms.append(A.VerticalFlip(p=options['vflip']['p']))
|
|
if options.get('rotate90', {}).get('enabled'):
|
|
transforms.append(A.RandomRotate90(p=options['rotate90']['p']))
|
|
if options.get('rotate', {}).get('enabled'):
|
|
transforms.append(
|
|
A.Rotate(limit=options['rotate']['limit'], p=options['rotate']['p'], border_mode=cv2.BORDER_CONSTANT,
|
|
value=0))
|
|
if options.get('ssr', {}).get('enabled'):
|
|
transforms.append(A.ShiftScaleRotate(shift_limit=options['ssr']['shift'], scale_limit=options['ssr']['scale'],
|
|
rotate_limit=options['ssr']['rotate'], p=options['ssr']['p'],
|
|
border_mode=cv2.BORDER_CONSTANT, value=0))
|
|
if options.get('affine', {}).get('enabled'):
|
|
limit = options['affine']['shear']
|
|
transforms.append(
|
|
A.Affine(shear={'x': (-limit, limit), 'y': (-limit, limit)}, p=options['affine']['p'], cval=0))
|
|
if options.get('crop', {}).get('enabled'):
|
|
transforms.append(A.RandomSizedBBoxSafeCrop(height=1024, width=1024, erosion_rate=0.2,
|
|
p=options['crop']['p']))
|
|
|
|
if options.get('grayscale', {}).get('enabled'):
|
|
transforms.append(A.ToGray(p=options['grayscale']['p']))
|
|
if options.get('hsv', {}).get('enabled'):
|
|
transforms.append(A.HueSaturationValue(hue_shift_limit=options['hsv']['h'], sat_shift_limit=options['hsv']['s'],
|
|
val_shift_limit=options['hsv']['v'], p=options['hsv']['p']))
|
|
if options.get('bc', {}).get('enabled'):
|
|
transforms.append(
|
|
A.RandomBrightnessContrast(brightness_limit=options['bc']['b'], contrast_limit=options['bc']['c'],
|
|
p=options['bc']['p']))
|
|
|
|
if options.get('blur', {}).get('enabled'):
|
|
transforms.append(A.GaussianBlur(blur_limit=(3, options['blur']['limit']), p=options['blur']['p']))
|
|
if options.get('noise', {}).get('enabled'):
|
|
transforms.append(A.GaussNoise(var_limit=(10.0, options['noise']['limit']), p=options['noise']['p']))
|
|
|
|
if options.get('cutout', {}).get('enabled'):
|
|
transforms.append(
|
|
BboxSafeCoarseDropout(max_holes=options['cutout']['holes'], max_height=options['cutout']['size'],
|
|
max_width=options['cutout']['size'], fill_value=0, p=options['cutout']['p']))
|
|
|
|
if not transforms: return None
|
|
return A.Compose(transforms,
|
|
bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'], min_visibility=0.1))
|
|
|
|
|
|
def process_frame_worker(args):
|
|
frame_info, target_img_dir, target_lbl_dir, class_map, augmentation_options = args
|
|
|
|
augment_pipeline = None
|
|
is_augmented = frame_info.get("type") == "augmented"
|
|
if is_augmented and augmentation_options and augmentation_options.get("enabled", False):
|
|
augment_pipeline = build_augmentation_pipeline(augmentation_options)
|
|
|
|
try:
|
|
if is_augmented:
|
|
base_filename = frame_info["augmented_id"]
|
|
else:
|
|
base_filename = f"{frame_info['video_uuid']}_{frame_info['frame_number']:05d}"
|
|
|
|
src_img_path = file_storage.get_frame_path(frame_info['video_uuid'], frame_info['frame_number'])
|
|
if not os.path.exists(src_img_path):
|
|
logging.warning(f"源文件未找到,跳过: {src_img_path}")
|
|
return None
|
|
|
|
image = cv2.imread(src_img_path)
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
yolo_bboxes, class_indices = file_storage.get_yolo_bboxes(
|
|
frame_info['bboxes_text'], frame_info['width'], frame_info['height'], class_map
|
|
)
|
|
|
|
if not yolo_bboxes:
|
|
return None
|
|
|
|
if is_augmented and augment_pipeline:
|
|
transformed = augment_pipeline(image=image, bboxes=yolo_bboxes, class_labels=class_indices)
|
|
image_aug_rgb = transformed['image']
|
|
bboxes_aug_yolo_tuples = transformed['bboxes']
|
|
labels_aug_indices = transformed['class_labels']
|
|
bboxes_aug_yolo = [(labels_aug_indices[i], *box) for i, box in enumerate(bboxes_aug_yolo_tuples)]
|
|
else:
|
|
image_aug_rgb = image
|
|
bboxes_aug_yolo = [(class_indices[i], *box) for i, box in enumerate(yolo_bboxes)]
|
|
|
|
final_image_bgr = cv2.cvtColor(image_aug_rgb, cv2.COLOR_RGB2BGR)
|
|
output_image_path = os.path.join(target_img_dir, base_filename + '.jpg')
|
|
cv2.imwrite(output_image_path, final_image_bgr)
|
|
|
|
yolo_content_lines = [f"{class_id} {x:.6f} {y:.6f} {w:.6f} {h:.6f}" for class_id, x, y, w, h in bboxes_aug_yolo]
|
|
output_label_path = os.path.join(target_lbl_dir, base_filename + '.txt')
|
|
with open(output_label_path, 'w') as f:
|
|
f.write("\n".join(yolo_content_lines))
|
|
|
|
return output_image_path
|
|
|
|
except Exception as e:
|
|
logging.error(f"处理帧 {frame_info.get('augmented_id') or frame_info.get('frame_number')} 时发生错误: {e}")
|
|
logging.error(traceback.format_exc())
|
|
return None
|
|
|
|
|
|
def create_dataset_task(dataset_uuid, video_uuids, eval_percent, test_percent, augmentation_options=None):
|
|
if augmentation_options is None:
|
|
augmentation_options = {}
|
|
|
|
logging.info(f"Starting dataset creation task for UUID: {dataset_uuid} with augmentations: {augmentation_options}")
|
|
try:
|
|
if eval_percent is None: eval_percent = 20.0
|
|
if test_percent is None: test_percent = 10.0
|
|
if eval_percent + test_percent >= 100.0:
|
|
raise ValueError(
|
|
f"The sum of validation ({eval_percent}%) and test ({test_percent}%) percentages must be less than 100.")
|
|
|
|
database.update_dataset_status(dataset_uuid, status="PROCESSING", message="Gathering labeled frames...")
|
|
|
|
frames_to_include = []
|
|
all_labels = set()
|
|
logging.info(f"Gathering frames from {len(video_uuids)} selected video(s)...")
|
|
for video_uuid in video_uuids:
|
|
video = database.get_video_entity(video_uuid)
|
|
all_video_frames = database.get_video_frames(video_uuid)
|
|
for frame in all_video_frames:
|
|
if frame.get('bboxes_text') and frame['bboxes_text'].strip():
|
|
frames_to_include.append({
|
|
"video_uuid": video_uuid, "frame_number": frame['frame_number'],
|
|
"bboxes_text": frame['bboxes_text'], "width": video['width'], "height": video['height']
|
|
})
|
|
labels_in_frame = extract_labels(frame['bboxes_text'])
|
|
for label in labels_in_frame: all_labels.add(label)
|
|
|
|
if not frames_to_include:
|
|
raise ValueError("No labeled frames with valid bounding boxes were found in the selected videos.")
|
|
|
|
sorted_labels = sorted(list(all_labels))
|
|
class_map = {name: i for i, name in enumerate(sorted_labels)}
|
|
logging.info(f"Dataset classes (sorted): {sorted_labels}")
|
|
|
|
is_aug_enabled = A is not None and augmentation_options.get("enabled", False)
|
|
multiplication_factor = int(augmentation_options.get("multiply_factor", 1)) if is_aug_enabled else 1
|
|
final_frames_to_process = []
|
|
if is_aug_enabled and multiplication_factor > 1:
|
|
for frame_info in frames_to_include:
|
|
final_frames_to_process.append({"type": "original", **frame_info})
|
|
for i in range(multiplication_factor - 1):
|
|
aug_id = f"aug_{i}_{frame_info['video_uuid']}_{frame_info['frame_number']:05d}"
|
|
final_frames_to_process.append({"type": "augmented", "augmented_id": aug_id, **frame_info})
|
|
else:
|
|
final_frames_to_process = [{"type": "original", **frame_info} for frame_info in frames_to_include]
|
|
|
|
random.shuffle(final_frames_to_process)
|
|
total_count = len(final_frames_to_process)
|
|
val_count = int(total_count * eval_percent / 100.0)
|
|
test_count = int(total_count * test_percent / 100.0)
|
|
|
|
val_data = final_frames_to_process[:val_count]
|
|
test_data = final_frames_to_process[val_count:val_count + test_count]
|
|
train_data = final_frames_to_process[val_count + test_count:]
|
|
|
|
dataset_dir = file_storage.get_dataset_dir(dataset_uuid)
|
|
if os.path.exists(dataset_dir): shutil.rmtree(dataset_dir)
|
|
dir_map = {
|
|
'train': (os.path.join(dataset_dir, 'images', 'train'), os.path.join(dataset_dir, 'labels', 'train')),
|
|
'val': (os.path.join(dataset_dir, 'images', 'val'), os.path.join(dataset_dir, 'labels', 'val')),
|
|
'test': (os.path.join(dataset_dir, 'images', 'test'), os.path.join(dataset_dir, 'labels', 'test')),
|
|
}
|
|
for img_dir, lbl_dir in dir_map.values():
|
|
os.makedirs(img_dir, exist_ok=True)
|
|
os.makedirs(lbl_dir, exist_ok=True)
|
|
all_tasks = []
|
|
for split_name, split_data in [('train', train_data), ('val', val_data), ('test', test_data)]:
|
|
img_dir, lbl_dir = dir_map[split_name]
|
|
for frame_info in split_data:
|
|
all_tasks.append((frame_info, img_dir, lbl_dir, class_map, augmentation_options))
|
|
|
|
database.update_dataset_status(dataset_uuid, status="PROCESSING",
|
|
message=f"Processing {len(all_tasks)} images across {cpu_count()} CPU cores...")
|
|
logging.info(f"Starting parallel processing of {len(all_tasks)} images using up to {cpu_count()} cores.")
|
|
|
|
processed_count = 0
|
|
with Pool(processes=cpu_count()) as pool:
|
|
for result in pool.imap_unordered(process_frame_worker, all_tasks):
|
|
if result:
|
|
processed_count += 1
|
|
if processed_count % 50 == 0:
|
|
progress_msg = f"Processed {processed_count}/{len(all_tasks)} images..."
|
|
database.update_dataset_status(dataset_uuid, status="PROCESSING", message=progress_msg)
|
|
|
|
logging.info(f"Parallel processing finished. Processed {processed_count} images successfully.")
|
|
|
|
if yaml:
|
|
yaml_content = {'path': f"../datasets/{dataset_uuid}", 'train': 'images/train', 'val': 'images/val',
|
|
'test': 'images/test', 'nc': len(sorted_labels), 'names': sorted_labels}
|
|
with open(os.path.join(dataset_dir, 'data.yaml'), 'w') as f:
|
|
yaml.dump(yaml_content, f, sort_keys=False)
|
|
else:
|
|
logging.error("PyYAML is not installed! Cannot create data.yaml for the dataset.")
|
|
|
|
database.update_dataset_status(dataset_uuid, status="PROCESSING", message="Creating ZIP archive...")
|
|
zip_path_base = os.path.join(config.STORAGE_DIR, 'datasets', dataset_uuid)
|
|
zip_path = shutil.make_archive(zip_path_base, 'zip', dataset_dir)
|
|
shutil.rmtree(dataset_dir)
|
|
|
|
logging.info(f"ZIP archive created at: {zip_path}")
|
|
database.update_dataset_status(dataset_uuid, status="READY", zip_path=zip_path, sorted_label_list=sorted_labels)
|
|
logging.info(f"Dataset {dataset_uuid} task completed successfully.")
|
|
|
|
except Exception as e:
|
|
error_message = f"Failed to create dataset {dataset_uuid}"
|
|
logging.error(f"{error_message}: {e}")
|
|
logging.error(traceback.format_exc())
|
|
database.update_dataset_status(dataset_uuid, status="FAILED", message=str(e)) |