Source code for pasteur.amalgam.llm

import logging
import threading
import time
from typing import Any, Literal, Mapping, Type, TypedDict

import numpy as np
import pandas as pd
from pydantic.main import BaseModel

from pasteur.utils import LazyDataset
from pasteur.utils.progress import prange

PRINT_FREQ = 0.2
TOP_K = 3
MAX_EXPOSURE = 5
PART_SIZE = 5000
THINK = False
MAX_FAILS = 20

logger = logging.getLogger(__name__)


[docs] class CacheTracker: def __init__(self, cache_time: float | None = None, cache_len: int | None = None): self.lock = threading.Lock() self.cache = {} self.idx = 0 self.cache_time = cache_time self.cache_len = cache_len self.cached_tokens = 0 def _get_cached_len(self, tokens: list[int], ctime=None): curr = self.cache for i, t in enumerate(tokens): if t not in curr: return i if ctime is not None and ctime > curr[t].get("_ctime", 0): return i if self.cache_len and self.idx - self.cache_len > curr[t].get("_cidx", 0): return i curr = curr[t] return len(tokens)
[docs] def get_cached_len(self, tokens: list[int]): with self.lock: i = self._get_cached_len( tokens, time.perf_counter() - self.cache_time if self.cache_time else None, ) self.cached_tokens += i return i
def _add_cached_tokens(self, tokens: list[int]): curr = self.cache for t in tokens: if t not in curr: curr[t] = {} curr["_ctime"] = time.perf_counter() curr["_cidx"] = self.idx curr = curr[t] self.idx += 1
[docs] def add_cached_tokens(self, tokens: list[int]): with self.lock: self._add_cached_tokens(tokens)
[docs] class AmalgamHFParams(TypedDict): type: Literal["hf"] repo_id: str filename: str n_ctx: int n_gpu_layers: int workers: int
[docs] class AmalgamORParams(TypedDict): type: Literal["or"] model: str workers: int
[docs] class EvalOutputType(BaseModel): score: Literal[1, 2, 3, 4, 5]
[docs] class EvalOutputReasonType(BaseModel): reasoning: str score: Literal[1, 2, 3, 4, 5]
def _load_llm_model(params: AmalgamHFParams | AmalgamORParams, output_type) -> Any: import outlines from outlines import Generator match params["type"]: case "hf": from llama_cpp import Llama from outlines.models import LlamaCpp llm = LlamaCpp( Llama.from_pretrained( repo_id=params["repo_id"], filename=params["filename"], n_ctx=params["n_ctx"], n_gpu_layers=params["n_gpu_layers"], batch_size=1, verbose=False, ) ) case "or": import openai # Create the model llm = outlines.from_openai( openai.OpenAI( base_url="https://openrouter.ai/api/v1", api_key=get_or_api_key(), ), params["model"], ) generator_thought = Generator(llm) # Generating the generator = Generator(llm, output_type) return { "model_type": params["type"], "llm": llm, "generator": generator, "generator_thought": generator_thought, } gpu_lock = threading.Lock() def _gpu_monitor_worker(name: str, stop: threading.Event, run_id: str): import subprocess import csv import tempfile process = subprocess.Popen( [ "nvidia-smi", "--query-gpu=timestamp,power.draw,power.draw.average,power.draw.instant,utilization.gpu,utilization.memory,memory.used,memory.total", "--format=csv,nounits", "-l=1", ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, ) with ( tempfile.TemporaryDirectory() as tmpdir, open(f"{tmpdir}/gpu_{name}.csv", "w", newline="") as tmpfile, ): writer = csv.writer(tmpfile) while not stop.is_set(): line = process.stdout.readline() if not line: continue writer.writerow(line.strip().split(", ")) process.terminate() process.wait() tmpfile.flush() try: import mlflow if mlflow.active_run() is None: mlflow.start_run(run_id=run_id) mlflow.log_artifact(tmpfile.name, artifact_path=f"_raw/energy") except Exception: logger.error("Error logging GPU energy info to MLflow.", exc_info=True)
[docs] class hold_gpu_lock: def __init__(self, name: str | None = None): self.name = name self.t = None self.stop_event = threading.Event() def __enter__(self): gpu_lock.acquire() try: import mlflow ar = mlflow.active_run() assert ar is not None, "MLflow active run is required for GPU monitoring." run_id = ar.info.run_id self.t = threading.Thread( target=_gpu_monitor_worker, args=(self.name or "unknown", self.stop_event, run_id), daemon=True, ) self.t.start() except Exception: logger.error("Error starting GPU monitor thread.", exc_info=True) def __exit__(self, exc_type, exc_value, traceback): if self.t is not None: self.stop_event.set() self.t.join() gpu_lock.release()
[docs] def load_llm_model( params: AmalgamHFParams | AmalgamORParams, output_type, ): llms = [] logger.info(f"Loading LLM model for sampling: {params}") for _ in range(params.get("workers", 1)): llms.append(_load_llm_model(params, output_type)) return { "type": "generate", "model_type": params["type"], "llms": llms, }
[docs] def load_llm_model_eval( params: AmalgamHFParams | AmalgamORParams, reason: bool = True, ): llms = [] logger.info(f"Loading LLM model for evaluation: {params}") for _ in range(params.get("workers", 1)): llms.append( _load_llm_model(params, EvalOutputReasonType if reason else EvalOutputType) ) return { "type": "evaluate", "model_type": params["type"], "llms": llms, }
[docs] def get_or_api_key() -> str: from pasteur.kedro import context assert context is not None, "Kedro context is not initialized." return context._get_config_credentials()["openrouter"]
def _printer(prompt, sample_num, sample_n, q, stop, task): import json MAX_LEN = 300 prompt_reduced = "\n".join( line if len(line) <= MAX_LEN else line[:MAX_LEN] + "..." for line in prompt.split("\n") ) decoder = json.JSONDecoder() thought = "" data = "" ttft = None last_print = time.perf_counter() while token := q.get(): if stop.is_set(): break dtype, j = token end = dtype is None if not end: if dtype is None: break if ttft is None: ttft = time.perf_counter() if isinstance(j, str): frac = j else: assert "object" in j and j["object"] == "text_completion" frac = j["choices"][0]["text"] # type: ignore if dtype == "thought": thought += frac else: data += frac # Flush the queue before printing # Printing is slow, so we only want to print the latest if not q.empty(): continue curr = time.perf_counter() if curr - last_print < PRINT_FREQ: continue # Try to correct invalid json as much as possible # This does not work in dicts, when before the : suffix = "" for i, d in enumerate(data): if d == "{": suffix += "}" elif d == "[": suffix += "]" elif d == '"': if suffix.endswith('"'): suffix = suffix[:-1] else: suffix += '"' elif d == "}": if suffix.endswith("}"): suffix = suffix[:-1] elif d == "]": if suffix.endswith("]"): suffix = suffix[:-1] suffix = suffix[::-1] # reverse sdata = data.rstrip() if sdata.endswith(","): full = data[:-1] + suffix elif sdata.endswith(":"): full = data + " null" + suffix else: full = data + suffix try: if full: obj, end = decoder.raw_decode(full) pretty = json.dumps(obj, indent=2) + "\n" else: pretty = "" thought_str = "" if thought: thought_str = f"\nThought: {thought}" logger.info( f":ephemeral:{task} {sample_num}/{sample_n}. Prompt: {prompt_reduced}{thought_str}\nData:\n{pretty}" ) last_print = curr except json.JSONDecodeError: pass def _worker( generator, generator_thought, stop, in_q, out_q, sample_n, think, print, task, ): import queue while not stop.is_set(): failed = True try: prompt, sample_num = in_q.get(timeout=0.1) except queue.Empty: continue start = time.perf_counter() ttft = None ttft_thought = None full_prompt = prompt pq = queue.Queue() if print: t = threading.Thread( target=_printer, args=(prompt, sample_num, sample_n, pq, stop, task), daemon=True, ) t.start() else: t = None data = [] try: if think: full_prompt = prompt + "\\think<think>" for j in generator_thought.stream(full_prompt, max_tokens=None, stop="</think>"): # type: ignore if ttft_thought is None: ttft_thought = time.perf_counter() if stop.is_set(): break full_prompt += j pq.put(("thought", j)) data.append(("thought", j)) for j in generator.stream(full_prompt, max_tokens=None): # type: ignore if ttft is None: ttft = time.perf_counter() if stop.is_set(): break pq.put(("data", j)) data.append(("data", j)) failed = stop.is_set() except Exception: import traceback logger.error(f"Error in thought worker:\n{traceback.format_exc()}") finally: end = time.perf_counter() out_q.put((start, ttft_thought, ttft, end, data, failed)) pq.put((None, None)) pq.put(None) if t is not None: t.join() def _process_name(name): if "_count" not in name: return name return "NUMBER OF " + name.replace("_count", "").upper() def _process_val(val): if isinstance(val, dict): return {k: _process_val(v) for k, v in val.items()} if not isinstance(val, str): return str(val) if not val.startswith("[") or not val.endswith(")"): return val # Convert bounds like [2, 3) to integer 2 to be simpler to parse v1, v2 = map(lambda s: float(s.strip()), val[1:-1].split(",")) if v1.is_integer() and v2.is_integer() and v2 == v1 + 1: return int(v1) return val def _sample( gen, prompt: str, counts, meta: Any, syn: pd.DataFrame, ref: pd.DataFrame, data: dict[str, LazyDataset], _ctx, ): import json import queue import random import threading import time import numpy as np from IPython.display import HTML, display from pasteur.extras.encoders import create_table_mapping, process_entity out = [] decoder = json.JSONDecoder() topk_idx = [] topk_scores = [] max_per_val = 2**16 // len(counts) norm_scores = { k: np.where( v > 0, max_per_val * v[v != 0].min() / np.where(v > 0, v, 1), 0 ).astype(np.uint16) for k, v in counts.items() if not k.endswith("_common") and not k.endswith("_cmn") } id_map = [] part_map = {} for i, (k, v) in enumerate(data["ids"].items()): part_map[i] = k uv = v() id_map.append(uv.rename(columns={next(iter(uv.columns)): "id"}).assign(part=i)) id_map = pd.concat(id_map).set_index("id") for p in range(0, len(ref) // PART_SIZE + 1): start = p * PART_SIZE end = min((p + 1) * PART_SIZE, len(ref)) scores = np.zeros((end - start, len(syn)), np.uint16) for k, v in norm_scores.items(): mult = 5 if "_total_count" in k else 1 vals = v[ref.iloc[start:end][k].values[:, None]] eqs = ( ref.iloc[start:end][k].values[:, None] == syn.loc[:, k].values[None, :] ) scores += mult * (vals * eqs).astype(np.uint16) idx = scores.argsort(axis=0)[-TOP_K:, :] scores_topk = np.take_along_axis(scores, idx, axis=0) idx_index = np.take_along_axis( np.array(ref.index), start + idx.flatten() ).reshape(idx.shape) topk_idx.append(idx_index) topk_scores.append(scores_topk) full_idx = np.concat(topk_idx) full_scores = np.concat(topk_scores) lookups = np.take_along_axis( full_idx, full_scores.argsort(axis=0)[-TOP_K:, :], axis=0 ) t = None stop = _ctx["stop"] jdata = data["data"] n_samples = syn.shape[0] fails = 0 in_q = queue.Queue() out_q = queue.Queue() for i, llm in enumerate(gen["llms"]): t = threading.Thread( target=_worker, args=( llm["generator"], llm["generator_thought"], stop, in_q, out_q, n_samples, THINK, i == 0, "Sampling Entity", ), daemon=True, ) t.start() _ctx["t"].append(t) prompts = [] for i in range(n_samples): samples = [] for k in lookups[:, i].tolist(): arr = jdata[part_map[int(id_map.loc[k].iloc[0])]]() val = arr.loc[k, next(iter(arr.columns))] samples.append(str(val)) samples = "\n".join(samples) base_data = process_entity( "table", i, create_table_mapping("table", {}, {"table": meta["flat"]["meta"]}, {}), {"table": syn}, {}, {}, ) sample_num = i + 1 seed = "\n".join( f"{_process_name(k)}: {_process_val(v)}" for k, v in base_data.items() ) fprompt = ( prompt.replace("<seed>", seed) .replace("<samples>", samples) .replace("<samples_n>", str(TOP_K)) ) in_q.put((fprompt, sample_num)) prompts.append(fprompt) # Grab energy info tracker = CacheTracker() # FIXME: Generalize this tokenizer = gen["llms"][0]["llm"].tokenizer cached_tokens = 0 input_tokens = 0 output_tokens = 0 in_time = 0 out_time = 0 for i in prange(n_samples, desc="Processing entities"): start, ttft_thought, ttft, end, chunks, failed = out_q.get() prompt = prompts[i] ptokens = tokenizer.encode(prompt)[0] ctokens = tracker.get_cached_len(ptokens) cached_tokens += ctokens input_tokens += len(ptokens) - ctokens in_time += (ttft if ttft is not None else start) - start out_time += end - (ttft if ttft is not None else start) if not chunks: continue data = "" for d in chunks: dtype, frac = d if dtype != "data": continue if isinstance(frac, str): data_str = frac else: assert "object" in frac and frac["object"] == "text_completion" data_str = frac["choices"][0]["text"] # type: ignore data += data_str otokens = tokenizer.encode(data)[0] output_tokens += len(otokens) tracker.add_cached_tokens(ptokens + otokens) if not failed: try: out.append(decoder.decode(data)) except json.JSONDecodeError: fails += 1 else: fails += 1 if fails >= MAX_FAILS and llm["model_type"] == "or": logger.error( f"Sampling failed {fails} times for sample {i+1}. Aborting further sampling." ) raise RuntimeError("Maximum sampling failures reached.") stop.set() for t in _ctx["t"]: t.join() try: input_tps = input_tokens / in_time if in_time > 0 else 0 output_tps = output_tokens / out_time if out_time > 0 else 0 logger.info( f"""\ Entities Generated: {len(out)}, failed: {fails}, total: {n_samples}. # Token information Cached: {cached_tokens:12,d} Input: {input_tokens:12,d} Output: {output_tokens:12,d} Total: {input_tokens + output_tokens:12,d} # Time spent Input time: {in_time if in_time else float('NaN'):7,.2f} s Output time: {out_time if out_time else float('NaN'):7,.2f} s Total time: {in_time + out_time if in_time + out_time else float('NaN'):7,.2f} s # Throughput Input tokens per second: {input_tps if input_tps else float('NaN'):8,.2f} t/s Output tokens per second: {output_tps if output_tps else float('NaN'):8,.2f} t/s """ ) import mlflow if mlflow.active_run() is not None: mlflow.log_param("sampling.cached_tokens", cached_tokens) mlflow.log_param("sampling.input_tokens", input_tokens) mlflow.log_param("sampling.input_time", in_time) mlflow.log_param("sampling.input_tps", input_tps) mlflow.log_param("sampling.output_tokens", output_tokens) mlflow.log_param("sampling.output_time", out_time) mlflow.log_param("sampling.output_tps", output_tps) mlflow.log_param("sampling.sample_n", len(out)) mlflow.log_param("sampling.failures", fails) except Exception: logger.error("Error logging sampling performance to MLflow.", exc_info=True) df = pd.DataFrame({"entity": map(str, out)}) return { "ids": pd.DataFrame({"id": df.index}), "data": df, }
[docs] def sample( gen, prompt: str, counts, meta: Any, pgm_samples: pd.DataFrame, ref: pd.DataFrame, data: dict[str, LazyDataset], ): ctx = { "t": [], "stop": threading.Event(), } try: return _sample( gen, prompt, counts, meta, pgm_samples, ref, data, ctx, ) finally: ctx["stop"].set() for t in ctx["t"]: t.join()
def _evaluate( gen, prompt: str, counts, wrk_flat: pd.DataFrame, wrk_json: dict[str, LazyDataset], eval_flat: pd.DataFrame, eval_json: dict[str, LazyDataset], max_samples: int | None, top_k: int, split: str, _ctx, ): import json import queue import random import threading import time import numpy as np from IPython.display import HTML, display from pasteur.extras.encoders import create_table_mapping, process_entity out = [] decoder = json.JSONDecoder() topk_idx = [] topk_scores = [] max_per_val = 2**16 // len(counts) norm_scores = { k: np.where( v > 0, max_per_val * v[v != 0].min() / np.where(v > 0, v, 1), 0 ).astype(np.uint16) for k, v in counts.items() if not k.endswith("_common") and not k.endswith("_cmn") } id_map = [] part_map = {} for i, (k, v) in enumerate(wrk_json["ids"].items()): part_map[i] = k uv = v() id_map.append(uv.rename(columns={next(iter(uv.columns)): "id"}).assign(part=i)) id_map = pd.concat(id_map).set_index("id") id_map_eval = [] part_map_eval = {} for i, (k, v) in enumerate(eval_json["ids"].items()): part_map_eval[i] = k uv = v() id_map_eval.append( uv.rename(columns={next(iter(uv.columns)): "id"}).assign(part=i) ) id_map_eval = pd.concat(id_map_eval).set_index("id") if max_samples is not None: eval_flat = eval_flat.iloc[:max_samples] for p in range(0, len(wrk_flat) // PART_SIZE + 1): start = p * PART_SIZE end = min((p + 1) * PART_SIZE, len(wrk_flat)) scores = np.zeros((end - start, len(eval_flat)), np.uint16) for k, v in norm_scores.items(): mult = 5 if "_total_count" in k else 1 vals = v[wrk_flat.iloc[start:end][k].values[:, None]] eqs = ( wrk_flat.iloc[start:end][k].values[:, None] == eval_flat.loc[:, k].values[None, :] ) scores += mult * (vals * eqs).astype(np.uint16) idx = scores.argsort(axis=0)[-top_k:, :] scores_topk = np.take_along_axis(scores, idx, axis=0) idx_index = np.take_along_axis( np.array(wrk_flat.index), start + idx.flatten() ).reshape(idx.shape) topk_idx.append(idx_index) topk_scores.append(scores_topk) full_idx = np.concat(topk_idx) full_scores = np.concat(topk_scores) lookups = np.take_along_axis( full_idx, full_scores.argsort(axis=0)[-top_k:, :], axis=0 ) t = None stop = _ctx["stop"] jdata = wrk_json["data"] jdata_eval = eval_json["data"] n_samples = eval_flat.shape[0] fails = 0 in_q = queue.Queue() out_q = queue.Queue() for i, llm in enumerate(gen["llms"]): t = threading.Thread( target=_worker, args=( llm["generator"], None, # llm["generator_thought"], stop, in_q, out_q, n_samples, False, i == 0, f"Evaluating {split} Entity", ), daemon=True, ) t.start() _ctx["t"].append(t) prompts = [] for i in range(n_samples): samples = [] for k in lookups[:, i].tolist(): arr = jdata[part_map[int(id_map.loc[k].iloc[0])]]() val = arr.loc[k, next(iter(arr.columns))] samples.append(str(val)) samples = "\n".join(samples) arr = jdata_eval[ part_map_eval[int(id_map_eval.loc[eval_flat.index[i]].iloc[0])] ]() val = arr.loc[eval_flat.index[i], next(iter(arr.columns))] eval_str = str(val) fprompt = ( prompt.replace("<eval>", eval_str) .replace("<samples>", samples) .replace("<samples_n>", str(top_k)) ) sample_num = i + 1 in_q.put((fprompt, sample_num)) prompts.append(fprompt) # Grab energy info tracker = CacheTracker() # FIXME: Generalize this tokenizer = gen["llms"][0]["llm"].tokenizer cached_tokens = 0 input_tokens = 0 output_tokens = 0 in_time = 0 out_time = 0 for i in prange(n_samples, desc=f"Evaluating {split.capitalize()} entities"): start, ttft_thought, ttft, end, chunks, failed = out_q.get() prompt = prompts[i] ptokens = tokenizer.encode(prompt)[0] ctokens = tracker.get_cached_len(ptokens) cached_tokens += ctokens input_tokens += len(ptokens) - ctokens in_time += (ttft if ttft is not None else start) - start out_time += end - (ttft if ttft is not None else start) if not chunks: continue data = "" for d in chunks: dtype, frac = d if dtype != "data": continue if isinstance(frac, str): data_str = frac else: assert "object" in frac and frac["object"] == "text_completion" data_str = frac["choices"][0]["text"] # type: ignore data += data_str otokens = tokenizer.encode(data)[0] output_tokens += len(otokens) tracker.add_cached_tokens(ptokens + otokens) if not failed: try: out.append(decoder.decode(data)) except json.JSONDecodeError: fails += 1 else: fails += 1 if fails >= MAX_FAILS and llm["model_type"] == "or": logger.error( f"Sampling failed {fails} times for sample {i+1}. Aborting further sampling." ) raise RuntimeError("Maximum sampling failures reached.") stop.set() for t in _ctx["t"]: t.join() try: input_tps = input_tokens / in_time if in_time > 0 else 0 output_tps = output_tokens / out_time if out_time > 0 else 0 logger.info( f"""\ {split.capitalize()} Entities evaluated: {len(out)}, failed: {fails}, total: {n_samples}. # Token information Cached: {cached_tokens:12,d} Input: {input_tokens:12,d} Output: {output_tokens:12,d} Total: {input_tokens + output_tokens:12,d} # Time spent Input time: {in_time if in_time else float('NaN'):7,.2f} s Output time: {out_time if out_time else float('NaN'):7,.2f} s Total time: {in_time + out_time if in_time + out_time else float('NaN'):7,.2f} s # Throughput Input tokens per second: {input_tps if input_tps else float('NaN'):8,.2f} t/s Output tokens per second: {output_tps if output_tps else float('NaN'):8,.2f} t/s """ ) import mlflow if mlflow.active_run() is not None: mlflow.log_param(f"eval.{split}.cached_tokens", cached_tokens) mlflow.log_param(f"eval.{split}.input_tokens", input_tokens) mlflow.log_param(f"eval.{split}.input_time", in_time) mlflow.log_param(f"eval.{split}.input_tps", input_tps) mlflow.log_param(f"eval.{split}.output_tokens", output_tokens) mlflow.log_param(f"eval.{split}.output_time", out_time) mlflow.log_param(f"eval.{split}.output_tps", output_tps) mlflow.log_param(f"eval.{split}.sample_n", len(out)) mlflow.log_param(f"eval.{split}.failures", fails) except Exception: logger.error("Error logging sampling performance to MLflow.", exc_info=True) return out
[docs] def evaluate( gen, prompt: str, counts, wrk_flat: pd.DataFrame, wrk_json: dict[str, LazyDataset], eval_flat: pd.DataFrame, eval_json: dict[str, LazyDataset], max_samples: int | None = None, top_k: int = 3, split: str = "ref", ): ctx = { "t": [], "stop": threading.Event(), } try: return _evaluate( gen, prompt, counts, wrk_flat, wrk_json, eval_flat, eval_json, max_samples, top_k, split, ctx, ) finally: ctx["stop"].set() for t in ctx["t"]: t.join()