From 78c555569bdc66e34720fb8f6b5faaf56bd7e713 Mon Sep 17 00:00:00 2001 From: Carl Xu Date: Tue, 25 Feb 2025 20:32:08 -0500 Subject: [PATCH 1/4] completed script --- src/mouse_tracker.py | 137 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 src/mouse_tracker.py diff --git a/src/mouse_tracker.py b/src/mouse_tracker.py new file mode 100644 index 0000000..d4f58d6 --- /dev/null +++ b/src/mouse_tracker.py @@ -0,0 +1,137 @@ +import math +import numpy as np +import cv2 +import pandas as pd +import template_matcher +import os +import glob + +def extract_frames(video_path, start_time, end_time): + video_capture = cv2.VideoCapture(video_path) + fps = video_capture.get(cv2.CAP_PROP_FPS) + start_frame = int(start_time * fps) + end_frame = int(end_time * fps) + interval = int(0.1 * fps) # 100ms interval + + grayscale_frames = [] # List to store grayscale frames + current_frame = start_frame + video_capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame) + while current_frame <= end_frame: + ret, frame = video_capture.read() + if not ret: + break + # Convert frame to grayscale + gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + binary_frame = template_matcher.binarize_image(gray_frame) + grayscale_frames.append(binary_frame) + current_frame += interval + video_capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame) + video_capture.release() + return np.array(grayscale_frames) # Return an array of grayscale frames + +def closest_point(point, pt_list): + x, y = point + dists = [] + for pt_x, pt_y in pt_list: + dists.append(int(round(math.sqrt((pt_x - x)**2 + (pt_y - y)**2)))) + return pt_list[dists.index(min(dists))] + +def load_templates(templates_folder, white_threshold): + images = [] + image_patterns = ["*.jp*g", "*.png", "*.bmp"] + + image_paths = [] + for pattern in image_patterns: + image_paths.extend(glob.glob(os.path.join(templates_folder, pattern))) + + for img_path in image_paths: + image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) + if image is not None: + images.append((img_path, image)) + else: + print(f"Warning: Failed to load image '{img_path}'.") + if not images: + print("Error: No valid template images loaded.") + return [(path, template_matcher.binarize_image(img, white_threshold)) for path, img in images] + +def get_cursor_loc(frames, idx, templates, prev_x, prev_y, detection_threshold, break_threshold, detection_threshold_decrement): + frame = frames[idx] + detected_regions = [] + while not detected_regions: + for template_path, template in templates: + result = cv2.matchTemplate(frame, template, cv2.TM_CCOEFF_NORMED) + locations = np.where(result >= detection_threshold) + + for pt in zip(*locations[::-1]): + # w, h = template.shape[::-1] + # detected_regions.append((pt[0], pt[1], w, h, template_path)) + detected_regions.append((int(pt[0]), int(pt[1]))) + if detection_threshold <= break_threshold: + break + detection_threshold -= detection_threshold_decrement + if len(detected_regions) == 1: # one + return detected_regions[0][0], detected_regions[0][1] + elif not detected_regions: # none + return prev_x, prev_y + else: # more than one -> pick the closest + return closest_point((prev_x, prev_y), detected_regions) + +def make_tr_dict(frames, templates, prev_x, prev_y, detection_threshold, break_threshold, detection_threshold_decrement): + tr_dist_dict = {} + # if only 1 match, then get x and y, if none, decrease threshold until 1 match, if multiple matches then discard/infer from previous + frame_counter = 1 + for tr in range(600): # 15*60/1.5 = 600 TR + dists = [] + for ms in range(15): # 15 units of 100ms in each TR + curr_x, curr_y = get_cursor_loc(frames, frame_counter, templates, prev_x, prev_y, detection_threshold, break_threshold, detection_threshold_decrement) + dists.append(int(round(math.sqrt((curr_x - prev_x)**2 + (curr_y - prev_y)**2)))) + prev_x, prev_y = curr_x, curr_y + frame_counter += 1 + + # store the data by hashmapping by TR (like 1: (sum of dists for 15 intervals each interval being 100ms)) + tr_dist_dict[tr+1] = sum(dists) + return tr_dist_dict + +def dict_to_csv(tr_dist_dict, output_filepath): + df = pd.DataFrame.from_dict(tr_dist_dict, orient='index', columns=["Distance"]) + df.index.name = "TR" + df.to_csv(output_filepath) + +def main(): + # paths + input_filepath = 'Web47_run1.mkv' + if not os.path.exists(input_filepath): + print(f"File '{input_filepath}' doesn't exist.") + exit() + output_filepath = 'tr_dist_dict_10_50_new.csv' + if os.path.exists(output_filepath): + print(f"File '{output_filepath}' exists.") + exit() + templates_folderpath = 'templates' + if not os.path.exists(templates_folderpath): + print(f"File '{templates_folderpath}' doesn't exist.") + exit() + + # thresholds + detection_threshold = 0.9 + detection_threshold_decrement = 0.025 + break_threshold = 0.8 + white_threshold = 200 + + # extract frames + frames = extract_frames(input_filepath, 25, 925) + + # get templates + templates = load_templates(templates_folderpath, white_threshold) + + # get initial frame + prev_x, prev_y = get_cursor_loc(frames, 0, templates, 750, 400, detection_threshold, break_threshold, detection_threshold_decrement) + + # make dict + tr_dist_dict = make_tr_dict(frames, templates, prev_x, prev_y, detection_threshold, break_threshold, detection_threshold_decrement) + + # save dict + dict_to_csv(tr_dist_dict, output_filepath) + +if __name__ == "__main__": + main() \ No newline at end of file From dabfe5e55951fdc7a8b80fc9c0c09db949c1d3c1 Mon Sep 17 00:00:00 2001 From: Carl Xu Date: Mon, 10 Mar 2025 12:15:41 -0400 Subject: [PATCH 2/4] updated mouse_tracker to manual cursor track for frames that didnt detect a mouse --- src/mouse_tracker.py | 107 ++++++++++++++++++++++++++-------------- src/template_matcher.py | 2 +- 2 files changed, 70 insertions(+), 39 deletions(-) diff --git a/src/mouse_tracker.py b/src/mouse_tracker.py index d4f58d6..8420d3c 100644 --- a/src/mouse_tracker.py +++ b/src/mouse_tracker.py @@ -5,29 +5,32 @@ import template_matcher import os import glob +import matplotlib.pyplot as plt -def extract_frames(video_path, start_time, end_time): +def extract_frames(video_path, white_threshold, start_time, end_time): video_capture = cv2.VideoCapture(video_path) fps = video_capture.get(cv2.CAP_PROP_FPS) start_frame = int(start_time * fps) end_frame = int(end_time * fps) interval = int(0.1 * fps) # 100ms interval - grayscale_frames = [] # List to store grayscale frames + rgb_frames = [] # List to store rgb frames + binary_frames = [] # List to store grayscale frames current_frame = start_frame video_capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame) while current_frame <= end_frame: ret, frame = video_capture.read() if not ret: break + rgb_frames.append(frame) # Convert frame to grayscale gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) - binary_frame = template_matcher.binarize_image(gray_frame) - grayscale_frames.append(binary_frame) + binary_frame = template_matcher.binarize_image(gray_frame, white_threshold) + binary_frames.append(binary_frame) current_frame += interval video_capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame) video_capture.release() - return np.array(grayscale_frames) # Return an array of grayscale frames + return np.array(binary_frames), np.array(rgb_frames) # Return an array of grayscale frames def closest_point(point, pt_list): x, y = point @@ -54,36 +57,54 @@ def load_templates(templates_folder, white_threshold): print("Error: No valid template images loaded.") return [(path, template_matcher.binarize_image(img, white_threshold)) for path, img in images] -def get_cursor_loc(frames, idx, templates, prev_x, prev_y, detection_threshold, break_threshold, detection_threshold_decrement): - frame = frames[idx] +def pick_pixel(frame): + selected_pixel = (-1, -1) + def on_click(event): + global selected_pixel + if event.xdata is not None and event.ydata is not None: + col = int(round(event.xdata)) + row = int(round(event.ydata)) + if 0 <= row < 768 and 0 <= col < 1280: + selected_pixel = (row, col) + plt.close() + fig, ax = plt.subplots() + ax.imshow(frame) + fig.canvas.mpl_connect('button_press_event', on_click) + plt.show() + return selected_pixel + +def get_cursor_loc(binary_frames, rgb_frames, idx, templates, prev_x, prev_y, detection_threshold): + frame = binary_frames[idx] detected_regions = [] - while not detected_regions: - for template_path, template in templates: - result = cv2.matchTemplate(frame, template, cv2.TM_CCOEFF_NORMED) - locations = np.where(result >= detection_threshold) - - for pt in zip(*locations[::-1]): - # w, h = template.shape[::-1] - # detected_regions.append((pt[0], pt[1], w, h, template_path)) - detected_regions.append((int(pt[0]), int(pt[1]))) - if detection_threshold <= break_threshold: - break - detection_threshold -= detection_threshold_decrement + for template_path, template in templates: + result = cv2.matchTemplate(frame, template, cv2.TM_CCOEFF_NORMED) + locations = np.where(result >= detection_threshold) + + for pt in zip(*locations[::-1]): + detected_regions.append((int(pt[0]), int(pt[1]))) if len(detected_regions) == 1: # one return detected_regions[0][0], detected_regions[0][1] - elif not detected_regions: # none - return prev_x, prev_y - else: # more than one -> pick the closest - return closest_point((prev_x, prev_y), detected_regions) + else: # none or multiple + # display frame + # option to pick no visible cursor + # return cursor location or previous cursor location + curr_x, curr_y = pick_pixel(rgb_frames[idx]) + if curr_x != -1 and curr_y != -1: + return curr_x, curr_y + else: + return prev_x, prev_y + # else: # more than one -> pick the closest + # return closest_point((prev_x, prev_y), detected_regions) -def make_tr_dict(frames, templates, prev_x, prev_y, detection_threshold, break_threshold, detection_threshold_decrement): +def make_tr_dict(binary_frames, rgb_frames, templates, prev_x, prev_y, detection_threshold): tr_dist_dict = {} # if only 1 match, then get x and y, if none, decrease threshold until 1 match, if multiple matches then discard/infer from previous frame_counter = 1 - for tr in range(600): # 15*60/1.5 = 600 TR + #for tr in range(600): # 15*60/1.5 = 600 TR + for tr in range(10): # 15*60/1.5 = 600 TR dists = [] for ms in range(15): # 15 units of 100ms in each TR - curr_x, curr_y = get_cursor_loc(frames, frame_counter, templates, prev_x, prev_y, detection_threshold, break_threshold, detection_threshold_decrement) + curr_x, curr_y = get_cursor_loc(binary_frames, rgb_frames, frame_counter, templates, prev_x, prev_y, detection_threshold) dists.append(int(round(math.sqrt((curr_x - prev_x)**2 + (curr_y - prev_y)**2)))) prev_x, prev_y = curr_x, curr_y frame_counter += 1 @@ -97,41 +118,51 @@ def dict_to_csv(tr_dist_dict, output_filepath): df.index.name = "TR" df.to_csv(output_filepath) -def main(): +def main(input_filepath, output_filepath, templates_folderpath, start_time, end_time): # paths - input_filepath = 'Web47_run1.mkv' if not os.path.exists(input_filepath): print(f"File '{input_filepath}' doesn't exist.") exit() - output_filepath = 'tr_dist_dict_10_50_new.csv' if os.path.exists(output_filepath): print(f"File '{output_filepath}' exists.") exit() - templates_folderpath = 'templates' if not os.path.exists(templates_folderpath): print(f"File '{templates_folderpath}' doesn't exist.") exit() # thresholds - detection_threshold = 0.9 - detection_threshold_decrement = 0.025 - break_threshold = 0.8 - white_threshold = 200 + detection_threshold = 0.85 + white_threshold = 127 # extract frames - frames = extract_frames(input_filepath, 25, 925) + binary_frames, rgb_frames = extract_frames(input_filepath, white_threshold, start_time, end_time) # get templates templates = load_templates(templates_folderpath, white_threshold) + # get frame dimensions + dim_y, dim_x = binary_frames[0].shape + # get initial frame - prev_x, prev_y = get_cursor_loc(frames, 0, templates, 750, 400, detection_threshold, break_threshold, detection_threshold_decrement) + prev_x, prev_y = get_cursor_loc(binary_frames, rgb_frames, 0, templates, int(round(dim_x/2)), int(round(dim_y/2)), detection_threshold) # make dict - tr_dist_dict = make_tr_dict(frames, templates, prev_x, prev_y, detection_threshold, break_threshold, detection_threshold_decrement) + tr_dist_dict = make_tr_dict(binary_frames, rgb_frames, templates, prev_x, prev_y, detection_threshold) # save dict dict_to_csv(tr_dist_dict, output_filepath) if __name__ == "__main__": - main() \ No newline at end of file + # input_filepath = 'Web47_run1.mkv' + # output_filepath = 'Web47_1_mouse_dists_selection.csv' + # templates_folderpath = 'binary-image-matching/src/templates' + # start_time = 25 + # end_time = 925 + # main(input_filepath, output_filepath, templates_folderpath, start_time, end_time) + + input_filepath = 'Web08_run2.mkv' + output_filepath = 'Web08_2_mouse_dists_selection.csv' + templates_folderpath = 'binary-image-matching/src/templates' + start_time = 185 + end_time = 1085 + main(input_filepath, output_filepath, templates_folderpath, start_time, 200) \ No newline at end of file diff --git a/src/template_matcher.py b/src/template_matcher.py index e5c2210..63dc43d 100644 --- a/src/template_matcher.py +++ b/src/template_matcher.py @@ -32,7 +32,7 @@ def load_images(path, grayscale=True): return images -def binarize_image(image, threshold=200): +def binarize_image(image, threshold): """ Converts an image to a binary format using a specified threshold. From 21a51a26bd51c50d02559f549229f52e0304902e Mon Sep 17 00:00:00 2001 From: Carl Xu Date: Sun, 6 Apr 2025 16:16:49 -0400 Subject: [PATCH 3/4] update --- src/mouse_tracker.py | 68 +++++++++++++++++++++++++++++++------------- 1 file changed, 48 insertions(+), 20 deletions(-) diff --git a/src/mouse_tracker.py b/src/mouse_tracker.py index 8420d3c..4ac6e3b 100644 --- a/src/mouse_tracker.py +++ b/src/mouse_tracker.py @@ -15,14 +15,14 @@ def extract_frames(video_path, white_threshold, start_time, end_time): interval = int(0.1 * fps) # 100ms interval rgb_frames = [] # List to store rgb frames - binary_frames = [] # List to store grayscale frames + binary_frames = [] # List to store binary frames current_frame = start_frame video_capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame) while current_frame <= end_frame: ret, frame = video_capture.read() if not ret: break - rgb_frames.append(frame) + rgb_frames.append(frame) # bgr # Convert frame to grayscale gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) binary_frame = template_matcher.binarize_image(gray_frame, white_threshold) @@ -68,42 +68,70 @@ def on_click(event): selected_pixel = (row, col) plt.close() fig, ax = plt.subplots() - ax.imshow(frame) + ax.imshow(frame, cmap='gray') fig.canvas.mpl_connect('button_press_event', on_click) + manager = plt.get_current_fig_manager() + manager.window.state('zoomed') plt.show() return selected_pixel def get_cursor_loc(binary_frames, rgb_frames, idx, templates, prev_x, prev_y, detection_threshold): frame = binary_frames[idx] detected_regions = [] - for template_path, template in templates: - result = cv2.matchTemplate(frame, template, cv2.TM_CCOEFF_NORMED) - locations = np.where(result >= detection_threshold) - - for pt in zip(*locations[::-1]): - detected_regions.append((int(pt[0]), int(pt[1]))) + while detection_threshold >= 0.7: + for template_path, template in templates: + w, h = template.shape[::-1] + result = cv2.matchTemplate(frame, template, cv2.TM_CCOEFF_NORMED) + locations = np.where(result >= detection_threshold) + + for pt in zip(*locations[::-1]): + detected_regions.append((int(pt[0]), int(pt[1]))) + cv2.rectangle(rgb_frames[idx], pt, (pt[0] + w, pt[1] + h), (0,0,255), 1) + if(len(detected_regions) != 0): + break + detection_threshold -= 0.05 if len(detected_regions) == 1: # one return detected_regions[0][0], detected_regions[0][1] - else: # none or multiple - # display frame - # option to pick no visible cursor - # return cursor location or previous cursor location + elif len(detected_regions) > 10 or len(detected_regions) == 0: # large multiple + print(len(detected_regions)) + curr_x, curr_y = pick_pixel(binary_frames[idx]) curr_x, curr_y = pick_pixel(rgb_frames[idx]) if curr_x != -1 and curr_y != -1: return curr_x, curr_y else: return prev_x, prev_y - # else: # more than one -> pick the closest - # return closest_point((prev_x, prev_y), detected_regions) + else: # if multiple, check if regions are close together (within 3 pixels) + # if close, return average + close = True + for idx1 in range(len(detected_regions)): + region = detected_regions[idx1] + for idx2 in range(idx1, len(detected_regions)): + if int(round(math.sqrt((region[idx1] - region[idx2])**2 + (region[idx1] - region[idx2])**2))) > 10: + close = False + break + if close == False: + break + if close: + x_list = [region[0] for region in detected_regions] + y_list = [region[1] for region in detected_regions] + return (int(round(np.mean(x_list))), int(round(np.mean(y_list)))) + else: + print(len(detected_regions)) + curr_x, curr_y = pick_pixel(binary_frames[idx]) + curr_x, curr_y = pick_pixel(rgb_frames[idx]) + if curr_x != -1 and curr_y != -1: + return curr_x, curr_y + else: + return prev_x, prev_y def make_tr_dict(binary_frames, rgb_frames, templates, prev_x, prev_y, detection_threshold): tr_dist_dict = {} # if only 1 match, then get x and y, if none, decrease threshold until 1 match, if multiple matches then discard/infer from previous frame_counter = 1 #for tr in range(600): # 15*60/1.5 = 600 TR - for tr in range(10): # 15*60/1.5 = 600 TR + for tr in range(20): # TESTING dists = [] - for ms in range(15): # 15 units of 100ms in each TR + for ms in range(15): # 15 units of 100ms in each TR (NEED TO CHANGE IF DECREASE INTERVAL) curr_x, curr_y = get_cursor_loc(binary_frames, rgb_frames, frame_counter, templates, prev_x, prev_y, detection_threshold) dists.append(int(round(math.sqrt((curr_x - prev_x)**2 + (curr_y - prev_y)**2)))) prev_x, prev_y = curr_x, curr_y @@ -131,8 +159,8 @@ def main(input_filepath, output_filepath, templates_folderpath, start_time, end_ exit() # thresholds - detection_threshold = 0.85 - white_threshold = 127 + detection_threshold = 0.90 + white_threshold = 207 # extract frames binary_frames, rgb_frames = extract_frames(input_filepath, white_threshold, start_time, end_time) @@ -165,4 +193,4 @@ def main(input_filepath, output_filepath, templates_folderpath, start_time, end_ templates_folderpath = 'binary-image-matching/src/templates' start_time = 185 end_time = 1085 - main(input_filepath, output_filepath, templates_folderpath, start_time, 200) \ No newline at end of file + main(input_filepath, output_filepath, templates_folderpath, start_time, 215) \ No newline at end of file From 244d727bd36e6069936643b042ad0b929e942e30 Mon Sep 17 00:00:00 2001 From: Carl Xu Date: Sun, 6 Apr 2025 16:20:14 -0400 Subject: [PATCH 4/4] update --- src/mouse_tracker.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/mouse_tracker.py b/src/mouse_tracker.py index 4ac6e3b..ab52e4d 100644 --- a/src/mouse_tracker.py +++ b/src/mouse_tracker.py @@ -68,7 +68,7 @@ def on_click(event): selected_pixel = (row, col) plt.close() fig, ax = plt.subplots() - ax.imshow(frame, cmap='gray') + ax.imshow(frame) fig.canvas.mpl_connect('button_press_event', on_click) manager = plt.get_current_fig_manager() manager.window.state('zoomed') @@ -93,8 +93,6 @@ def get_cursor_loc(binary_frames, rgb_frames, idx, templates, prev_x, prev_y, de if len(detected_regions) == 1: # one return detected_regions[0][0], detected_regions[0][1] elif len(detected_regions) > 10 or len(detected_regions) == 0: # large multiple - print(len(detected_regions)) - curr_x, curr_y = pick_pixel(binary_frames[idx]) curr_x, curr_y = pick_pixel(rgb_frames[idx]) if curr_x != -1 and curr_y != -1: return curr_x, curr_y