Comment développer une interface utilisateur à jetons pour votre LLM avec Go, FastAPI et JS

Les modèles génératifs mettent parfois un certain temps à renvoyer un résultat, il est donc intéressant d'exploiter le token streaming afin de voir le résultat apparaître à la volée dans l'interface utilisateur. Voici comment vous pouvez réaliser un tel frontend pour votre LLM avec Go, FastAPI et Javascript.

Développeur sur PC

Qu'est-ce que le streaming de jetons ?

Pour rappel, un jeton est une entité unique qui peut être un petit mot, une partie de mot ou une ponctuation. En moyenne, 1 jeton est composé de 4 caractères, et 100 jetons équivalent à peu près à 75 mots. Les modèles de traitement du langage naturel doivent transformer votre texte en jetons pour pouvoir le traiter.

Lors de l'utilisation d'un modèle d'IA de génération de texte (également connu sous le nom de modèle "génératif"), le temps de réponse peut être assez élevé, en fonction de votre matériel et de la taille de votre modèle. Par exemple, dans le cas d'un grand modèle de langage (également connu sous le nom de "LLM") comme LLaMA 30B, déployé sur un GPU NVIDIA A100 en fp16, le modèle génère 100 tokens en 3 secondes environ. Ainsi, si vous attendez de votre modèle génératif qu'il génère un texte volumineux composé de centaines ou de milliers de mots, la latence sera élevée et vous devrez peut-être attendre plus de 10 secondes pour obtenir le texte complet. plus de 10 secondes pour obtenir la réponse complète.

Attendre si longtemps pour obtenir une réponse peut être un problème du point de vue de l'expérience utilisateur. La solution dans ce cas est le streaming de jetons !

Le streaming de jetons consiste à générer chaque nouveau jeton à la volée au lieu d'attendre que toute la réponse soit prête. C'est ce que vous sur l'application ChatGPT ou sur l'assistant ChatDolphin de NLP Cloud, par exemple. Les mots apparaissent dès qu'ils sont générés par le modèle. Essayez l'assistant ChatDolphin AI ici.

Streaming de jetons avec ChatDolphin sur NLP Cloud Streaming de jetons avec l'assistant ChatDolphin sur NLP Cloud. Essayez-le ici.

Sélection d'un moteur d'inférence prenant en charge le flux de jetons

La première étape consiste à tirer parti d'un moteur d'inférence qui prend en charge le flux de jetons.

Voici quelques options que vous pourriez envisager :

Voici un exemple utilisant la méthode HuggingFace generate() :

from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
inputs = tokenizer(["An increasing sequence: one,"], return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer)

# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
    print(new_text)

Dans cet exemple, nous générons une sortie avec le modèle GPT-2 et nous imprimons chaque jeton dans la console dès qu'il arrive.

Streaming de la réponse avec FastAPI

Maintenant que vous avez choisi un moteur d'inférence, vous devez servir votre modèle et renvoyer les jetons en flux.

Votre modèle sera très probablement exécuté dans un environnement Python. Vous aurez donc besoin d'un serveur Python pour renvoyer les jetons et les rendre disponibles via une API HTTP. FastAPI est devenu un choix de facto pour de telles situations.

Nous utilisons ici Uvicorn et la StreamingResponse de FastAPI afin de servir chaque jeton dès qu'il est généré. Voici un exemple :

from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

app = FastAPI()

async def generate():
    inputs = tokenizer(["An increasing sequence: one,"], return_tensors="pt")
    streamer = TextIteratorStreamer(tokenizer)
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    for new_text in streamer:
        yield new_text

@app.get("/")
async def main():
    return StreamingResponse(generate())

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

Vous pouvez tester votre serveur de streaming grâce à la commande cURL suivante :

curl -N localhost:8000

Nous avons maintenant un modèle d'IA qui fonctionne et qui renvoie correctement les jetons.

Nous pourrions lire directement ces jetons à partir d'une application client dans un navigateur. Mais nous n'allons pas le faire, pour deux raisons.

Tout d'abord, il est important de découpler le modèle d'IA du reste de la pile car nous ne voulons pas redémarrer le modèle à chaque fois que chaque fois que nous allons apporter une petite modification à l'API. Gardez à l'esprit que les modèles d'IA génératifs modernes sont très lourds et prennent souvent plusieurs minutes à redémarrer.

La deuxième raison est que Python n'est pas nécessairement le meilleur choix lorsqu'il s'agit de construire une application concurrente à haut débit comme nous allons le faire. Ce choix peut être discuté, bien sûr, et c'est peut-être aussi une question de goût !

Transférer des jetons à travers une passerelle Go

Comme mentionné ci-dessus, il est important d'ajouter une passerelle entre votre modèle et votre client final, et Go est un bon langage de programmation pour une telle application. En production, vous voudrez peut-être aussi ajouter un proxy inverse entre la passerelle Go et le client final, et un équilibreur de charge entre la passerelle Go et le client final. proxy inverse entre la passerelle Go et le client final, et un équilibreur de charge entre votre passerelle Go et votre modèle d'IA afin de répartir la charge sur plusieurs répliques. de répartir la charge sur plusieurs répliques de votre modèle. Mais cela sort du cadre de notre article !

Notre application Go sera également chargée de rendre la page HTML finale.

Cette application fait une demande à l'application FastAPI, reçoit les jetons de FastAPI et transmet chaque jeton au frontend à l'aide d'événements envoyés par le serveur (SSE). au frontend à l'aide de Server Sent Events (SSE). SSE est plus simple que les websockets car il est unidirectionnel. Le SSE est plus simple que les websockets car il est unidirectionnel. C'est un bon choix lorsque vous voulez construire une application qui envoie des informations à un client, sans écouter une réponse potentielle du client. réponse potentielle du client.

Voici le code Go (le modèle HTML/JS/CSS sera présenté dans la section suivante) :

package main

import (
    "bufio"
    "fmt"
    "html/template"
    "io"
    "log"
    "net/http"
    "strings"
)

var (
    templates      *template.Template
    streamedTextCh chan string
)

func init() {
    // Parse all templates in the templates folder.
    templates = template.Must(template.ParseGlob("templates/*.html"))

    streamedTextCh = make(chan string)
}

// generateText calls FastAPI and returns every token received on the fly through
// a dedicated channel (streamedTextCh).
// If the EOF character is received from FastAPI, it means that text generation is over.
func generateText(streamedTextCh chan<- string) {
    var buf io.Reader = nil

    req, err := http.NewRequest("GET", "http://127.0.0.1:8000", buf)
    if err != nil {
        log.Fatal(err)
    }

    resp, err := http.DefaultClient.Do(req)
    if err != nil {
        log.Fatal(err)
    }
    defer resp.Body.Close()

    reader := bufio.NewReader(resp.Body)

outerloop:
    for {
        chunk, err := reader.ReadBytes('\x00')
        if err != nil {
            if err == io.EOF {
                break outerloop
            }
            log.Println(err)
            break outerloop
        }

        output := string(chunk)

        streamedTextCh <- output
    }
}

// formatServerSentEvent creates a proper SSE compatible body.
// Server sent events need to follow a specific formatting that
// uses "event:" and "data:" prefixes.
func formatServerSentEvent(event, data string) (string, error) {
    sb := strings.Builder{}

    _, err := sb.WriteString(fmt.Sprintf("event: %s\n", event))
    if err != nil {
        return "", err
    }
    _, err = sb.WriteString(fmt.Sprintf("data: %v\n\n", data))
    if err != nil {
        return "", err
    }

    return sb.String(), nil
}

// generate is an infinite loop that waits for new tokens received 
// from the streamedTextCh. Once a new token is received,
// it is automatically pushed to the frontend as a server sent event. 
func generate(w http.ResponseWriter, r *http.Request) {
    flusher, ok := w.(http.Flusher)
    if !ok {
        http.Error(w, "SSE not supported", http.StatusInternalServerError)
        return
    }

    w.Header().Set("Content-Type", "text/event-stream")

    for text := range streamedTextCh {
        event, err := formatServerSentEvent("streamed-text", text)
        if err != nil {
            http.Error(w, "Cannot format SSE message", http.StatusInternalServerError)
            return
        }

        _, err = fmt.Fprint(w, event)
        if err != nil {
            http.Error(w, "Cannot format SSE message", http.StatusInternalServerError)
            return
        }

        flusher.Flush()
    }
}

// start starts an asynchronous request to the AI engine.
func start(w http.ResponseWriter, r *http.Request) {
    go generateText(streamedTextCh)
}

func home(w http.ResponseWriter, r *http.Request) {
    if err := templates.ExecuteTemplate(w, "home.html", nil); err != nil {
        log.Println(err.Error())
        http.Error(w, "", http.StatusInternalServerError)
        return
    }
}

func main() {
    http.HandleFunc("/generate", generate)
    http.HandleFunc("/start", start).Methods("POST")
    http.HandleFunc("/", home).Methods("GET")

    log.Fatal(http.ListenAndServe(":8000", r))
}                

Notre page "/home" rend la page HTML/CSS/JS (présentée plus loin). La page "/start" reçoit une requête POST de l'application JS qui déclenche une requête à notre modèle d'IA. JS qui déclenche une requête à notre modèle d'IA. Et notre page "/generate" renvoie le résultat à l'application JS par le biais d'événements envoyés par le serveur.

Une fois que la fonction start() reçoit une requête POST du frontend, elle crée automatiquement une goroutine qui va faire une requête à notre application FastAPI. à notre application FastAPI.

La fonction generateText() appelle FastAPI et renvoie chaque jeton reçu à la volée via un canal dédié (streamedTextCh). Si le caractère EOF est reçu de FastAPI, cela signifie que la génération de texte est terminée.

La fonction generate() est une boucle infinie qui attend les nouveaux jetons reçus du canal streamedTextCh. Une fois qu'un nouveau jeton est reçu, il est automatiquement poussé vers le frontend en tant qu'événement envoyé par le serveur, il est automatiquement poussé vers le frontend en tant qu'événement envoyé par le serveur. Les événements envoyés par le serveur doivent suivre un formatage spécifique qui utilise les préfixes "event :" et "data :". d'où la fonction formatServerSentEvent().

Pour que l'ESS soit complète, nous avons besoin d'un client Javascript capable d'écouter les événements envoyés par le serveur en s'abonnant à la page "generate". Voir la section suivante pour comprendre comment y parvenir.

Réception de jetons avec Javascript dans le navigateur

Vous devez maintenant créer un répertoire "templates" et y ajouter un fichier "home.html".

Voici le contenu de "home.html" :

<!DOCTYPE html>
<html>
<head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <title>Our Streamed Tokens App</title>
</head>
<body>
    <div id="response-section"></div>    
    <form method="POST">
        <button onclick="start()">Start</button>
    </form>
</body>
<script>
    // Disable the default behavior of the HTML form.
    document.querySelector('form').addEventListener('submit', function(e) {
        e.preventDefault()
    })

    // Make a request to the /start to trigger the request to the AI model.
    async function start() {
        try {
            const response = await fetch("/start", {
            method: "POST",
            })
        } catch (error) {
            console.error("Error when starting process:", error)
        }
    }

    // Listen to SSE by subscribing to the /generate page, and
    // put the result in the #response-section div.
    const evtSource = new EventSource("generate")
    evtSource.addEventListener("streamed-text", (event) => {
        document.getElementById('response-section').innerHTML = event.data
    })
</script>
</html>

Comme vous pouvez le constater, l'écoute de l'ESS dans le navigateur est assez simple.

Tout d'abord, vous devez vous abonner à notre point de terminaison SSE (la page "/generate"). Ensuite, vous devez ajouter un écouteur d'événements qui lira les jetons dès qu'ils seront reçus. qu'ils sont reçus.

Les navigateurs modernes essaient automatiquement de se reconnecter la source de l'événement en cas de problèmes de connexion.

Conclusion

Vous savez maintenant comment créer une application moderne d'IA générative qui diffuse dynamiquement du texte dans le navigateur, à la manière de ChatGPT ! dynamiquement du texte dans le navigateur, à la ChatGPT !

Comme vous l'avez remarqué, une telle application n'est pas nécessairement simple car plusieurs couches sont impliquées. couches sont impliquées. Et bien sûr, le code ci-dessus est simplifié à l'extrême pour les besoins de l'exemple. l'exemple.

La principale difficulté liée à la diffusion en continu de jetons concerne la gestion des défaillances du réseau. La plupart de ces ces défaillances se produiront entre le backend Go et le frontend Javascript. Vous devrez devrez explorer des stratégies de reconnexion plus avancées et vous assurer que les erreurs sont correctement signalées à l'interface utilisateur.

J'espère que ce tutoriel vous a été utile !

Vincent
Developer Advocate chez NLP Cloud