| | from typing import Any, Dict |
| |
|
| | import torch |
| | from transformers import AutoModel, AutoProcessor |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | |
| | self.processor = AutoProcessor.from_pretrained("suno/bark") |
| | self.model = AutoModel.from_pretrained( |
| | "suno/bark", |
| | ).to("cuda") |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: |
| | """ |
| | Args: |
| | data (:dict:): |
| | The payload with the text prompt and generation parameters. |
| | """ |
| | |
| | text = data.pop("inputs", data) |
| | voice_preset = data.get("voice_preset", None) |
| | if voice_preset: |
| | inputs = self.processor( |
| | text=[text], |
| | return_tensors="pt", |
| | voice_preset=voice_preset, |
| | ).to("cuda") |
| | else: |
| | inputs = self.processor( |
| | text=[text], |
| | return_tensors="pt", |
| | ).to("cuda") |
| |
|
| | with torch.autocast("cuda"): |
| | outputs = self.model.generate(**inputs) |
| |
|
| | |
| | prediction = outputs.cpu().numpy().tolist() |
| |
|
| | return {"generated_audio": prediction} |
| |
|