| import os |
| import random |
| import io |
| import av |
| import cv2 |
| import decord |
| import imageio |
| from decord import VideoReader |
| import torch |
| import numpy as np |
| import math |
| import torch.nn.functional as F |
| decord.bridge.set_bridge("torch") |
|
|
| from transformers import AutoConfig, AutoModel |
| config = AutoConfig.from_pretrained("/fs-computility/video/heyinan/iv2hf/", trust_remote_code=True) |
| model = AutoModel.from_pretrained("/fs-computility/video/heyinan/iv2hf/", trust_remote_code=True).to(config.device) |
|
|
|
|
| def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1, start=None, end=None): |
| start_frame, end_frame = 0, vlen |
| if start is not None: |
| start_frame = max(start_frame,int(start * input_fps)) |
| if end is not None: |
| end_frame = min(end_frame,int(end * input_fps)) |
|
|
| |
| if start_frame >= end_frame: |
| raise ValueError("Start frame index must be less than end frame index") |
|
|
| |
| clip_length = end_frame - start_frame |
|
|
| if sample in ["rand", "middle"]: |
| acc_samples = min(num_frames, clip_length) |
| |
| intervals = np.linspace(start=start_frame, stop=end_frame, num=acc_samples + 1).astype(int) |
| ranges = [] |
| for idx, interv in enumerate(intervals[:-1]): |
| ranges.append((interv, intervals[idx + 1] - 1)) |
| if sample == 'rand': |
| try: |
| frame_indices = [random.choice(range(x[0], x[1] + 1)) for x in ranges] |
| except: |
| frame_indices = np.random.permutation(clip_length)[:acc_samples] + start_frame |
| frame_indices.sort() |
| frame_indices = list(frame_indices) |
| elif fix_start is not None: |
| frame_indices = [x[0] + fix_start for x in ranges] |
| elif sample == 'middle': |
| frame_indices = [(x[0] + x[1]) // 2 for x in ranges] |
| else: |
| raise NotImplementedError |
|
|
| if len(frame_indices) < num_frames: |
| padded_frame_indices = [frame_indices[-1]] * num_frames |
| padded_frame_indices[:len(frame_indices)] = frame_indices |
| frame_indices = padded_frame_indices |
| elif "fps" in sample: |
| output_fps = float(sample[3:]) |
| duration = float(clip_length) / input_fps |
| delta = 1 / output_fps |
| frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) |
| frame_indices = np.around(frame_seconds * input_fps).astype(int) + start_frame |
| frame_indices = [e for e in frame_indices if e < end_frame] |
| if max_num_frames > 0 and len(frame_indices) > max_num_frames: |
| frame_indices = frame_indices[:max_num_frames] |
| |
| else: |
| raise ValueError |
| return frame_indices |
|
|
| def read_frames_decord( |
| video_path, num_frames, sample='middle', fix_start=None, |
| max_num_frames=-1, client=None, trimmed30=False, start=None, end=None |
| ): |
| num_threads = 1 if video_path.endswith('.webm') else 0 |
|
|
| video_reader = VideoReader(video_path, num_threads=num_threads) |
| vlen = len(video_reader) |
| |
| fps = video_reader.get_avg_fps() |
| duration = vlen / float(fps) |
|
|
| frame_indices = get_frame_indices( |
| num_frames, vlen, sample=sample, fix_start=fix_start, |
| input_fps=fps, max_num_frames=max_num_frames, start=start, end=end |
| ) |
|
|
| frames = video_reader.get_batch(frame_indices) |
| frames = frames.permute(0, 3, 1, 2) |
| return frames, frame_indices, duration |
|
|
| def get_text_feature(model, texts): |
| text_input = model.tokenizer(texts).to(model.device) |
| text_features = model.encode_text(text_input) |
| return text_features |
| |
| def get_similarity(video_feature, text_feature): |
| video_feature = F.normalize(video_feature, dim=-1) |
| text_feature = F.normalize(text_feature, dim=-1) |
| sim_matrix = text_feature @ video_feature.T |
| return sim_matrix |
|
|
| def get_top_videos(model, text_features, video_features, video_paths, texts): |
| |
|
|
| video_features = F.normalize(video_features, dim=-1) |
| text_features = F.normalize(text_features, dim=-1) |
|
|
| |
| sim_matrix = text_features @ video_features.T |
| |
|
|
| top_k = 5 |
| sim_matrix_top_k = torch.topk(sim_matrix, top_k, dim=1)[1] |
| softmax_sim_matrix = F.softmax(sim_matrix, dim=1) |
|
|
| retrieval_infos = {} |
| for i in range(len(sim_matrix_top_k)): |
| print("\n",texts[i]) |
| retrieval_infos[texts[i]] = [] |
| for j in range(top_k): |
| print("top", j+1, ":", video_paths[sim_matrix_top_k[i][j]], "~prob:", sim_matrix[i][sim_matrix_top_k[i][j]].item()) |
| retrieval_infos[texts[i]].append({"video": video_paths[sim_matrix_top_k[i][j]], "prob": sim_matrix[i][sim_matrix_top_k[i][j]].item(), "rank": j+1}) |
| return retrieval_infos |
|
|
| if __name__=="__main__": |
| video_features = [] |
| demo_videos = ["video1.mp4","video2.mp4"] |
| texts = ['a person talking', 'a logo', 'a building'] |
| for video_path in demo_videos: |
| frames, frame_indices, video_duration = read_frames_decord(video_path,8) |
| frames = model.transform(frames).unsqueeze(0).to(model.device) |
| with torch.no_grad(): |
| video_feature = model.encode_vision(frames, test=True) |
| video_features.append(video_feature) |
| |
| text_features = get_text_feature(model, texts) |
| video_features = torch.cat(video_features, dim=0).to(text_features.dtype).to(config.device) |
| results = get_top_videos(model, text_features, video_features, demo_videos, texts) |
|
|
|
|
|
|