[fix] lower on query, [add] metadata response, [add] context distance & reference links
This commit is contained in:
		
							parent
							
								
									05c5546c10
								
							
						
					
					
						commit
						40daf46c03
					
				@ -43,7 +43,11 @@ and the only supported normalizer is the `pubmed` normalizer.
 | 
			
		||||
To normalize data, you can use Minyma's `normalize` CLI command:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
minyma normalize --filename ./pubmed_manuscripts.jsonl --normalizer pubmed --database chroma --datapath ./chroma
 | 
			
		||||
minyma normalize \
 | 
			
		||||
    --filename ./pubmed_manuscripts.jsonl \
 | 
			
		||||
    --normalizer pubmed \
 | 
			
		||||
    --database chroma \
 | 
			
		||||
    --datapath ./chroma
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
The above example does the following:
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,8 @@
 | 
			
		||||
from os import path
 | 
			
		||||
import click
 | 
			
		||||
import signal
 | 
			
		||||
import sys
 | 
			
		||||
from importlib.metadata import version
 | 
			
		||||
from minyma.config import Config
 | 
			
		||||
from minyma.oai import OpenAIConnector
 | 
			
		||||
from minyma.vdb import ChromaDB
 | 
			
		||||
from flask import Flask
 | 
			
		||||
@ -15,14 +15,15 @@ def signal_handler(sig, frame):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_app():
 | 
			
		||||
    global oai, cdb
 | 
			
		||||
    global oai, vdb
 | 
			
		||||
 | 
			
		||||
    from minyma.config import Config
 | 
			
		||||
    import minyma.api.common as api_common
 | 
			
		||||
    import minyma.api.v1 as api_v1
 | 
			
		||||
 | 
			
		||||
    app = Flask(__name__)
 | 
			
		||||
    cdb = ChromaDB(Config.DATA_PATH)
 | 
			
		||||
    oai = OpenAIConnector(Config.OPENAI_API_KEY, cdb)
 | 
			
		||||
    vdb = ChromaDB(path.join(Config.DATA_PATH, "chroma"))
 | 
			
		||||
    oai = OpenAIConnector(Config.OPENAI_API_KEY, vdb)
 | 
			
		||||
 | 
			
		||||
    app.register_blueprint(api_common.bp)
 | 
			
		||||
    app.register_blueprint(api_v1.bp)
 | 
			
		||||
 | 
			
		||||
@ -17,8 +17,28 @@ def get_response():
 | 
			
		||||
    if message == "":
 | 
			
		||||
        return {"error": "Empty Message"}
 | 
			
		||||
 | 
			
		||||
    oai_response = minyma.oai.query(message)
 | 
			
		||||
    return oai_response
 | 
			
		||||
    resp = minyma.oai.query(message)
 | 
			
		||||
 | 
			
		||||
    # Derive LLM Data
 | 
			
		||||
    llm_resp = resp.get("llm", {})
 | 
			
		||||
    llm_choices = llm_resp.get("choices", [])
 | 
			
		||||
 | 
			
		||||
    # Derive VDB Data
 | 
			
		||||
    vdb_resp = resp.get("vdb", {})
 | 
			
		||||
    combined_context  = [{
 | 
			
		||||
            "id": vdb_resp.get("ids")[i],
 | 
			
		||||
            "distance": vdb_resp.get("distances")[i],
 | 
			
		||||
            "doc": vdb_resp.get("docs")[i],
 | 
			
		||||
            "metadata": vdb_resp.get("metadatas")[i],
 | 
			
		||||
    } for i, _ in enumerate(vdb_resp.get("docs", []))]
 | 
			
		||||
 | 
			
		||||
    # Return Data
 | 
			
		||||
    return {
 | 
			
		||||
        "response": None if len(llm_choices) == 0 else llm_choices[0].get("message", {}).get("content"),
 | 
			
		||||
        "context": combined_context,
 | 
			
		||||
        "usage": llm_resp.get("usage"),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
@ -34,5 +54,5 @@ def get_related():
 | 
			
		||||
    if message == "":
 | 
			
		||||
        return {"error": "Empty Message"}
 | 
			
		||||
 | 
			
		||||
    related_documents = minyma.cdb.get_related(message)
 | 
			
		||||
    related_documents = minyma.vdb.get_related(message)
 | 
			
		||||
    return related_documents
 | 
			
		||||
 | 
			
		||||
@ -1,12 +1,16 @@
 | 
			
		||||
from io import TextIOWrapper
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
class DataNormalizer:
 | 
			
		||||
class DataNormalizer():
 | 
			
		||||
    def __init__(self, file: TextIOWrapper):
 | 
			
		||||
        pass
 | 
			
		||||
        self.file = file
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        return 0
 | 
			
		||||
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        pass
 | 
			
		||||
        yield None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PubMedNormalizer(DataNormalizer):
 | 
			
		||||
    """
 | 
			
		||||
@ -14,7 +18,14 @@ class PubMedNormalizer(DataNormalizer):
 | 
			
		||||
    normalized inside the iterator.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, file: TextIOWrapper):
 | 
			
		||||
         self.file = file
 | 
			
		||||
        self.file = file
 | 
			
		||||
        self.length = 0
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        last_pos = self.file.tell()
 | 
			
		||||
        self.length = sum(1 for _ in self.file)
 | 
			
		||||
        self.file.seek(last_pos)
 | 
			
		||||
        return self.length
 | 
			
		||||
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        count = 0
 | 
			
		||||
@ -42,4 +53,10 @@ class PubMedNormalizer(DataNormalizer):
 | 
			
		||||
            count += 1
 | 
			
		||||
 | 
			
		||||
            # ID = Line Number
 | 
			
		||||
            yield { "doc": norm_text, "id": str(count - 1) }
 | 
			
		||||
            yield {
 | 
			
		||||
               "id": str(count - 1),
 | 
			
		||||
               "doc": norm_text,
 | 
			
		||||
               "metadata": {
 | 
			
		||||
                   "file": l.get("file")
 | 
			
		||||
               },
 | 
			
		||||
           }
 | 
			
		||||
 | 
			
		||||
@ -5,8 +5,8 @@ from minyma.vdb import VectorDB
 | 
			
		||||
 | 
			
		||||
# Stolen LangChain Prompt
 | 
			
		||||
PROMPT_TEMPLATE = """
 | 
			
		||||
Use the following pieces of context to answer the question at the end. 
 | 
			
		||||
If you don't know the answer, just say that you don't know, don't try to 
 | 
			
		||||
Use the following pieces of context to answer the question at the end.
 | 
			
		||||
If you don't know the answer, just say that you don't know, don't try to
 | 
			
		||||
make up an answer.
 | 
			
		||||
 | 
			
		||||
{context}
 | 
			
		||||
@ -19,6 +19,7 @@ class OpenAIConnector:
 | 
			
		||||
    def __init__(self, api_key: str, vdb: VectorDB):
 | 
			
		||||
        self.vdb = vdb
 | 
			
		||||
        self.model = "gpt-3.5-turbo"
 | 
			
		||||
        self.word_cap = 1000
 | 
			
		||||
        openai.api_key = api_key
 | 
			
		||||
 | 
			
		||||
    def query(self, question: str) -> Any:
 | 
			
		||||
@ -30,8 +31,9 @@ class OpenAIConnector:
 | 
			
		||||
        if len(all_docs) == 0:
 | 
			
		||||
            return { "error": "No Context Found" }
 | 
			
		||||
 | 
			
		||||
        # Join on new line, generate main prompt
 | 
			
		||||
        context = '\n'.join(all_docs)
 | 
			
		||||
        # Join on new line (cap @ word limit), generate main prompt
 | 
			
		||||
        reduced_docs = list(map(lambda x: " ".join(x.split()[:self.word_cap]), all_docs))
 | 
			
		||||
        context = '\n'.join(reduced_docs)
 | 
			
		||||
        prompt = PROMPT_TEMPLATE.format(context = context, question = question)
 | 
			
		||||
 | 
			
		||||
        # Query OpenAI ChatCompletion
 | 
			
		||||
@ -41,4 +43,4 @@ class OpenAIConnector:
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Return Response
 | 
			
		||||
        return response
 | 
			
		||||
        return { "llm": response, "vdb": related }
 | 
			
		||||
 | 
			
		||||
@ -72,41 +72,41 @@
 | 
			
		||||
    </main>
 | 
			
		||||
    <script>
 | 
			
		||||
      const LOADING_SVG = `<svg
 | 
			
		||||
	width="24"
 | 
			
		||||
	height="24"
 | 
			
		||||
	viewBox="0 0 24 24"
 | 
			
		||||
	xmlns="http://www.w3.org/2000/svg"
 | 
			
		||||
	fill="currentColor"
 | 
			
		||||
      >
 | 
			
		||||
	<style>
 | 
			
		||||
	  .spinner_qM83 {
 | 
			
		||||
	    animation: spinner_8HQG 1.05s infinite;
 | 
			
		||||
	  }
 | 
			
		||||
	  .spinner_oXPr {
 | 
			
		||||
	    animation-delay: 0.1s;
 | 
			
		||||
	  }
 | 
			
		||||
	  .spinner_ZTLf {
 | 
			
		||||
	    animation-delay: 0.2s;
 | 
			
		||||
	  }
 | 
			
		||||
	  @keyframes spinner_8HQG {
 | 
			
		||||
	    0%,
 | 
			
		||||
	    57.14% {
 | 
			
		||||
	      animation-timing-function: cubic-bezier(0.33, 0.66, 0.66, 1);
 | 
			
		||||
	      transform: translate(0);
 | 
			
		||||
	    }
 | 
			
		||||
	    28.57% {
 | 
			
		||||
	      animation-timing-function: cubic-bezier(0.33, 0, 0.66, 0.33);
 | 
			
		||||
	      transform: translateY(-6px);
 | 
			
		||||
	    }
 | 
			
		||||
	    100% {
 | 
			
		||||
	      transform: translate(0);
 | 
			
		||||
	    }
 | 
			
		||||
	  }
 | 
			
		||||
	</style>
 | 
			
		||||
	<circle class="spinner_qM83" cx="4" cy="12" r="3"></circle>
 | 
			
		||||
	<circle class="spinner_qM83 spinner_oXPr" cx="12" cy="12" r="3"></circle>
 | 
			
		||||
	<circle class="spinner_qM83 spinner_ZTLf" cx="20" cy="12" r="3"></circle>
 | 
			
		||||
      </svg>`;
 | 
			
		||||
       width="24"
 | 
			
		||||
       height="24"
 | 
			
		||||
       viewBox="0 0 24 24"
 | 
			
		||||
       xmlns="http://www.w3.org/2000/svg"
 | 
			
		||||
       fill="currentColor"
 | 
			
		||||
            >
 | 
			
		||||
       <style>
 | 
			
		||||
         .spinner_qM83 {
 | 
			
		||||
           animation: spinner_8HQG 1.05s infinite;
 | 
			
		||||
         }
 | 
			
		||||
         .spinner_oXPr {
 | 
			
		||||
           animation-delay: 0.1s;
 | 
			
		||||
         }
 | 
			
		||||
         .spinner_ZTLf {
 | 
			
		||||
           animation-delay: 0.2s;
 | 
			
		||||
         }
 | 
			
		||||
         @keyframes spinner_8HQG {
 | 
			
		||||
           0%,
 | 
			
		||||
           57.14% {
 | 
			
		||||
             animation-timing-function: cubic-bezier(0.33, 0.66, 0.66, 1);
 | 
			
		||||
             transform: translate(0);
 | 
			
		||||
           }
 | 
			
		||||
           28.57% {
 | 
			
		||||
             animation-timing-function: cubic-bezier(0.33, 0, 0.66, 0.33);
 | 
			
		||||
             transform: translateY(-6px);
 | 
			
		||||
           }
 | 
			
		||||
           100% {
 | 
			
		||||
             transform: translate(0);
 | 
			
		||||
           }
 | 
			
		||||
         }
 | 
			
		||||
       </style>
 | 
			
		||||
       <circle class="spinner_qM83" cx="4" cy="12" r="3"></circle>
 | 
			
		||||
       <circle class="spinner_qM83 spinner_oXPr" cx="12" cy="12" r="3"></circle>
 | 
			
		||||
       <circle class="spinner_qM83 spinner_ZTLf" cx="20" cy="12" r="3"></circle>
 | 
			
		||||
            </svg>`;
 | 
			
		||||
 | 
			
		||||
      /**
 | 
			
		||||
       * Wrapper API Call
 | 
			
		||||
@ -125,9 +125,9 @@
 | 
			
		||||
        // Wrapping Element
 | 
			
		||||
        let wrapEl = document.createElement("div");
 | 
			
		||||
        wrapEl.innerHTML = `<div class="flex">
 | 
			
		||||
	   <span class="font-bold w-24 grow-0 shrink-0"></span>
 | 
			
		||||
	   <span class="whitespace-break-spaces w-full"></span>
 | 
			
		||||
	 </div>`;
 | 
			
		||||
          <span class="font-bold w-24 grow-0 shrink-0"></span>
 | 
			
		||||
          <span class="whitespace-break-spaces w-full"></span>
 | 
			
		||||
        </div>`;
 | 
			
		||||
 | 
			
		||||
        // Get Elements
 | 
			
		||||
        let nameEl = wrapEl.querySelector("span");
 | 
			
		||||
@ -158,7 +158,63 @@
 | 
			
		||||
        })
 | 
			
		||||
          .then((data) => {
 | 
			
		||||
            console.log("SUCCESS:", data);
 | 
			
		||||
            content.innerText = data.choices[0].message.content;
 | 
			
		||||
 | 
			
		||||
            // Create Response Element
 | 
			
		||||
            let responseEl = document.createElement("p");
 | 
			
		||||
            responseEl.setAttribute(
 | 
			
		||||
              "class",
 | 
			
		||||
              "whitespace-break-spaces border-b pb-3 mb-3"
 | 
			
		||||
            );
 | 
			
		||||
            responseEl.innerText = data.response;
 | 
			
		||||
 | 
			
		||||
            // Create Context Element
 | 
			
		||||
            let contextEl = document.createElement("div");
 | 
			
		||||
            contextEl.innerHTML = `
 | 
			
		||||
	      <h1 class="font-bold">Context:</h1>
 | 
			
		||||
	      <ul class="list-disc ml-6"></ul>`;
 | 
			
		||||
            let ulEl = contextEl.querySelector("ul");
 | 
			
		||||
 | 
			
		||||
            // Create Context Links
 | 
			
		||||
            data.context
 | 
			
		||||
 | 
			
		||||
              // Capture PubMed ID & Distance
 | 
			
		||||
              .map((item) => [
 | 
			
		||||
                item.metadata.file.match("\/(.*)\.txt$"),
 | 
			
		||||
                item.distance,
 | 
			
		||||
              ])
 | 
			
		||||
 | 
			
		||||
              // Filter Non-Matches
 | 
			
		||||
              .filter(([match]) => match)
 | 
			
		||||
 | 
			
		||||
              // Get Match Value & Round Distance (2)
 | 
			
		||||
              .map(([match, distance]) => [
 | 
			
		||||
                match[1],
 | 
			
		||||
                Math.round(distance * 100) / 100,
 | 
			
		||||
              ])
 | 
			
		||||
 | 
			
		||||
              // Create Links
 | 
			
		||||
              .forEach(([pmid, distance]) => {
 | 
			
		||||
                let newEl = document.createElement("li");
 | 
			
		||||
                let linkEl = document.createElement("a");
 | 
			
		||||
                linkEl.setAttribute("target", "_blank");
 | 
			
		||||
                linkEl.setAttribute(
 | 
			
		||||
                  "class",
 | 
			
		||||
                  "text-blue-500 hover:text-blue-600"
 | 
			
		||||
                );
 | 
			
		||||
                linkEl.setAttribute(
 | 
			
		||||
                  "href",
 | 
			
		||||
                  "https://www.ncbi.nlm.nih.gov/pmc/articles/" + pmid
 | 
			
		||||
                );
 | 
			
		||||
                linkEl.textContent = "[" + distance + "] " + pmid;
 | 
			
		||||
                newEl.append(linkEl);
 | 
			
		||||
                ulEl.append(newEl);
 | 
			
		||||
              });
 | 
			
		||||
 | 
			
		||||
            // Add to DOM
 | 
			
		||||
            content.setAttribute("class", "w-full");
 | 
			
		||||
            content.innerHTML = "";
 | 
			
		||||
            content.append(responseEl);
 | 
			
		||||
            content.append(contextEl);
 | 
			
		||||
          })
 | 
			
		||||
          .catch((e) => {
 | 
			
		||||
            console.log("ERROR:", e);
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,6 @@
 | 
			
		||||
from chromadb.api import API
 | 
			
		||||
from itertools import islice
 | 
			
		||||
from os import path
 | 
			
		||||
from tqdm.auto import tqdm
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from typing import Any, cast
 | 
			
		||||
import chromadb
 | 
			
		||||
 | 
			
		||||
@ -29,47 +28,45 @@ class VectorDB:
 | 
			
		||||
ChromaDV VectorDB Type
 | 
			
		||||
"""
 | 
			
		||||
class ChromaDB(VectorDB):
 | 
			
		||||
    def __init__(self, base_path: str):
 | 
			
		||||
        chroma_path = path.join(base_path, "chroma")
 | 
			
		||||
        self.client: API = chromadb.PersistentClient(path=chroma_path)
 | 
			
		||||
        self.word_limit = 1000
 | 
			
		||||
    def __init__(self, path: str):
 | 
			
		||||
        self.client: API = chromadb.PersistentClient(path=path)
 | 
			
		||||
        self.word_cap = 2500
 | 
			
		||||
        self.collection_name: str = "vdb"
 | 
			
		||||
        self.collection: chromadb.Collection = self.client.create_collection(name=self.collection_name, get_or_create=True)
 | 
			
		||||
 | 
			
		||||
    def get_related(self, question) -> Any:
 | 
			
		||||
    def get_related(self, question: str) -> Any:
 | 
			
		||||
        """Returns line separated related docs"""
 | 
			
		||||
        results = self.collection.query(
 | 
			
		||||
            query_texts=[question],
 | 
			
		||||
            query_texts=[question.lower()],
 | 
			
		||||
            n_results=2
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        all_docs: list = cast(list, results.get("documents", [[]]))[0]
 | 
			
		||||
        all_metadata: list = cast(list, results.get("metadatas", [[]]))[0]
 | 
			
		||||
        all_distances: list = cast(list, results.get("distances", [[]]))[0]
 | 
			
		||||
        all_ids: list = cast(list, results.get("ids", [[]]))[0]
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            "distances":all_distances, 
 | 
			
		||||
            "distances": all_distances,
 | 
			
		||||
            "metadatas": all_metadata,
 | 
			
		||||
            "docs": all_docs,
 | 
			
		||||
            "ids": all_ids
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def load_documents(self, normalizer: DataNormalizer):
 | 
			
		||||
        # 10 Item Chunking
 | 
			
		||||
        for items in tqdm(chunk(normalizer, 50)):
 | 
			
		||||
    def load_documents(self, normalizer: DataNormalizer, chunk_size: int = 10):
 | 
			
		||||
        length = len(normalizer) / chunk_size
 | 
			
		||||
        for items in tqdm(chunk(normalizer, chunk_size), total=length):
 | 
			
		||||
            ids = []
 | 
			
		||||
            documents = []
 | 
			
		||||
            metadatas = []
 | 
			
		||||
 | 
			
		||||
            # Limit words per document to accommodate context token limits
 | 
			
		||||
            for item in items:
 | 
			
		||||
                doc = " ".join(item.get("doc").split()[:self.word_limit])
 | 
			
		||||
                documents.append(doc)
 | 
			
		||||
                documents.append(" ".join(item.get("doc").split()[:self.word_cap]))
 | 
			
		||||
                ids.append(item.get("id"))
 | 
			
		||||
                metadatas.append(item.get("metadata", {}))
 | 
			
		||||
 | 
			
		||||
            # Ideally we parse out metadata from each document
 | 
			
		||||
            # and pass to the metadata kwarg. However, each
 | 
			
		||||
            # document appears to have a slightly different format,
 | 
			
		||||
            # so it's difficult to parse out.
 | 
			
		||||
            self.collection.add(
 | 
			
		||||
                ids=ids,
 | 
			
		||||
                documents=documents,
 | 
			
		||||
                ids=ids
 | 
			
		||||
                metadatas=metadatas,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user