237 lines
8.2 KiB
Python
Raw Permalink Normal View History

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from dataclasses import dataclass, field
from typing import Optional, ClassVar, List, Tuple
import numpy as np
from numpy.typing import NDArray
from PIL import Image
import librosa
import librosa.util
import cv2
import os
from openai import OpenAI
import base64
# import ray # 注释掉Ray导入
# # 使用本地模式初始化Ray避免分布式通信问题
# ray.init(local_mode=True, ignore_reinit_error=True)
# 设置环境变量,禁用在线检查
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"
@dataclass(frozen=True)
class VideoAsset:
name: str
num_frames: int = -1
_NAME_TO_FILE: ClassVar[dict[str, str]] = {
"baby_reading": "sample_demo_1.mp4",
}
@property
def filename(self) -> str:
return self._NAME_TO_FILE[self.name]
@property
def pil_images(self) -> list[Image.Image]:
video_path = download_video_asset(self.filename)
return video_to_pil_images_list(video_path, self.num_frames)
@property
def np_ndarrays(self) -> NDArray:
video_path = download_video_asset(self.filename)
return video_to_ndarrays(video_path, self.num_frames)
def get_audio(self, sampling_rate: Optional[float] = None) -> NDArray:
video_path = download_video_asset(self.filename)
return librosa.load(video_path, sr=sampling_rate)[0]
@dataclass(frozen=True)
class LocalVideoAsset:
local_path: str
name: str = "local_video"
num_frames: int = -1
@property
def filename(self) -> str:
return self.local_path
@property
def pil_images(self) -> list[Image.Image]:
return video_to_pil_images_list(self.filename, self.num_frames)
@property
def np_ndarrays(self) -> NDArray:
return video_to_ndarrays(self.filename, self.num_frames)
def get_audio(self, sampling_rate: Optional[float] = None) -> NDArray:
try:
if not os.path.exists(self.filename):
print(f"音频文件不存在: {self.filename}")
return np.zeros(1) # 返回空数组
return librosa.load(self.filename, sr=sampling_rate)[0]
except Exception as e:
print(f"加载音频时出错: {e}")
return np.zeros(1) # 出错时返回空数组
# 辅助函数实现
def download_video_asset(filename: str) -> str:
# 如果路径是绝对路径或相对路径,直接返回
if filename.startswith("/") or filename.startswith("./"):
return filename
# 否则执行下载逻辑(原实现)
return f"/path/to/downloaded/{filename}"
def video_to_pil_images_list(video_path: str, num_frames: int) -> list[Image.Image]:
"""将视频转换为PIL图像列表"""
if not os.path.exists(video_path):
print(f"视频文件不存在: {video_path}")
return []
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"无法打开视频: {video_path}")
return []
# 获取视频帧数
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames / fps if fps > 0 else 0
print(f"视频信息: 总帧数={total_frames}, FPS={fps:.2f}, 时长={duration:.2f}")
# 如果指定了帧数,设置采样间隔;否则读取所有帧
if num_frames > 0 and num_frames < total_frames:
frame_interval = total_frames / num_frames
print(f"将提取 {num_frames} 帧,采样间隔为每 {frame_interval:.2f}")
else:
frame_interval = 1
num_frames = total_frames
print(f"将提取所有 {total_frames}")
pil_images = []
frame_count = 0
success = True
last_progress = -1
while success and len(pil_images) < num_frames:
# 读取下一帧
success, frame = cap.read()
if not success:
break
# 按间隔采样帧
if frame_count % max(1, int(frame_interval)) == 0:
# OpenCV使用BGR转为RGB
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 转为PIL图像
pil_image = Image.fromarray(rgb_frame)
pil_images.append(pil_image)
# 显示进度每10%显示一次)
progress = int(len(pil_images) / num_frames * 10)
if progress > last_progress:
print(f"提取进度: {len(pil_images)}/{num_frames} ({len(pil_images)/num_frames*100:.1f}%)")
last_progress = progress
frame_count += 1
cap.release()
print(f"从视频中共提取了 {len(pil_images)}")
return pil_images
def video_to_ndarrays(video_path: str, num_frames: int) -> NDArray:
"""将视频转换为NumPy数组"""
pil_images = video_to_pil_images_list(video_path, num_frames)
if not pil_images:
print(f"未能从视频中提取帧: {video_path}")
return np.zeros((1, 224, 224, 3))
# 将PIL图像列表转换为NumPy数组
arrays = []
for img in pil_images:
# 调整图像大小为统一尺寸
img_resized = img.resize((224, 224))
# 转换为NumPy数组
arr = np.array(img_resized)
arrays.append(arr)
# 堆叠为单个NumPy数组形状为[num_frames, height, width, channels]
stacked_array = np.stack(arrays, axis=0)
print(f"NumPy数组形状: {stacked_array.shape}")
return stacked_array
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def encode_audio(audio_path):
with open(audio_path, "rb") as audio_file:
return base64.b64encode(audio_file.read()).decode("utf-8")
def encode_npy(npy_path):
data = np.load(npy_path)
data_bytes = data.tobytes()
return base64.b64encode(data_bytes).decode('utf-8')
# 使用示例
if __name__ == "__main__":
#使用本地视频资源
local_video = LocalVideoAsset(
local_path="/root/autodl-tmp/hot_video_analyse/source/sample_demo_1.mp4",
num_frames = -1 # 限制帧数以加快测试速度
)
print("本地资源:", local_video.filename)
#获取PIL图像列表实际会调用download_video_asset和转换函数
pil_images = local_video.pil_images
print("PIL图像数量:", len(pil_images))
print(pil_images[0])
#获取NumPy数组
print("\n=== 加载视频帧 ===")
np_arrays = local_video.np_ndarrays
print(f"视频数组形状: {np_arrays.shape}")
scene_change_arrays = np.load('/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_arrays.npy')
print(f"场景变化帧数组形状: {scene_change_arrays.shape}")
print(len(scene_change_arrays))
#获取音频数据
audio = local_video.get_audio(sampling_rate=16000)
print("音频数据形状:", audio.shape)
print(type(audio))
# 指定每段视频的最大帧数为 4。可以调整此值。
llm = LLM("/root/autodl-tmp/llm", limit_mm_per_prompt={"image": 1})
# 创建请求负载。
#video_frames = scene_change_arrays # load your video making sure it only has the number of frames specified earlier.
video_frames = scene_change_arrays # 加载视频,确保帧数不超过之前指定的数量。
# first place, in data preprocessing
audios, images, videos = process_mm_info(conversations, use_audio_in_video=True)
outputs = llm.generate(
[
{
# 对于视频/图像数据,需要包含 <video> 或 <image> 占位符
"prompt": "USER: <video>\n描述这个视频的内容。考虑这些帧都是同一个视频的不同场景。\nASSISTANT:",
"multi_modal_data": {"video": video_frames},
},
{
# 对于音频数据,需要包含 <audio> 占位符
"prompt": "USER: <audio>\n描述一下这个音频的内容,包括背景音乐、声音效果等。\nASSISTANT:",
"multi_modal_data": {"audio": audio},
}
]
)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)