|
|
from typing import Dict, Any |
|
|
import base64 |
|
|
import tempfile |
|
|
import os |
|
|
os.environ["TRANSFORMERS_NO_FLASH_ATTN_2"] = "1" |
|
|
import sys |
|
|
|
|
|
|
|
|
|
|
|
sys.path.append('./') |
|
|
|
|
|
from videollama2 import model_init, mm_infer |
|
|
from videollama2.utils import disable_torch_init |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
|
|
|
disable_torch_init() |
|
|
|
|
|
self.model_path = path or "DAMO-NLP-SG/VideoLLaMA2-7B-16F" |
|
|
|
|
|
self.model, self.processor, self.tokenizer = model_init(self.model_path) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
期待输入数据格式: |
|
|
{ |
|
|
"video": "<base64字符串>", # 视频文件base64编码 |
|
|
"prompt": "描述视频内容的自然语言指令" |
|
|
} |
|
|
或者 |
|
|
{ |
|
|
"image": "<base64字符串>", # 图片文件base64编码 |
|
|
"prompt": "描述图片内容的自然语言指令" |
|
|
} |
|
|
""" |
|
|
|
|
|
data = data.get("inputs", data) |
|
|
|
|
|
|
|
|
if "video" in data: |
|
|
modal = "video" |
|
|
file_b64 = data["video"] |
|
|
elif "image" in data: |
|
|
modal = "image" |
|
|
file_b64 = data["image"] |
|
|
else: |
|
|
return {"error": "请求必须包含 'video' 或 'image' 字段"} |
|
|
|
|
|
prompt = data.get("prompt", "Describe the content.") |
|
|
|
|
|
|
|
|
if os.path.exists(file_b64): |
|
|
tmp_path = file_b64 |
|
|
cleanup = False |
|
|
else: |
|
|
|
|
|
suffix = ".mp4" if modal == "video" else ".png" |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file: |
|
|
tmp_file.write(base64.b64decode(file_b64)) |
|
|
tmp_path = tmp_file.name |
|
|
cleanup = True |
|
|
|
|
|
try: |
|
|
|
|
|
inputs = self.processor[modal](tmp_path) |
|
|
output = mm_infer( |
|
|
inputs, |
|
|
prompt, |
|
|
model=self.model, |
|
|
tokenizer=self.tokenizer, |
|
|
do_sample=False, |
|
|
modal=modal |
|
|
) |
|
|
finally: |
|
|
if cleanup: |
|
|
os.remove(tmp_path) |
|
|
|
|
|
|
|
|
return { |
|
|
"modal": modal, |
|
|
"prompt": prompt, |
|
|
"result": output |
|
|
} |
|
|
|