Files
fish_monitor/Zero2YoloYard-main/database.py
T
2025-12-25 17:41:18 +08:00

541 lines
18 KiB
Python

import json
import logging
import sqlite3
import time
import uuid
import config
import file_storage
from bbox_writer import extract_labels
def get_db_connection():
conn = sqlite3.connect(config.DATABASE_FILE)
conn.row_factory = sqlite3.Row
return conn
def _add_column_if_not_exists(cursor, table_name, column_name, column_type):
cursor.execute(f"PRAGMA table_info({table_name})")
columns = [row['name'] for row in cursor.fetchall()]
if column_name not in columns:
cursor.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}")
print(f"Added column '{column_name}' to table '{table_name}'.")
def migrate_db():
conn = get_db_connection()
cursor = conn.cursor()
_add_column_if_not_exists(cursor, 'datasets', 'eval_percent', 'REAL')
_add_column_if_not_exists(cursor, 'datasets', 'test_percent', 'REAL')
_add_column_if_not_exists(cursor, 'models', 'label_filename', 'TEXT')
_add_column_if_not_exists(cursor, 'models', 'model_type', 'TEXT')
_add_column_if_not_exists(cursor, 'videos', 'last_pre_annotation_info', 'TEXT')
_add_column_if_not_exists(cursor, 'video_frames', 'tags', 'TEXT')
_add_column_if_not_exists(cursor, 'video_frames', 'suggested_bboxes_text', 'TEXT')
conn.commit()
conn.close()
def init_db():
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS videos (
video_uuid TEXT PRIMARY KEY,
description TEXT NOT NULL UNIQUE,
video_filename TEXT,
file_size INTEGER,
create_time_ms INTEGER,
status TEXT,
status_message TEXT,
width INTEGER,
height INTEGER,
fps REAL,
frame_count INTEGER,
extracted_frame_count INTEGER DEFAULT 0,
included_frame_count INTEGER DEFAULT 0,
labeled_frame_count INTEGER DEFAULT 0,
last_pre_annotation_info TEXT
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS video_frames (
frame_id INTEGER PRIMARY KEY AUTOINCREMENT,
video_uuid TEXT,
frame_number INTEGER,
bboxes_text TEXT,
suggested_bboxes_text TEXT,
tags TEXT,
include_frame_in_dataset INTEGER,
FOREIGN KEY (video_uuid) REFERENCES videos (video_uuid) ON DELETE CASCADE
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS datasets (
dataset_uuid TEXT PRIMARY KEY,
description TEXT NOT NULL UNIQUE,
video_uuids TEXT,
create_time_ms INTEGER,
status TEXT,
status_message TEXT,
zip_path TEXT,
sorted_label_list TEXT,
eval_percent REAL,
test_percent REAL
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS models (
model_uuid TEXT PRIMARY KEY,
description TEXT NOT NULL UNIQUE,
create_time_ms INTEGER,
label_filename TEXT,
model_type TEXT
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS annotation_tasks (
task_uuid TEXT PRIMARY KEY,
video_uuid TEXT NOT NULL,
assigned_to TEXT NOT NULL,
description TEXT,
start_frame INTEGER NOT NULL,
end_frame INTEGER NOT NULL,
status TEXT,
create_time_ms INTEGER,
FOREIGN KEY (video_uuid) REFERENCES videos (video_uuid) ON DELETE CASCADE
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS class_labels (
label_id INTEGER PRIMARY KEY AUTOINCREMENT,
label_name TEXT NOT NULL UNIQUE,
create_time_ms INTEGER
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS class_tags (
tag_id INTEGER PRIMARY KEY AUTOINCREMENT,
tag_name TEXT NOT NULL UNIQUE,
create_time_ms INTEGER
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS frame_labels (
frame_id INTEGER NOT NULL,
label_name TEXT NOT NULL,
FOREIGN KEY (frame_id) REFERENCES video_frames (frame_id) ON DELETE CASCADE
)
''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_frame_labels_name ON frame_labels (label_name)')
conn.commit()
conn.close()
def create_video_entry(description, video_filename, file_size, create_time_ms):
conn = get_db_connection()
video_uuid = str(uuid.uuid4().hex)
conn.execute(
'INSERT INTO videos (video_uuid, description, video_filename, file_size, create_time_ms, status) VALUES (?, ?, ?, ?, ?, ?)',
(video_uuid, description, video_filename, file_size, create_time_ms, 'UPLOADING')
)
conn.commit()
conn.close()
return video_uuid
def get_ready_videos_with_labels():
conn = get_db_connection()
videos = conn.execute(
'SELECT * FROM videos WHERE status = "READY" AND labeled_frame_count > 0 ORDER BY create_time_ms DESC').fetchall()
conn.close()
return [dict(row) for row in videos]
def get_all_video_list():
conn = get_db_connection()
videos = conn.execute('SELECT * FROM videos ORDER BY create_time_ms DESC').fetchall()
conn.close()
return [dict(row) for row in videos]
def get_video_entity(video_uuid):
conn = get_db_connection()
video = conn.execute('SELECT * FROM videos WHERE video_uuid = ?', (video_uuid,)).fetchone()
conn.close()
return dict(video) if video else None
def update_video_status(video_uuid, status, message=""):
conn = get_db_connection()
conn.execute('UPDATE videos SET status = ?, status_message = ? WHERE video_uuid = ?', (status, message, video_uuid))
conn.commit()
conn.close()
def update_pre_annotation_info(video_uuid, model_uuid, model_desc):
conn = get_db_connection()
info = {
"model_uuid": model_uuid,
"model_desc": model_desc,
"time_ms": int(time.time() * 1000)
}
conn.execute('UPDATE videos SET last_pre_annotation_info = ? WHERE video_uuid = ?', (json.dumps(info), video_uuid))
conn.commit()
conn.close()
def update_video_after_extraction_start(video_uuid, width, height, fps, frame_count):
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(
'UPDATE videos SET width=?, height=?, fps=?, frame_count=?, included_frame_count=?, status=? WHERE video_uuid=?',
(width, height, fps, frame_count, frame_count, 'EXTRACTING', video_uuid)
)
frames_to_insert = [(video_uuid, i, '', '', '', 1) for i in range(frame_count)]
cursor.executemany(
'INSERT INTO video_frames (video_uuid, frame_number, bboxes_text, suggested_bboxes_text, tags, include_frame_in_dataset) VALUES (?, ?, ?, ?, ?, ?)',
frames_to_insert
)
conn.commit()
conn.close()
def update_extracted_frame_count(video_uuid, count):
conn = get_db_connection()
conn.execute('UPDATE videos SET extracted_frame_count = ? WHERE video_uuid = ?', (count, video_uuid))
conn.commit()
conn.close()
def delete_video(video_uuid):
conn = get_db_connection()
conn.execute('DELETE FROM videos WHERE video_uuid = ?', (video_uuid,))
conn.commit()
conn.close()
def get_video_frames(video_uuid):
conn = get_db_connection()
frames = conn.execute('SELECT * FROM video_frames WHERE video_uuid = ? ORDER BY frame_number ASC',
(video_uuid,)).fetchall()
conn.close()
return [dict(row) for row in frames]
def save_frame_bboxes(video_uuid, frame_number, bboxes_text):
conn = get_db_connection()
cursor = conn.cursor()
frame = cursor.execute(
'SELECT frame_id FROM video_frames WHERE video_uuid = ? AND frame_number = ?',
(video_uuid, frame_number)
).fetchone()
if not frame:
conn.close()
logging.error(f"无法为 video {video_uuid}, frame {frame_number} 找到 frame_id。")
return
frame_id = frame['frame_id']
cursor.execute(
'UPDATE video_frames SET bboxes_text = ?, suggested_bboxes_text = ? WHERE frame_id = ?',
(bboxes_text, '', frame_id)
)
cursor.execute('DELETE FROM frame_labels WHERE frame_id = ?', (frame_id,))
unique_labels_in_frame = set(extract_labels(bboxes_text))
if unique_labels_in_frame:
labels_to_insert = [(frame_id, label) for label in unique_labels_in_frame]
cursor.executemany(
'INSERT INTO frame_labels (frame_id, label_name) VALUES (?, ?)',
labels_to_insert
)
new_labeled_count = cursor.execute(
"SELECT COUNT(*) FROM video_frames WHERE video_uuid = ? AND ((bboxes_text IS NOT NULL AND bboxes_text != '') OR (tags IS NOT NULL AND tags != '[]' AND tags != ''))",
(video_uuid,)
).fetchone()[0]
cursor.execute(
'UPDATE videos SET labeled_frame_count = ? WHERE video_uuid = ?',
(new_labeled_count, video_uuid)
)
conn.commit()
conn.close()
def save_frame_suggestions(video_uuid, frame_number, suggested_bboxes_text):
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(
'UPDATE video_frames SET suggested_bboxes_text = ? WHERE video_uuid = ? AND frame_number = ?',
(suggested_bboxes_text, video_uuid, frame_number)
)
conn.commit()
conn.close()
def save_frame_tags(video_uuid, frame_number, tags_json_string):
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(
'UPDATE video_frames SET tags = ? WHERE video_uuid = ? AND frame_number = ?',
(tags_json_string, video_uuid, frame_number)
)
new_labeled_count = cursor.execute(
"SELECT COUNT(*) FROM video_frames WHERE video_uuid = ? AND ((bboxes_text IS NOT NULL AND bboxes_text != '') OR (tags IS NOT NULL AND tags != '[]' AND tags != ''))",
(video_uuid,)
).fetchone()[0]
cursor.execute(
'UPDATE videos SET labeled_frame_count = ? WHERE video_uuid = ?',
(new_labeled_count, video_uuid)
)
conn.commit()
conn.close()
def add_frames_from_upload(video_uuid, frame_files):
conn = get_db_connection()
cursor = conn.cursor()
try:
video_data = cursor.execute('SELECT frame_count FROM videos WHERE video_uuid = ?', (video_uuid,)).fetchone()
if not video_data:
raise ValueError("Video UUID not found in database.")
start_frame_number = video_data['frame_count']
frames_to_insert = []
for i, file_storage_obj in enumerate(frame_files):
new_frame_number = start_frame_number + i
image_bytes = file_storage_obj.read()
file_storage.save_frame_image(video_uuid, new_frame_number, image_bytes)
frames_to_insert.append((video_uuid, new_frame_number, '', '', '', 1))
if frames_to_insert:
cursor.executemany(
'INSERT INTO video_frames (video_uuid, frame_number, bboxes_text, suggested_bboxes_text, tags, include_frame_in_dataset) VALUES (?, ?, ?, ?, ?, ?)',
frames_to_insert
)
new_total_frames = start_frame_number + len(frames_to_insert)
cursor.execute(
'UPDATE videos SET frame_count = ?, extracted_frame_count = ?, included_frame_count = ? WHERE video_uuid = ?',
(new_total_frames, new_total_frames, new_total_frames, video_uuid)
)
conn.commit()
return len(frames_to_insert)
except Exception as e:
conn.rollback()
raise e
finally:
conn.close()
def create_annotation_task(video_uuid, assigned_to, description, start_frame, end_frame):
conn = get_db_connection()
cursor = conn.cursor()
existing_tasks = cursor.execute(
'SELECT start_frame, end_frame FROM annotation_tasks WHERE video_uuid = ?', (video_uuid,)
).fetchall()
for task in existing_tasks:
if start_frame <= task['end_frame'] and end_frame >= task['start_frame']:
conn.close()
raise ValueError(f"Frame range overlaps with an existing task ({task['start_frame']}-{task['end_frame']}).")
task_uuid = str(uuid.uuid4().hex)
create_time_ms = int(time.time() * 1000)
cursor.execute(
'''INSERT INTO annotation_tasks
(task_uuid, video_uuid, assigned_to, description, start_frame, end_frame, status, create_time_ms)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)''',
(task_uuid, video_uuid, assigned_to, description, start_frame, end_frame, 'PENDING', create_time_ms)
)
conn.commit()
conn.close()
return task_uuid
def get_tasks_for_video(video_uuid):
conn = get_db_connection()
tasks = conn.execute('SELECT * FROM annotation_tasks WHERE video_uuid = ? ORDER BY create_time_ms DESC',
(video_uuid,)).fetchall()
conn.close()
return [dict(row) for row in tasks]
def get_task_entity(task_uuid):
conn = get_db_connection()
task = conn.execute('SELECT * FROM annotation_tasks WHERE task_uuid = ?', (task_uuid,)).fetchone()
conn.close()
return dict(task) if task else None
def delete_task(task_uuid):
conn = get_db_connection()
conn.execute('DELETE FROM annotation_tasks WHERE task_uuid = ?', (task_uuid,))
conn.commit()
conn.close()
def update_task_status(task_uuid, status):
conn = get_db_connection()
conn.execute('UPDATE annotation_tasks SET status = ? WHERE task_uuid = ?', (status, task_uuid))
conn.commit()
conn.close()
def add_class_label(label_name):
conn = get_db_connection()
try:
create_time_ms = int(time.time() * 1000)
conn.execute('INSERT INTO class_labels (label_name, create_time_ms) VALUES (?, ?)',
(label_name, create_time_ms))
conn.commit()
except sqlite3.IntegrityError:
pass
finally:
conn.close()
def get_all_class_labels():
conn = get_db_connection()
labels = conn.execute('SELECT label_name FROM class_labels ORDER BY label_name ASC').fetchall()
conn.close()
return [row['label_name'] for row in labels]
def delete_class_label(label_name):
conn = get_db_connection()
conn.execute('DELETE FROM class_labels WHERE label_name = ?', (label_name,))
conn.commit()
conn.close()
def get_all_frames_with_class(class_name):
conn = get_db_connection()
query = """
SELECT T1.*, T2.width, T2.height
FROM video_frames AS T1
INNER JOIN videos AS T2 ON T1.video_uuid = T2.video_uuid
INNER JOIN frame_labels AS T3 ON T1.frame_id = T3.frame_id
WHERE T3.label_name = ? \
"""
frames = conn.execute(query, (class_name,)).fetchall()
conn.close()
return [dict(row) for row in frames]
def add_class_tag(tag_name):
conn = get_db_connection()
try:
create_time_ms = int(time.time() * 1000)
conn.execute('INSERT INTO class_tags (tag_name, create_time_ms) VALUES (?, ?)', (tag_name, create_time_ms))
conn.commit()
except sqlite3.IntegrityError:
pass
finally:
conn.close()
def get_all_class_tags():
conn = get_db_connection()
tags = conn.execute('SELECT tag_name FROM class_tags ORDER BY tag_name ASC').fetchall()
conn.close()
return [row['tag_name'] for row in tags]
def delete_class_tag(tag_name):
conn = get_db_connection()
conn.execute('DELETE FROM class_tags WHERE tag_name = ?', (tag_name,))
conn.commit()
conn.close()
def create_dataset_entry(description, video_uuids, create_time_ms, eval_percent, test_percent):
conn = get_db_connection()
dataset_uuid = str(uuid.uuid4().hex)
conn.execute(
'INSERT INTO datasets (dataset_uuid, description, video_uuids, create_time_ms, status, eval_percent, test_percent) VALUES (?, ?, ?, ?, ?, ?, ?)',
(dataset_uuid, description, json.dumps(video_uuids), create_time_ms, 'PENDING', eval_percent, test_percent)
)
conn.commit()
conn.close()
return dataset_uuid
def update_dataset_status(dataset_uuid, status, message="", zip_path="", sorted_label_list=None):
conn = get_db_connection()
if sorted_label_list is not None:
conn.execute(
'UPDATE datasets SET status=?, status_message=?, zip_path=?, sorted_label_list=? WHERE dataset_uuid=?',
(status, message, zip_path, json.dumps(sorted_label_list), dataset_uuid)
)
else:
conn.execute(
'UPDATE datasets SET status=?, status_message=?, zip_path=? WHERE dataset_uuid=?',
(status, message, zip_path, dataset_uuid)
)
conn.commit()
conn.close()
def get_dataset_list():
conn = get_db_connection()
datasets = conn.execute('SELECT * FROM datasets ORDER BY create_time_ms DESC').fetchall()
conn.close()
return [dict(row) for row in datasets]
def get_dataset_entity(dataset_uuid):
conn = get_db_connection()
dataset = conn.execute('SELECT * FROM datasets WHERE dataset_uuid = ?', (dataset_uuid,)).fetchone()
conn.close()
return dict(dataset) if dataset else None
def delete_dataset(dataset_uuid):
conn = get_db_connection()
conn.execute('DELETE FROM datasets WHERE dataset_uuid = ?', (dataset_uuid,))
conn.commit()
conn.close()
def import_model_metadata(description, label_filename, model_type, create_time_ms):
conn = get_db_connection()
model_uuid = str(uuid.uuid4().hex)
conn.execute(
'INSERT INTO models (model_uuid, description, label_filename, model_type, create_time_ms) VALUES (?, ?, ?, ?, ?)',
(model_uuid, description, label_filename, model_type, create_time_ms)
)
conn.commit()
conn.close()
return model_uuid
def get_model_list():
conn = get_db_connection()
models = conn.execute('SELECT * FROM models ORDER BY create_time_ms DESC').fetchall()
conn.close()
return [dict(row) for row in models]
def get_model_entity(model_uuid):
conn = get_db_connection()
model = conn.execute('SELECT * FROM models WHERE model_uuid = ?', (model_uuid,)).fetchone()
conn.close()
return dict(model) if model else None
def delete_model(model_uuid):
conn = get_db_connection()
conn.execute('DELETE FROM models WHERE model_uuid = ?', (model_uuid,))
conn.commit()
conn.close()
def get_frame_numbers_for_video(video_uuid):
conn = get_db_connection()
frames = conn.execute('SELECT frame_number FROM video_frames WHERE video_uuid = ? ORDER BY frame_number ASC',
(video_uuid,)).fetchall()
conn.close()
return [row['frame_number'] for row in frames]