[fix] lower on query, [add] metadata response, [add] context distance & reference links

This commit is contained in:
Evan Reichard 2023-10-19 18:56:48 -04:00
parent 05c5546c10
commit 40daf46c03
7 changed files with 174 additions and 77 deletions

View File

@ -43,7 +43,11 @@ and the only supported normalizer is the `pubmed` normalizer.
To normalize data, you can use Minyma's `normalize` CLI command:
```bash
minyma normalize --filename ./pubmed_manuscripts.jsonl --normalizer pubmed --database chroma --datapath ./chroma
minyma normalize \
--filename ./pubmed_manuscripts.jsonl \
--normalizer pubmed \
--database chroma \
--datapath ./chroma
```
The above example does the following:

View File

@ -1,8 +1,8 @@
from os import path
import click
import signal
import sys
from importlib.metadata import version
from minyma.config import Config
from minyma.oai import OpenAIConnector
from minyma.vdb import ChromaDB
from flask import Flask
@ -15,14 +15,15 @@ def signal_handler(sig, frame):
def create_app():
global oai, cdb
global oai, vdb
from minyma.config import Config
import minyma.api.common as api_common
import minyma.api.v1 as api_v1
app = Flask(__name__)
cdb = ChromaDB(Config.DATA_PATH)
oai = OpenAIConnector(Config.OPENAI_API_KEY, cdb)
vdb = ChromaDB(path.join(Config.DATA_PATH, "chroma"))
oai = OpenAIConnector(Config.OPENAI_API_KEY, vdb)
app.register_blueprint(api_common.bp)
app.register_blueprint(api_v1.bp)

View File

@ -17,8 +17,28 @@ def get_response():
if message == "":
return {"error": "Empty Message"}
oai_response = minyma.oai.query(message)
return oai_response
resp = minyma.oai.query(message)
# Derive LLM Data
llm_resp = resp.get("llm", {})
llm_choices = llm_resp.get("choices", [])
# Derive VDB Data
vdb_resp = resp.get("vdb", {})
combined_context = [{
"id": vdb_resp.get("ids")[i],
"distance": vdb_resp.get("distances")[i],
"doc": vdb_resp.get("docs")[i],
"metadata": vdb_resp.get("metadatas")[i],
} for i, _ in enumerate(vdb_resp.get("docs", []))]
# Return Data
return {
"response": None if len(llm_choices) == 0 else llm_choices[0].get("message", {}).get("content"),
"context": combined_context,
"usage": llm_resp.get("usage"),
}
"""
@ -34,5 +54,5 @@ def get_related():
if message == "":
return {"error": "Empty Message"}
related_documents = minyma.cdb.get_related(message)
related_documents = minyma.vdb.get_related(message)
return related_documents

View File

@ -1,12 +1,16 @@
from io import TextIOWrapper
import json
class DataNormalizer:
class DataNormalizer():
def __init__(self, file: TextIOWrapper):
pass
self.file = file
def __len__(self) -> int:
return 0
def __iter__(self):
pass
yield None
class PubMedNormalizer(DataNormalizer):
"""
@ -14,7 +18,14 @@ class PubMedNormalizer(DataNormalizer):
normalized inside the iterator.
"""
def __init__(self, file: TextIOWrapper):
self.file = file
self.file = file
self.length = 0
def __len__(self):
last_pos = self.file.tell()
self.length = sum(1 for _ in self.file)
self.file.seek(last_pos)
return self.length
def __iter__(self):
count = 0
@ -42,4 +53,10 @@ class PubMedNormalizer(DataNormalizer):
count += 1
# ID = Line Number
yield { "doc": norm_text, "id": str(count - 1) }
yield {
"id": str(count - 1),
"doc": norm_text,
"metadata": {
"file": l.get("file")
},
}

View File

@ -5,8 +5,8 @@ from minyma.vdb import VectorDB
# Stolen LangChain Prompt
PROMPT_TEMPLATE = """
Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to
Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to
make up an answer.
{context}
@ -19,6 +19,7 @@ class OpenAIConnector:
def __init__(self, api_key: str, vdb: VectorDB):
self.vdb = vdb
self.model = "gpt-3.5-turbo"
self.word_cap = 1000
openai.api_key = api_key
def query(self, question: str) -> Any:
@ -30,8 +31,9 @@ class OpenAIConnector:
if len(all_docs) == 0:
return { "error": "No Context Found" }
# Join on new line, generate main prompt
context = '\n'.join(all_docs)
# Join on new line (cap @ word limit), generate main prompt
reduced_docs = list(map(lambda x: " ".join(x.split()[:self.word_cap]), all_docs))
context = '\n'.join(reduced_docs)
prompt = PROMPT_TEMPLATE.format(context = context, question = question)
# Query OpenAI ChatCompletion
@ -41,4 +43,4 @@ class OpenAIConnector:
)
# Return Response
return response
return { "llm": response, "vdb": related }

View File

@ -72,41 +72,41 @@
</main>
<script>
const LOADING_SVG = `<svg
width="24"
height="24"
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
fill="currentColor"
>
<style>
.spinner_qM83 {
animation: spinner_8HQG 1.05s infinite;
}
.spinner_oXPr {
animation-delay: 0.1s;
}
.spinner_ZTLf {
animation-delay: 0.2s;
}
@keyframes spinner_8HQG {
0%,
57.14% {
animation-timing-function: cubic-bezier(0.33, 0.66, 0.66, 1);
transform: translate(0);
}
28.57% {
animation-timing-function: cubic-bezier(0.33, 0, 0.66, 0.33);
transform: translateY(-6px);
}
100% {
transform: translate(0);
}
}
</style>
<circle class="spinner_qM83" cx="4" cy="12" r="3"></circle>
<circle class="spinner_qM83 spinner_oXPr" cx="12" cy="12" r="3"></circle>
<circle class="spinner_qM83 spinner_ZTLf" cx="20" cy="12" r="3"></circle>
</svg>`;
width="24"
height="24"
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
fill="currentColor"
>
<style>
.spinner_qM83 {
animation: spinner_8HQG 1.05s infinite;
}
.spinner_oXPr {
animation-delay: 0.1s;
}
.spinner_ZTLf {
animation-delay: 0.2s;
}
@keyframes spinner_8HQG {
0%,
57.14% {
animation-timing-function: cubic-bezier(0.33, 0.66, 0.66, 1);
transform: translate(0);
}
28.57% {
animation-timing-function: cubic-bezier(0.33, 0, 0.66, 0.33);
transform: translateY(-6px);
}
100% {
transform: translate(0);
}
}
</style>
<circle class="spinner_qM83" cx="4" cy="12" r="3"></circle>
<circle class="spinner_qM83 spinner_oXPr" cx="12" cy="12" r="3"></circle>
<circle class="spinner_qM83 spinner_ZTLf" cx="20" cy="12" r="3"></circle>
</svg>`;
/**
* Wrapper API Call
@ -125,9 +125,9 @@
// Wrapping Element
let wrapEl = document.createElement("div");
wrapEl.innerHTML = `<div class="flex">
<span class="font-bold w-24 grow-0 shrink-0"></span>
<span class="whitespace-break-spaces w-full"></span>
</div>`;
<span class="font-bold w-24 grow-0 shrink-0"></span>
<span class="whitespace-break-spaces w-full"></span>
</div>`;
// Get Elements
let nameEl = wrapEl.querySelector("span");
@ -158,7 +158,63 @@
})
.then((data) => {
console.log("SUCCESS:", data);
content.innerText = data.choices[0].message.content;
// Create Response Element
let responseEl = document.createElement("p");
responseEl.setAttribute(
"class",
"whitespace-break-spaces border-b pb-3 mb-3"
);
responseEl.innerText = data.response;
// Create Context Element
let contextEl = document.createElement("div");
contextEl.innerHTML = `
<h1 class="font-bold">Context:</h1>
<ul class="list-disc ml-6"></ul>`;
let ulEl = contextEl.querySelector("ul");
// Create Context Links
data.context
// Capture PubMed ID & Distance
.map((item) => [
item.metadata.file.match("\/(.*)\.txt$"),
item.distance,
])
// Filter Non-Matches
.filter(([match]) => match)
// Get Match Value & Round Distance (2)
.map(([match, distance]) => [
match[1],
Math.round(distance * 100) / 100,
])
// Create Links
.forEach(([pmid, distance]) => {
let newEl = document.createElement("li");
let linkEl = document.createElement("a");
linkEl.setAttribute("target", "_blank");
linkEl.setAttribute(
"class",
"text-blue-500 hover:text-blue-600"
);
linkEl.setAttribute(
"href",
"https://www.ncbi.nlm.nih.gov/pmc/articles/" + pmid
);
linkEl.textContent = "[" + distance + "] " + pmid;
newEl.append(linkEl);
ulEl.append(newEl);
});
// Add to DOM
content.setAttribute("class", "w-full");
content.innerHTML = "";
content.append(responseEl);
content.append(contextEl);
})
.catch((e) => {
console.log("ERROR:", e);

View File

@ -1,7 +1,6 @@
from chromadb.api import API
from itertools import islice
from os import path
from tqdm.auto import tqdm
from tqdm import tqdm
from typing import Any, cast
import chromadb
@ -29,47 +28,45 @@ class VectorDB:
ChromaDV VectorDB Type
"""
class ChromaDB(VectorDB):
def __init__(self, base_path: str):
chroma_path = path.join(base_path, "chroma")
self.client: API = chromadb.PersistentClient(path=chroma_path)
self.word_limit = 1000
def __init__(self, path: str):
self.client: API = chromadb.PersistentClient(path=path)
self.word_cap = 2500
self.collection_name: str = "vdb"
self.collection: chromadb.Collection = self.client.create_collection(name=self.collection_name, get_or_create=True)
def get_related(self, question) -> Any:
def get_related(self, question: str) -> Any:
"""Returns line separated related docs"""
results = self.collection.query(
query_texts=[question],
query_texts=[question.lower()],
n_results=2
)
all_docs: list = cast(list, results.get("documents", [[]]))[0]
all_metadata: list = cast(list, results.get("metadatas", [[]]))[0]
all_distances: list = cast(list, results.get("distances", [[]]))[0]
all_ids: list = cast(list, results.get("ids", [[]]))[0]
return {
"distances":all_distances,
"distances": all_distances,
"metadatas": all_metadata,
"docs": all_docs,
"ids": all_ids
}
def load_documents(self, normalizer: DataNormalizer):
# 10 Item Chunking
for items in tqdm(chunk(normalizer, 50)):
def load_documents(self, normalizer: DataNormalizer, chunk_size: int = 10):
length = len(normalizer) / chunk_size
for items in tqdm(chunk(normalizer, chunk_size), total=length):
ids = []
documents = []
metadatas = []
# Limit words per document to accommodate context token limits
for item in items:
doc = " ".join(item.get("doc").split()[:self.word_limit])
documents.append(doc)
documents.append(" ".join(item.get("doc").split()[:self.word_cap]))
ids.append(item.get("id"))
metadatas.append(item.get("metadata", {}))
# Ideally we parse out metadata from each document
# and pass to the metadata kwarg. However, each
# document appears to have a slightly different format,
# so it's difficult to parse out.
self.collection.add(
ids=ids,
documents=documents,
ids=ids
metadatas=metadatas,
)