Podcast Semantic Search

Transcribe audio using OpenAI Whisper, then index the content by it's semantic value in a vector space

Podcast Semantic Search

A working example can be found here

import time
import os
from dotenv import load_dotenv
from tqdm.auto import tqdm
import torch
import whisper
from sentence_transformers import SentenceTransformer
import pinecone
# custom GC Storage Bucket functions
from storage import download_file_to_memory, upload_file_from_memory

load_dotenv()

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
trans_model = whisper.load_model("medium")
emb_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

pinecone.init(
    api_key=os.getenv("PINECONE_DEV_API_KEY"),
    environment=os.getenv("PINECONE_DEV_ENV", "asia-southeast1-gcp-free"),
)
index = pinecone.Index(os.getenv("PINECONE_DEV_INDEX"))

# TODO get podcasts
podcasts = []

for pod in podcasts:
    # check if we've already transcribed
    result = download_file_to_memory("bucket-name", f"{pod['guid']}.json")
    if result is None:
        # transcribe the audio
        try:
            tic = time.perf_counter()
            result = trans_model.transcribe(pod["audioUrl"])
            toc = time.perf_counter()
            print(f"Transcribed in {(toc - tic) / 60:0.4f} minutes")
        except:
            print("Failed to transcribe")
            continue
        # save the transcription result to storage
        upload_file_from_memory("bucket-name", f"{pod['guid']}.json", result)
    # combine the transcribed sentences
    data = []
    for seg in result["segments"]:
        meta = {
            "id": f"{pod['id']}-${seg['start']}",
            "audio_url": pod["audioUrl"],
            "created_at": pod["createdAt"],
            **seg,
        }
        data.append(meta)
    # TODO combine sentences into groups of 6 with 1 or 2 sentences of overlap for context
    # batch upsert embeds into vector space
    for i in tqdm(range(0, len(data), 64)):
        # find end position of batch (for when we hit end of data)
        if i == len(data) - 1:
            i_end = len(data)
        else:
            i_end = min(len(data) - 1, i + 64)
        # extract the metadata like text, start/end positions, etc
        batch_meta = [
            {
                "text": data[x]["text"],
                "start": data[x]["start"],
                "end": data[x]["end"],
                "audio_url": data[x]["audio_url"],
                "created_at": data[x]["created_at"],
            }
            for x in range(i, i_end)
        ]
        # extract only text to be encoded by embedding model
        batch_text = [row["text"] for row in data[i:i_end]]
        # create the embedding vectors
        batch_embeds = emb_model.encode(batch_text).tolist()
        print("Embedded text")
        # extract IDs to be attached to each embedding and metadata
        batch_ids = [row["id"] for row in data[i:i_end]]
        # 'upsert' (insert) IDs, embeddings, and metadata to index
        to_upsert = list(zip(batch_ids, batch_embeds, batch_meta))
        index.upsert(to_upsert)
        print("Upserted embeddeds")